From 2bbbfed14c2d45fe5a1be6e079408b7be7c33587 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 12 Sep 2017 20:06:44 -0400 Subject: Be better about asm syntax dialects With some help from stackoverflow, https://stackoverflow.com/questions/46186592/how-do-i-refer-to-literal-registers-in-gcc-inline-assembly-in-att-syntax and https://stackoverflow.com/questions/46185788/how-can-i-pass-an-immediate-value-to-shr-in-assembly-in-intel-syntax --- register-allocate.py | 168 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 101 insertions(+), 67 deletions(-) (limited to 'register-allocate.py') diff --git a/register-allocate.py b/register-allocate.py index 16faee062..e05e346cd 100755 --- a/register-allocate.py +++ b/register-allocate.py @@ -5,11 +5,13 @@ import codecs, re, sys, os LAMBDA = u'\u03bb' NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RBP', 'RSI', 'RDI') +NUMBERED_REGISTERS = tuple('r%d' % i for i in range(16)) RESERVED_REGISTERS = ('RSP', ) TO_BE_RESTORED_REGISTERS = ('RBP', ) NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS)) -REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + ['r%d' % i for i in range(8, 16)]) +REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + list(NUMBERED_REGISTERS)) REGISTERS = ['reg%d' % i for i in range(13)] +DEFAULT_DIALECT = 'att' def get_lines(filename): with codecs.open(filename, 'r', encoding='utf8') as f: @@ -516,18 +518,33 @@ def fix_emit_vars(emit_vars): ret = [] waiting = [] seen = set() + get_high_waiting = None for node in emit_vars: waiting.append(node) + early_new_waiting = [] new_waiting = [] for wnode in waiting: if wnode['out'] in seen: continue + elif wnode['op'] == 'GET_HIGH' and wnode['deps'][0]['out'] == get_high_waiting: + ret.append(wnode) + seen.add(wnode['out']) + get_high_waiting = None + elif wnode['op'] == 'GET_HIGH' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+': + new_waiting.append(wnode) + elif get_high_waiting is None and wnode['op'] == 'GET_LOW' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+': + ret.append(wnode) + seen.add(wnode['out']) + assert(len(wnode['deps']) == 1) + get_high_waiting = wnode['deps'][0]['out'] + elif get_high_waiting is not None: + new_waiting.append(wnode) elif all(dep['out'] in seen for dep in wnode['deps']): ret.append(wnode) seen.add(wnode['out']) else: new_waiting.append(wnode) - waiting = new_waiting + waiting = early_new_waiting + new_waiting while len(waiting) > 0: # print('Waiting on...') # print(list(sorted(node['out'] for node in waiting))) @@ -548,36 +565,54 @@ def print_input(reg_out, mem_in): #return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out) return "" -def print_val(reg): - if reg.upper() in NAMED_REGISTERS: - return '%%%s' % reg +def print_val(reg, dialect=DEFAULT_DIALECT, numbered_registers=False, final_pass=False): + assert(dialect in ('intel', 'att')) + if reg.upper() in NAMED_REGISTERS or (numbered_registers and reg.lower() in NUMBERED_REGISTERS): + if dialect == 'intel': + if final_pass: + return reg + else: + return '%%%s' % reg + elif dialect == 'att': + return '%%%%%s' % reg if reg[:2] == '0x': - return '$%s' % reg + if dialect == 'intel': + return '%s' % reg + elif dialect == 'att': + return '$%s' % reg return '%%[%s]' % reg -def print_mov_no_adjust(reg_out, reg_in, comment=None): - #return '%s <- MOV %s;\n' % (reg_out, reg_in) - #ret, reg_in = print_load(reg_in) - ret = '"mov %s, %s\\t\\n"' % (reg_out, reg_in) +# args should be (outputs, inputs), as in intel syntax, regardless of what dialect says +def print_instr(instr, args, comment=None, dialect=DEFAULT_DIALECT, do_print_val=True): + if do_print_val: + args = tuple(print_val(arg, dialect=dialect) for arg in args) + if dialect == 'att': + args = tuple(reversed(args)) + ret ='"%s %s\\t\\n"' % (instr, ', '.join(args)) if comment is not None: ret += ' // %s' % comment ret += '\n' return ret +def print_mov_no_adjust(reg_out, reg_in, comment=None, do_print_val=False): + #return '%s <- MOV %s;\n' % (reg_out, reg_in) + #ret, reg_in = print_load(reg_in) + return print_instr('mov', (reg_out, reg_in), comment=comment, do_print_val=do_print_val) + def print_mov(reg_out, reg_in): #return '%s <- MOV %s;\n' % (reg_out, reg_in) #ret, reg_in = print_load(reg_in) - return print_mov_no_adjust(print_val(reg_out), print_val(reg_in)) + return print_mov_no_adjust(reg_out, reg_in, do_print_val=True) def print_load_constant(reg_out, imm): assert(imm[:2] == '0x') - return print_mov_no_adjust(print_val(reg_out), print_val(imm)) + return print_mov_no_adjust(reg_out, imm, do_print_val=True) def print_load_specific_reg(reg, specific_reg='rdx'): ret = '' #ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) if reg != specific_reg: - ret += print_mov_no_adjust(print_val(specific_reg), print_val(reg)) + ret += print_mov_no_adjust(specific_reg, reg, do_print_val=True) return ret, specific_reg def print_unload_specific_reg(specific_reg='rdx'): ret = '' @@ -602,7 +637,7 @@ def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src): ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx') assert(rx2 != actual_rx1) ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1]) - ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(reg_out_high), print_val(reg_out_low), print_val(actual_rx2), src)) + ret += ret2 + ret3 + print_instr('mulx', (reg_out_high, reg_out_low, actual_rx2), comment=src) ret += print_unload_specific_reg('rdx') return ret @@ -627,7 +662,16 @@ def print_adx(reg_out, rx1, rx2, bucket): #return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket) assert(rx1 == reg_out) ret, rx2 = print_load(rx2, dont_clobber=[rx1]) - return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) + return ret + print_instr('adx', (reg_out, rx2), 'bucket: ' + bucket) + +def print_adc(reg_out, carry_out, carry_in, rx1, rx2, bucket): + #return '%s <- ADCX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket) + global LAST_CARRY + assert(LAST_CARRY == carry_in) + LAST_CARRY = carry_out + assert(rx1 == reg_out) + ret, rx2 = print_load(rx2, dont_clobber=[rx1]) + return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket) def print_add(reg_out, cf, rx1, rx2, bucket): #return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket) @@ -636,24 +680,24 @@ def print_add(reg_out, cf, rx1, rx2, bucket): #assert(LAST_CARRY is None or LAST_CARRY == cf) LAST_CARRY = cf ret, rx2 = print_load(rx2, dont_clobber=[rx1]) - return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) + return ret + print_instr('add', (reg_out, rx2), 'bucket: ' + bucket) -def print_adc(reg_out, cf, rx1, rx2, bucket): - #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf, cf, rx1, rx2, bucket) +def print_adc(reg_out, cf_out, cf_in, rx1, rx2, bucket): + #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf_out, cf_in, rx1, rx2, bucket) assert(reg_out == rx1) ret = '' global LAST_CARRY - if LAST_CARRY != cf: - ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf) - LAST_CARRY = cf + if LAST_CARRY != cf_in: + ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf_in) + LAST_CARRY = cf_out ret2, rx2 = print_load(rx2, dont_clobber=[rx1]) ret += ret2 - return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket)) + return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket) def print_adcx(reg_out, cf, bucket): #return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket) assert(LAST_CARRY == cf) - return ('"adcx %%[%s], $0\\t\\n" // bucket: %s\n' % (reg_out, bucket)) + return print_instr('adcx', (reg_out, '0x0'), 'bucket: ' + bucket) def print_and(reg_out, rx1, rx2, src): #return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src) @@ -662,10 +706,8 @@ def print_and(reg_out, rx1, rx2, src): if reg_out != rx1: return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src) else: - ret = '' - if rx2[:2] != '0x': - ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1]) - return ret + ('"and %s, %s\\t\\n" // %s\n' % (print_val(reg_out), print_val(rx2), src)) + ret, rx2 = print_load(rx2, can_clobber=[reg_out, 'rdx'], dont_clobber=[rx1]) + return ret + print_instr('and', (reg_out, rx2), src) def print_shr(reg_out, rx1, imm, src): @@ -674,7 +716,7 @@ def print_shr(reg_out, rx1, imm, src): LAST_CARRY = None assert(rx1 == reg_out) assert(imm[:2] == '0x') - return ('"shr %%[%s], $%s\\t\\n" // %s\n' % (reg_out, imm, src)) + return print_instr('shr', (reg_out, imm), src) def print_shrd(reg_out, rx_low, rx_high, imm, src): #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src) @@ -687,13 +729,12 @@ def print_shrd(reg_out, rx_low, rx_high, imm, src): print_shrd(reg_out, rx_high, rx_low, imm, src) assert(rx_low == reg_out) assert(imm[:2] == '0x') - return ('"shrd %%[%s], %%[%s], $%s\\t\\n" // %s\n' % (rx_low, rx_high, imm, src)) + return print_instr('shrd', (rx_low, rx_high, imm), src) def schedule(input_data, existing, emit_vars): ret = '' buckets_seen = set() - buckets_carried = set() emit_vars = fix_emit_vars(emit_vars) ret += ('// Convention is low_reg:high_reg\n') for node in emit_vars: @@ -754,34 +795,31 @@ def schedule(input_data, existing, emit_vars): ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) buckets_seen.add(node['rev_deps'][0]['out']) elif node['op'] == 'GET_HIGH': - ret += print_adx(existing[node['rev_deps'][0]['out']], + carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')] + ret += print_adc(existing[node['rev_deps'][0]['out']], + None, + carry, existing[node['rev_deps'][0]['out']], existing[node['out']], ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) elif node['op'] == 'GET_LOW': carry = 'c' + node['rev_deps'][0]['out'][:-len('_low')] - if node['rev_deps'][0]['out'] not in buckets_carried: - ret += print_add(existing[node['rev_deps'][0]['out']], - carry, - existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) - buckets_carried.add(node['rev_deps'][0]['out']) - else: - ret += print_adc(existing[node['rev_deps'][0]['out']], - carry, - existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) + ret += print_add(existing[node['rev_deps'][0]['out']], + carry, + existing[node['rev_deps'][0]['out']], + existing[node['out']], + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) elif node['op'] in ('GET_CARRY',): - carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')] - ret += print_adcx(existing[node['rev_deps'][0]['out']], - carry, - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) + #carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')] + #ret += print_adc(existing[node['rev_deps'][0]['out']], + # carry, + # ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) + pass elif node['op'] == '+' and len(node['extra_out']) > 0: pass elif node['op'] == '+' and len(node['deps']) == 2 and node['type'] == 'uint64_t': - ret += print_adx(existing[node['out']], + ret += print_add(existing[node['out']], + None, existing[node['deps'][0]['out']], existing[node['deps'][1]['out']], '%s = %s + %s' @@ -801,25 +839,20 @@ def schedule(input_data, existing, emit_vars): ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) buckets_seen.add(rdep['out']) elif 'high' in rdep['out']: - ret += print_adx(existing[rdep['out']], + carry = 'c' + rdep['out'][:-len('_high')] + ret += print_adc(existing[rdep['out']], + None, + carry, existing[rdep['out']], existing[node['out']], ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) elif 'low' in rdep['out']: carry = 'c' + rdep['out'][:-len('_low')] - if rdep['out'] not in buckets_carried: - ret += print_add(existing[rdep['out']], - carry, - existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) - buckets_carried.add(rdep['out']) - else: - ret += print_adc(existing[rdep['out']], - carry, - existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) + ret += print_add(existing[rdep['out']], + carry, + existing[rdep['out']], + existing[node['out']], + ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) else: assert(False) return ret @@ -845,19 +878,20 @@ def inline_schedule(sched, input_vars, output_vars): [reg for reg in available_registers[count:] if reg.upper() not in TO_BE_RESTORED_REGISTERS] renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):])) for from_reg, to_reg in renaming.items(): - sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg) + sched = sched.replace('%%[%s]' % from_reg, print_val(to_reg, numbered_registers=True)) transient_regs = [renaming[reg] for reg in transient_regs] for reg in REAL_REGISTERS: - sched = sched.replace('%' + reg.lower(), reg.lower()) + sched = sched.replace(print_val(reg.lower(), numbered_registers=True), + print_val(reg.lower(), numbered_registers=True, final_pass=True)) ret = '' ret += 'uint64_t %s;\n' % ', '.join(output_vars[reg] for reg in output_regs) ret += 'uint64_t %s;\n\n' % ', '.join(reg.lower() for reg in TO_BE_RESTORED_REGISTERS) ret += 'asm (\n' for reg in map(str.lower, TO_BE_RESTORED_REGISTERS): - ret += print_mov_no_adjust('%%[%s]' % reg, reg) + ret += print_mov_no_adjust('%%[%s]' % reg, print_val(reg, numbered_registers=True, final_pass=True)) ret += sched for reg in map(str.lower, TO_BE_RESTORED_REGISTERS): - ret += print_mov_no_adjust(reg, '%%[%s]' % reg) + ret += print_mov_no_adjust(print_val(reg, final_pass=True), '%%[%s]' % reg) ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n' ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars] + ['[%s] "m" (%s)' % (reg, reg) for reg in map(str.lower, TO_BE_RESTORED_REGISTERS)]) + '\n' -- cgit v1.2.3