From bd4714ebbb3552b2d85222f37fc6274052e3176e Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 12 Sep 2017 13:15:31 -0400 Subject: Fix assembly --- .../make-graph-with-reg-by-ac-buckets.py | 180 ++++++++++++++------- 1 file changed, 123 insertions(+), 57 deletions(-) (limited to 'etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py') diff --git a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py index f37bc3ff1..1083846d8 100755 --- a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py +++ b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py @@ -6,10 +6,12 @@ import subprocess LAMBDA = u'\u03bb' -OP_NAMES = {'*':'MUL', '+':'ADD', '>>':'SHL', '<<':'SHR', '|':'OR', '&':'AND'} - +NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RSI', 'RDI') +NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS)) REGISTERS = tuple(#['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] + #, 'RSP'] # RSP is stack pointer? - ['r%d' % i for i in range(13)]) + ['reg%d' % i for i in range(13)]) +#REAL_REGISTERS = tuple(['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] + #, 'RSP'] # RSP is stack pointer? +# ['reg%d' % i for i in range(13)]) REGISTER_COLORS = ['color="black"', 'color="white",fillcolor="black"', 'color="maroon"', 'color="green"', 'fillcolor="olive"', 'color="navy"', 'color="purple"', 'fillcolor="teal"', 'fillcolor="silver"', 'fillcolor="gray"', 'fillcolor="red"', 'fillcolor="lime"', 'fillcolor="yellow"', 'fillcolor="blue"', 'fillcolor="fuschia"', 'fillcolor="aqua"'] @@ -281,7 +283,7 @@ def allocate_node(existing, node, *args): return do_ret() if len(node['deps']) == 0 and node['op'] == 'INPUT': assert(node['type'] == 'uint64_t') - cur_map[node['out']] = 'r' + node['out'] # free_list.pop() + cur_map[node['out']] = 'm' + node['out'] # free_list.pop() emit_vars.append(node) return do_ret() if is_temp(node): @@ -513,65 +515,76 @@ def print_input(reg_out, mem_in): #return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out) return "" +def print_val(reg): + if reg.upper() in NAMED_REGISTERS: + return '%%%s' % reg + if reg[:2] == '0x': + return '$%s' % reg + return '%%[%s]' % reg + def print_load_specific_reg(reg, specific_reg='rdx'): ret = '' - ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) - ret += '"mov %%[%s], %%%s\\t\\n"\n' % (reg, specific_reg) - return ret, (specific_reg,) + #ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) + if reg != specific_reg: + ret += '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(specific_reg)) + return ret, specific_reg def print_unload_specific_reg(specific_reg='rdx'): ret = '' - ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) + #ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) return ret -def print_load(*regs): - TEMP_REG = ['arg%d' % d for d in reversed(range(15))] - ret, out_reg = '', [] - for reg in regs: - if reg in REGISTERS: - out_reg.append(reg) - continue - else: - cur_reg = TEMP_REG.pop() - ret += '"mov %%[%s], %%[%s]\\t\\n"\n' % (reg, cur_reg) - out_reg.append(cur_reg) - if len(out_reg) == 1: return ret, out_reg[0] - return ret, tuple(out_reg) +#def get_arg_reg(d): +# return 'arg%d' % d +def print_load(reg, can_clobber=tuple(), dont_clobber=tuple()): + assert(not isinstance(can_clobber, str)) + assert(not isinstance(dont_clobber, str)) + can_clobber = [i for i in reversed(can_clobber) if i not in dont_clobber] + if reg in REGISTERS: + return ('', reg) + else: + cur_reg = can_clobber.pop() + ret = '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(cur_reg)) + return (ret, cur_reg) def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src): #return '%s:%s <- MULX %s, %s; // %s\n' % (reg_out_low, reg_out_high, rx1, rx2, src) ret = '' ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx') - ret3, actual_rx2 = print_load(rx2) - ret += ret2 + ret3 + ('"mulx %%[%s], %%[%s], %%[%s]\\t\\n" // %s\n' % (actual_rx2, reg_out_high, reg_out_low, src)) + assert(rx2 != actual_rx1) + ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1]) + ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(actual_rx2), print_val(reg_out_high), print_val(reg_out_low), src)) ret += print_unload_specific_reg('rdx') return ret def print_mov_bucket(reg_out, reg_in, bucket): #return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket) - ret, reg_in = print_load(reg_in) - return ret + ('"mov %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_in, reg_out, bucket)) + #ret, reg_in = print_load(reg_in, can_clobber=[reg_out]) + return ('"mov %s, %s\\t\\n" // bucket: %s\n' % (print_val(reg_in), print_val(reg_out), bucket)) def print_mov(reg_out, reg_in): #return '%s <- MOV %s;\n' % (reg_out, reg_in) - ret, reg_in = print_load(reg_in) - return ret + ('"mov %%[%s], %%[%s]\\t\\n"\n' % (reg_in, reg_out)) + #ret, reg_in = print_load(reg_in) + return ('"mov %s, %s\\t\\n"\n' % (print_val(reg_in), print_val(reg_out))) + +def print_load_constant(reg_out, imm): + assert(imm[:2] == '0x') + return ('"mov $%s, %s\\t\\n"\n' % (imm, print_val(reg_out))) LAST_CARRY = None def print_mul_by_constant(reg_out, reg_in, constant, src): #return '%s <- MULX %s, %s; // %s\n' % (ret_out, reg_in, constant, src) - #assert(LAST_CARRY is None) - global LAST_CARRY - ret, reg_in = print_load(reg_in) + ret = '' if constant == '0x13': - return ret + ('FIXME: lea for %s\n' % src) - else: - LAST_CARRY = None - return ret + ('"imul %%[%s], $%s, %%[%s]\\t\\n" // %s\n' % (reg_in, constant, reg_out, src)) + ret += ('// FIXME: lea for %s\n' % src) + assert(constant[:2] == '0x') + return ret + \ + print_load_constant('rdx', constant) + \ + print_mulx(reg_out, 'rdx', 'rdx', reg_in, src) def print_adx(reg_out, rx1, rx2, bucket): #return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket) assert(rx1 == reg_out) - ret, rx2 = print_load(rx2) + ret, rx2 = print_load(rx2, dont_clobber=[rx1]) return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) def print_add(reg_out, cf, rx1, rx2, bucket): @@ -580,7 +593,7 @@ def print_add(reg_out, cf, rx1, rx2, bucket): assert(reg_out == rx1) #assert(LAST_CARRY is None or LAST_CARRY == cf) LAST_CARRY = cf - ret, rx2 = print_load(rx2) + ret, rx2 = print_load(rx2, dont_clobber=[rx1]) return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) def print_adc(reg_out, cf, rx1, rx2, bucket): @@ -591,7 +604,7 @@ def print_adc(reg_out, cf, rx1, rx2, bucket): if LAST_CARRY != cf: ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf) LAST_CARRY = cf - ret2, rx2 = print_load(rx2) + ret2, rx2 = print_load(rx2, dont_clobber=[rx1]) ret += ret2 return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) @@ -610,11 +623,30 @@ def print_and(reg_out, rx1, rx2, src): if rx2[:2] == '0x': return ('"and $%s, %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) else: - ret, rx2 = print_load(rx2) + ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1]) return ret + ('"and %%[%s], %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) -#def print_shr(reg_out, rx1, imm, src): - #return '%s <- SHR %s, %s;\n' % +def print_shr(reg_out, rx1, imm, src): + #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src) + global LAST_CARRY + LAST_CARRY = None + assert(rx1 == reg_out) + assert(imm[:2] == '0x') + return ('"shr $%s, %%[%s]\\t\\n" // %s\n' % (imm, reg_out, src)) + +def print_shrd(reg_out, rx_low, rx_high, imm, src): + #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src) + global LAST_CARRY + LAST_CARRY = None + if rx_low != reg_out and rx_high == reg_out: + return print_mov('rdx', rx_low) + \ + print_mov(rx_high, rx_low) + \ + print_mov(rx_low, 'rdx') + \ + print_shrd(reg_out, rx_high, rx_low, imm, src) + assert(rx_low == reg_out) + assert(imm[:2] == '0x') + return ('"shrd $%s, %%[%s], %%[%s]\\t\\n" // %s\n' % (imm, rx_low, rx_high, src)) + def schedule(input_data, existing, emit_vars): ret = '' @@ -655,24 +687,24 @@ def schedule(input_data, existing, emit_vars): extra_arg)) elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['op'] == 'COMBINE': extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0] - ret += ('%s <- SHR %s:%s, %s; // %s = %s:%s >> %s\n' - % (existing[node['out']], - existing[node['deps'][0]['deps'][0]['out']], - existing[node['deps'][0]['deps'][1]['out']], - extra_arg, - node['out'], - node['deps'][0]['deps'][0]['out'], - node['deps'][0]['deps'][1]['out'], - extra_arg)) + ret += print_shrd(existing[node['out']], + existing[node['deps'][0]['deps'][0]['out']], + existing[node['deps'][0]['deps'][1]['out']], + extra_arg, + '%s = %s:%s >> %s' + % (node['out'], + node['deps'][0]['deps'][0]['out'], + node['deps'][0]['deps'][1]['out'], + extra_arg)) elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['type'] == 'uint64_t': extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0] - ret += ('%s <- SHR %s, %s; // %s = %s >> %s\n' - % (existing[node['out']], - existing[node['deps'][0]['deps'][0]['out']], - extra_arg, - node['out'], - node['deps'][0]['deps'][0]['out'], - extra_arg)) + ret += print_shr(existing[node['out']], + existing[node['deps'][0]['deps'][0]['out']], + extra_arg, + '%s = %s >> %s' + % (node['out'], + node['deps'][0]['deps'][0]['out'], + extra_arg)) elif node['op'] in ('GET_HIGH', 'GET_LOW'): if node['rev_deps'][0]['out'] not in buckets_seen: ret += print_mov_bucket(existing[node['rev_deps'][0]['out']], @@ -750,6 +782,38 @@ def schedule(input_data, existing, emit_vars): assert(False) return ret +def inline_schedule(sched, input_vars, output_vars): + KNOWN_CONSTRAINTS = dict(('r%sx' % l, l) for l in 'abcd') + def int_or_zero_key(v): + orig = v + v = v.strip('abcdefghijklmnopqrstuvwxyz') + if v.isdigit(): return (int(v), orig) + return (0, orig) + variables = list(sorted(set(list(re.findall('%\[([a-zA-Z0-9_]*)\]', sched)) + + list(re.findall('%([a-zA-Z0-9_]+)', sched))), + key=int_or_zero_key)) + mems, variables = [i for i in variables if i[:2] == 'mx'], [i for i in variables if i[:2] != 'mx'] + special_reg, variables = [i for i in variables if i.upper() in NAMED_REGISTERS], [i for i in variables if i.upper() not in NAMED_REGISTERS] + transient_regs, output_regs = [i for i in variables if i not in output_vars.values()], [i for i in variables if i in output_vars.keys()] + available_registers = ['r%d' % i for i in range(16) + if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg] + for reg in output_regs: + sched = sched.replace('%%[%s]' % reg, '%%[r%s]' % output_vars[reg]) + renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):])) + for from_reg, to_reg in renaming.items(): + sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg) + transient_regs = [renaming[reg] for reg in transient_regs] + ret = '' + ret += 'asm (\n' + ret += sched + ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n' + ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n' + ret += ': ' + ', '.join(['"cc"'] + + ['"%s"' % reg for reg in special_reg] + + ['"%s"' % reg for reg in transient_regs]) + '\n' + ret += ');\n' + return ret + data_list = parse_lines(get_lines('femulDisplay.log')) for i, data in enumerate(data_list): graph = to_graph(data) @@ -807,7 +871,9 @@ for i, data in enumerate(data_list): #mul_node = possible_nodes[0] #print([n['out'] for n in mul_node['deps']]) #cur_map, free_temps, free_list, all_temps = allocate_subgraph(existing, mul_node, cur_map, free_temps, free_list, all_temps) - sched = schedule(data, existing, emit_vars) + sched = inline_schedule(schedule(data, existing, emit_vars), + dict((existing[n['out']], n['out']) for n in graph['in'].values()), + dict((existing[n['out']], n['out']) for n in graph['out'].values())) #fill_deps(buckets[0]) deps = adjust_bits(data, print_graph(graph, existing)) with codecs.open('femulData%d.dot' % i, 'w', encoding='utf8') as f: -- cgit v1.2.3