From 98daac5a86136cf0f1018b292baf8a676f6fd579 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 12 Sep 2017 16:33:09 -0400 Subject: Update register allocation more Switch over to intel syntax, because I can't figure out how to name registers in AT&T / GAS. --- register-allocate.py | 84 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 28 deletions(-) (limited to 'register-allocate.py') diff --git a/register-allocate.py b/register-allocate.py index 69f7ef7e9..16faee062 100755 --- a/register-allocate.py +++ b/register-allocate.py @@ -4,8 +4,11 @@ import codecs, re, sys, os LAMBDA = u'\u03bb' -NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RSI', 'RDI') +NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RBP', 'RSI', 'RDI') +RESERVED_REGISTERS = ('RSP', ) +TO_BE_RESTORED_REGISTERS = ('RBP', ) NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS)) +REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + ['r%d' % i for i in range(8, 16)]) REGISTERS = ['reg%d' % i for i in range(13)] def get_lines(filename): @@ -552,11 +555,29 @@ def print_val(reg): return '$%s' % reg return '%%[%s]' % reg +def print_mov_no_adjust(reg_out, reg_in, comment=None): + #return '%s <- MOV %s;\n' % (reg_out, reg_in) + #ret, reg_in = print_load(reg_in) + ret = '"mov %s, %s\\t\\n"' % (reg_out, reg_in) + if comment is not None: + ret += ' // %s' % comment + ret += '\n' + return ret + +def print_mov(reg_out, reg_in): + #return '%s <- MOV %s;\n' % (reg_out, reg_in) + #ret, reg_in = print_load(reg_in) + return print_mov_no_adjust(print_val(reg_out), print_val(reg_in)) + +def print_load_constant(reg_out, imm): + assert(imm[:2] == '0x') + return print_mov_no_adjust(print_val(reg_out), print_val(imm)) + def print_load_specific_reg(reg, specific_reg='rdx'): ret = '' #ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) if reg != specific_reg: - ret += '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(specific_reg)) + ret += print_mov_no_adjust(print_val(specific_reg), print_val(reg)) return ret, specific_reg def print_unload_specific_reg(specific_reg='rdx'): ret = '' @@ -572,7 +593,7 @@ def print_load(reg, can_clobber=tuple(), dont_clobber=tuple()): return ('', reg) else: cur_reg = can_clobber.pop() - ret = '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(cur_reg)) + ret = print_mov_no_adjust(print_val(cur_reg), print_val(reg)) return (ret, cur_reg) def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src): @@ -581,23 +602,14 @@ def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src): ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx') assert(rx2 != actual_rx1) ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1]) - ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(actual_rx2), print_val(reg_out_high), print_val(reg_out_low), src)) + ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(reg_out_high), print_val(reg_out_low), print_val(actual_rx2), src)) ret += print_unload_specific_reg('rdx') return ret def print_mov_bucket(reg_out, reg_in, bucket): #return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket) #ret, reg_in = print_load(reg_in, can_clobber=[reg_out]) - return ('"mov %s, %s\\t\\n" // bucket: %s\n' % (print_val(reg_in), print_val(reg_out), bucket)) - -def print_mov(reg_out, reg_in): - #return '%s <- MOV %s;\n' % (reg_out, reg_in) - #ret, reg_in = print_load(reg_in) - return ('"mov %s, %s\\t\\n"\n' % (print_val(reg_in), print_val(reg_out))) - -def print_load_constant(reg_out, imm): - assert(imm[:2] == '0x') - return ('"mov $%s, %s\\t\\n"\n' % (imm, print_val(reg_out))) + return print_mov_no_adjust(print_val(reg_out), print_val(reg_in), 'bucket: ' + bucket) LAST_CARRY = None @@ -615,7 +627,7 @@ def print_adx(reg_out, rx1, rx2, bucket): #return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket) assert(rx1 == reg_out) ret, rx2 = print_load(rx2, dont_clobber=[rx1]) - return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) def print_add(reg_out, cf, rx1, rx2, bucket): #return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket) @@ -624,7 +636,7 @@ def print_add(reg_out, cf, rx1, rx2, bucket): #assert(LAST_CARRY is None or LAST_CARRY == cf) LAST_CARRY = cf ret, rx2 = print_load(rx2, dont_clobber=[rx1]) - return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) def print_adc(reg_out, cf, rx1, rx2, bucket): #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf, cf, rx1, rx2, bucket) @@ -636,12 +648,12 @@ def print_adc(reg_out, cf, rx1, rx2, bucket): LAST_CARRY = cf ret2, rx2 = print_load(rx2, dont_clobber=[rx1]) ret += ret2 - return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) def print_adcx(reg_out, cf, bucket): #return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket) assert(LAST_CARRY == cf) - return ('"adcx $0, %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, bucket)) + return ('"adcx %%[%s], $0\\t\\n" // bucket: %s\n' % (reg_out, bucket)) def print_and(reg_out, rx1, rx2, src): #return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src) @@ -650,11 +662,11 @@ def print_and(reg_out, rx1, rx2, src): if reg_out != rx1: return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src) else: - if rx2[:2] == '0x': - return ('"and $%s, %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) - else: + ret = '' + if rx2[:2] != '0x': ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1]) - return ret + ('"and %%[%s], %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) + return ret + ('"and %s, %s\\t\\n" // %s\n' % (print_val(reg_out), print_val(rx2), src)) + def print_shr(reg_out, rx1, imm, src): #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src) @@ -662,7 +674,7 @@ def print_shr(reg_out, rx1, imm, src): LAST_CARRY = None assert(rx1 == reg_out) assert(imm[:2] == '0x') - return ('"shr $%s, %%[%s]\\t\\n" // %s\n' % (imm, reg_out, src)) + return ('"shr %%[%s], $%s\\t\\n" // %s\n' % (reg_out, imm, src)) def print_shrd(reg_out, rx_low, rx_high, imm, src): #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src) @@ -675,7 +687,7 @@ def print_shrd(reg_out, rx_low, rx_high, imm, src): print_shrd(reg_out, rx_high, rx_low, imm, src) assert(rx_low == reg_out) assert(imm[:2] == '0x') - return ('"shrd $%s, %%[%s], %%[%s]\\t\\n" // %s\n' % (imm, rx_low, rx_high, src)) + return ('"shrd %%[%s], %%[%s], $%s\\t\\n" // %s\n' % (rx_low, rx_high, imm, src)) def schedule(input_data, existing, emit_vars): @@ -820,19 +832,35 @@ def inline_schedule(sched, input_vars, output_vars): mems, variables = [i for i in variables if i[:2] == 'mx'], [i for i in variables if i[:2] != 'mx'] special_reg, variables = [i for i in variables if i.upper() in NAMED_REGISTERS], [i for i in variables if i.upper() not in NAMED_REGISTERS] transient_regs, output_regs = [i for i in variables if i not in output_vars.values()], [i for i in variables if i in output_vars.keys()] - available_registers = ['r%d' % i for i in range(16) - if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg] + available_registers = [NAMED_REGISTER_MAPPING.get('r%d' % i, 'r%d' % i).lower() for i in range(16) + if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or (NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg + and NAMED_REGISTER_MAPPING['r%d' % i] not in RESERVED_REGISTERS)] + assert(len(available_registers) >= len(transient_regs)) for reg in output_regs: sched = sched.replace('%%[%s]' % reg, '%%[r%s]' % output_vars[reg]) + available_registers = available_registers[-len(transient_regs):] + assert(len(available_registers) > len(TO_BE_RESTORED_REGISTERS)) # makes the replacement of low registers with ones we have to handle specially easier + count = len([reg for reg in TO_BE_RESTORED_REGISTERS if reg.lower() not in available_registers]) + available_registers = [reg.lower() for reg in TO_BE_RESTORED_REGISTERS] + \ + [reg for reg in available_registers[count:] if reg.upper() not in TO_BE_RESTORED_REGISTERS] renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):])) for from_reg, to_reg in renaming.items(): sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg) transient_regs = [renaming[reg] for reg in transient_regs] + for reg in REAL_REGISTERS: + sched = sched.replace('%' + reg.lower(), reg.lower()) ret = '' - ret += '__asm__ (\n' + ret += 'uint64_t %s;\n' % ', '.join(output_vars[reg] for reg in output_regs) + ret += 'uint64_t %s;\n\n' % ', '.join(reg.lower() for reg in TO_BE_RESTORED_REGISTERS) + ret += 'asm (\n' + for reg in map(str.lower, TO_BE_RESTORED_REGISTERS): + ret += print_mov_no_adjust('%%[%s]' % reg, reg) ret += sched + for reg in map(str.lower, TO_BE_RESTORED_REGISTERS): + ret += print_mov_no_adjust(reg, '%%[%s]' % reg) ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n' - ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n' + ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars] + + ['[%s] "m" (%s)' % (reg, reg) for reg in map(str.lower, TO_BE_RESTORED_REGISTERS)]) + '\n' ret += ': ' + ', '.join(['"cc"'] + ['"%s"' % reg for reg in special_reg] + ['"%s"' % reg for reg in transient_regs]) + '\n' -- cgit v1.2.3