diff options
Diffstat (limited to 'etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py')
-rwxr-xr-x | etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py | 309 |
1 files changed, 208 insertions, 101 deletions
diff --git a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py index 4ad14e50f..f37bc3ff1 100755 --- a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py +++ b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py @@ -8,8 +8,8 @@ LAMBDA = u'\u03bb' OP_NAMES = {'*':'MUL', '+':'ADD', '>>':'SHL', '<<':'SHR', '|':'OR', '&':'AND'} -REGISTERS = tuple(['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] #, 'RSP'] # RSP is stack pointer? - + ['r%d' % i for i in range(8, 19)]) +REGISTERS = tuple(#['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] + #, 'RSP'] # RSP is stack pointer? + ['r%d' % i for i in range(13)]) REGISTER_COLORS = ['color="black"', 'color="white",fillcolor="black"', 'color="maroon"', 'color="green"', 'fillcolor="olive"', 'color="navy"', 'color="purple"', 'fillcolor="teal"', 'fillcolor="silver"', 'fillcolor="gray"', 'fillcolor="red"', 'fillcolor="lime"', 'fillcolor="yellow"', 'fillcolor="blue"', 'fillcolor="fuschia"', 'fillcolor="aqua"'] @@ -197,18 +197,6 @@ def to_graph(input_data): return graph -def print_dependencies(input_data, dependencies): - in_vars = get_input_var_names(input_data) - out_vars = get_output_var_names(input_data) - registers = assign_registers(input_data, dependencies) - body = ( - ''.join(' %s [label="%s (%s)",%s];\n' % (var, var, reg, COLOR_FOR_REGISTER[reg.split(':')[0]]) for var, reg in registers.items()) + - ''.join(' in -> %s ;\n' % var for var in in_vars) + - ''.join(' %s -> out ;\n' % var for var in out_vars) + - ''.join(''.join(' %s -> %s ;\n' % (out_var, in_var) for out_var in sorted(dependencies[in_var])) - for in_var in sorted(dependencies.keys())) - ) - return ('digraph G {\n' + body + '}\n') def adjust_bits(input_data, graph): for line in input_data['lines']: if line['type'] == 'uint128_t': @@ -237,7 +225,11 @@ def is_temp(node): return True return False +def is_allocated_to_reg(full_map, node): + return node['out'] in full_map.keys() and all(reg in REGISTERS for reg in full_map[node['out']].split(':')) + def deps_allocated(full_map, node): + if node['op'] == 'INPUT': return True if node['out'] not in full_map.keys(): return False return all(deps_allocated(full_map, dep) for dep in node['deps']) @@ -260,7 +252,7 @@ def allocate_node(existing, node, *args): if reg in all_temps: if reg not in free_temps: free_temps.append(reg) - else: + elif reg in REGISTERS: if reg not in free_list: print('freeing %s from %s' % (reg, var)) free_list.append(reg) @@ -269,6 +261,7 @@ def allocate_node(existing, node, *args): if node['out'] in full_map.keys(): for dep in node['deps']: if dep['out'] in freed or dep['out'] not in full_map.keys(): continue + if not is_allocated_to_reg(full_map, dep): continue if (all(deps_allocated(full_map, rdep) for rdep in dep['rev_deps']) or all(reg in all_temps for reg in full_map[dep['out']].split(':'))): do_free(dep['out']) @@ -277,18 +270,18 @@ def allocate_node(existing, node, *args): do_free_deps(node) return do_ret() #print('alloc: %s (of %d)' % (node['out'], len(free_list))) - if node['op'] in ('GET_HIGH', 'GET_LOW') and len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) <= 2 and all(n['op'] in ('GET_HIGH', 'GET_LOW') for n in node['deps'][0]['rev_deps']) and node['deps'][0]['out'] in full_map.keys(): + if node['op'] in ('GET_HIGH', 'GET_LOW') and len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) <= 2 and all(n['op'] in ('GET_HIGH', 'GET_LOW') for n in node['deps'][0]['rev_deps']) and is_allocated_to_reg(full_map, node['deps'][0]): reg_idx = {'GET_LOW':0, 'GET_HIGH':1}[node['op']] cur_map[node['out']] = full_map[node['deps'][0]['out']].split(':')[reg_idx] emit_vars.append(node) return do_ret() - if len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) == 1 and node['deps'][0]['out'] in full_map.keys() and node['type'] == node['deps'][0]['type']: + if len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) == 1 and is_allocated_to_reg(full_map, node['deps'][0]) and node['type'] == node['deps'][0]['type']: cur_map[node['out']] = full_map[node['deps'][0]['out']] emit_vars.append(node) return do_ret() if len(node['deps']) == 0 and node['op'] == 'INPUT': assert(node['type'] == 'uint64_t') - cur_map[node['out']] = free_list.pop() + cur_map[node['out']] = 'r' + node['out'] # free_list.pop() emit_vars.append(node) return do_ret() if is_temp(node): @@ -314,7 +307,8 @@ def allocate_node(existing, node, *args): if node['op'] == '*' and node['type'] == 'uint64_t' and len(node['deps']) == 1: dep = node['deps'][0] assert(dep['out'] in full_map.keys()) - if all(rdep is node or (rdep['out'] in full_map.keys() and full_map[rdep['out']] != full_map[dep['out']]) + if is_allocated_to_reg(full_map, dep) and \ + all(rdep is node or (is_allocated_to_reg(full_map, rdep) and full_map[rdep['out']] != full_map[dep['out']]) for rdep in dep['rev_deps']): cur_map[node['out']] = full_map[dep['out']] freed += [dep['out']] @@ -322,7 +316,7 @@ def allocate_node(existing, node, *args): cur_map[node['out']] = free_list.pop() emit_vars.append(node) return do_ret() - raw_input([node['out'], node['op'], node['type'], len(node['deps'])]) + raw_input([node['out'], node['op'], node['type'], [(dep['out'], full_map.get(dep['out'])) for dep in node['deps']]]) return do_ret() def allocate_deps(existing, node, *args): @@ -491,7 +485,9 @@ def fix_emit_vars(emit_vars): waiting.append(node) new_waiting = [] for wnode in waiting: - if all(dep['out'] in seen for dep in wnode['deps']): + if wnode['out'] in seen: + continue + elif all(dep['out'] in seen for dep in wnode['deps']): ret.append(wnode) seen.add(wnode['out']) else: @@ -502,7 +498,9 @@ def fix_emit_vars(emit_vars): print(list(sorted(node['out'] for node in waiting))) new_waiting = [] for wnode in waiting: - if all(dep['out'] in seen for dep in wnode['deps']): + if wnode['out'] in seen: + continue + elif all(dep['out'] in seen for dep in wnode['deps']): ret.append(wnode) seen.add(wnode['out']) else: @@ -510,6 +508,114 @@ def fix_emit_vars(emit_vars): waiting = new_waiting return tuple(ret) +def print_input(reg_out, mem_in): + #return '%s <- LOAD %s;\n' % (reg_out, mem_in) + #return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out) + return "" + +def print_load_specific_reg(reg, specific_reg='rdx'): + ret = '' + ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) + ret += '"mov %%[%s], %%%s\\t\\n"\n' % (reg, specific_reg) + return ret, (specific_reg,) +def print_unload_specific_reg(specific_reg='rdx'): + ret = '' + ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg) + return ret +def print_load(*regs): + TEMP_REG = ['arg%d' % d for d in reversed(range(15))] + ret, out_reg = '', [] + for reg in regs: + if reg in REGISTERS: + out_reg.append(reg) + continue + else: + cur_reg = TEMP_REG.pop() + ret += '"mov %%[%s], %%[%s]\\t\\n"\n' % (reg, cur_reg) + out_reg.append(cur_reg) + if len(out_reg) == 1: return ret, out_reg[0] + return ret, tuple(out_reg) + +def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src): + #return '%s:%s <- MULX %s, %s; // %s\n' % (reg_out_low, reg_out_high, rx1, rx2, src) + ret = '' + ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx') + ret3, actual_rx2 = print_load(rx2) + ret += ret2 + ret3 + ('"mulx %%[%s], %%[%s], %%[%s]\\t\\n" // %s\n' % (actual_rx2, reg_out_high, reg_out_low, src)) + ret += print_unload_specific_reg('rdx') + return ret + +def print_mov_bucket(reg_out, reg_in, bucket): + #return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket) + ret, reg_in = print_load(reg_in) + return ret + ('"mov %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_in, reg_out, bucket)) + +def print_mov(reg_out, reg_in): + #return '%s <- MOV %s;\n' % (reg_out, reg_in) + ret, reg_in = print_load(reg_in) + return ret + ('"mov %%[%s], %%[%s]\\t\\n"\n' % (reg_in, reg_out)) + +LAST_CARRY = None + +def print_mul_by_constant(reg_out, reg_in, constant, src): + #return '%s <- MULX %s, %s; // %s\n' % (ret_out, reg_in, constant, src) + #assert(LAST_CARRY is None) + global LAST_CARRY + ret, reg_in = print_load(reg_in) + if constant == '0x13': + return ret + ('FIXME: lea for %s\n' % src) + else: + LAST_CARRY = None + return ret + ('"imul %%[%s], $%s, %%[%s]\\t\\n" // %s\n' % (reg_in, constant, reg_out, src)) + +def print_adx(reg_out, rx1, rx2, bucket): + #return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket) + assert(rx1 == reg_out) + ret, rx2 = print_load(rx2) + return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + +def print_add(reg_out, cf, rx1, rx2, bucket): + #return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket) + global LAST_CARRY + assert(reg_out == rx1) + #assert(LAST_CARRY is None or LAST_CARRY == cf) + LAST_CARRY = cf + ret, rx2 = print_load(rx2) + return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + +def print_adc(reg_out, cf, rx1, rx2, bucket): + #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf, cf, rx1, rx2, bucket) + assert(reg_out == rx1) + ret = '' + global LAST_CARRY + if LAST_CARRY != cf: + ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf) + LAST_CARRY = cf + ret2, rx2 = print_load(rx2) + ret += ret2 + return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket)) + +def print_adcx(reg_out, cf, bucket): + #return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket) + assert(LAST_CARRY == cf) + return ('"adcx $0, %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, bucket)) + +def print_and(reg_out, rx1, rx2, src): + #return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src) + global LAST_CARRY + LAST_CARRY = None + if reg_out != rx1: + return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src) + else: + if rx2[:2] == '0x': + return ('"and $%s, %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) + else: + ret, rx2 = print_load(rx2) + return ret + ('"and %%[%s], %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src)) + +#def print_shr(reg_out, rx1, imm, src): + #return '%s <- SHR %s, %s;\n' % + def schedule(input_data, existing, emit_vars): ret = '' buckets_seen = set() @@ -518,33 +624,35 @@ def schedule(input_data, existing, emit_vars): ret += ('// Convention is low_reg:high_reg\n') for node in emit_vars: if node['op'] == 'INPUT': - ret += ('%s <- LOAD %s;\n' % (existing[node['out']], node['out'])) + ret += print_input(existing[node['out']], node['out']) elif node['op'] == '*' and len(node['deps']) == 2: - ret += ('%s <- MULX %s, %s; // %s = %s * %s\n' - % (existing[node['out']], - existing[node['deps'][0]['out']], - existing[node['deps'][1]['out']], - node['out'], - node['deps'][0]['out'], - node['deps'][1]['out'])) + assert(len(existing[node['out']].split(':')) == 2) + out_low, out_high = existing[node['out']].split(':') + ret += print_mulx(out_low, out_high, + existing[node['deps'][0]['out']], + existing[node['deps'][1]['out']], + '%s = %s * %s' + % (node['out'], + node['deps'][0]['out'], + node['deps'][1]['out'])) elif node['op'] == '*' and len(node['deps']) == 1: extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0] - ret += ('%s <- MULX %s, %s; // %s = %s * %s\n' - % (existing[node['out']], - existing[node['deps'][0]['out']], - extra_arg, - node['out'], - node['deps'][0]['out'], - extra_arg)) + ret += print_mul_by_constant(existing[node['out']], + existing[node['deps'][0]['out']], + extra_arg, + '%s = %s * %s' + % (node['out'], + node['deps'][0]['out'], + extra_arg)) elif node['op'] == '&' and len(node['deps']) == 1: extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0] - ret += ('%s <- AND %s, %s; // %s = %s & %s\n' - % (existing[node['out']], - existing[node['deps'][0]['out']], - extra_arg, - node['out'], - node['deps'][0]['out'], - extra_arg)) + ret += print_and(existing[node['out']], + existing[node['deps'][0]['out']], + extra_arg, + '%s = %s & %s' + % (node['out'], + node['deps'][0]['out'], + extra_arg)) elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['op'] == 'COMBINE': extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0] ret += ('%s <- SHR %s:%s, %s; // %s = %s:%s >> %s\n' @@ -567,52 +675,45 @@ def schedule(input_data, existing, emit_vars): extra_arg)) elif node['op'] in ('GET_HIGH', 'GET_LOW'): if node['rev_deps'][0]['out'] not in buckets_seen: - ret += ('%s <- MOV %s; // bucket: %s\n' - % (existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))) + ret += print_mov_bucket(existing[node['rev_deps'][0]['out']], + existing[node['out']], + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) buckets_seen.add(node['rev_deps'][0]['out']) elif node['op'] == 'GET_HIGH': - ret += ('%s <- ADX %s, %s; // bucket: %s\n' - % (existing[node['rev_deps'][0]['out']], - existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))) + ret += print_adx(existing[node['rev_deps'][0]['out']], + existing[node['rev_deps'][0]['out']], + existing[node['out']], + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) elif node['op'] == 'GET_LOW': carry = 'c' + node['rev_deps'][0]['out'][:-len('_low')] if node['rev_deps'][0]['out'] not in buckets_carried: - ret += ('%s, (%s) <- ADD %s, %s; // bucket: %s\n' - % (existing[node['rev_deps'][0]['out']], - carry, - existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))) + ret += print_add(existing[node['rev_deps'][0]['out']], + carry, + existing[node['rev_deps'][0]['out']], + existing[node['out']], + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) buckets_carried.add(node['rev_deps'][0]['out']) else: - ret += ('%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' - % (existing[node['rev_deps'][0]['out']], - carry, - carry, - existing[node['rev_deps'][0]['out']], - existing[node['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))) + ret += print_adc(existing[node['rev_deps'][0]['out']], + carry, + existing[node['rev_deps'][0]['out']], + existing[node['out']], + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) elif node['op'] in ('GET_CARRY',): carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')] - ret += ('%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' - % (existing[node['rev_deps'][0]['out']], - carry, - existing[node['rev_deps'][0]['out']], - ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))) + ret += print_adcx(existing[node['rev_deps'][0]['out']], + carry, + ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out'])))) elif node['op'] == '+' and len(node['extra_out']) > 0: pass elif node['op'] == '+' and len(node['deps']) == 2 and node['type'] == 'uint64_t': - ret += ('%s <- ADX %s, %s; // %s = %s + %s\n' - % (existing[node['out']], - existing[node['deps'][0]['out']], - existing[node['deps'][1]['out']], - node['out'], - node['deps'][0]['out'], - node['deps'][1]['out'])) + ret += print_adx(existing[node['out']], + existing[node['deps'][0]['out']], + existing[node['deps'][1]['out']], + '%s = %s + %s' + % (node['out'], + node['deps'][0]['out'], + node['deps'][1]['out'])) elif node['op'] in ('COMBINE',): pass else: @@ -621,35 +722,30 @@ def schedule(input_data, existing, emit_vars): for rdep in node['rev_deps']: if len(rdep['extra_out']) > 0 and rdep['op'] == '+': if rdep['out'] not in buckets_seen: - ret += ('%s <- MOV %s; // bucket: %s\n' - % (existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))) + ret += print_mov_bucket(existing[rdep['out']], + existing[node['out']], + ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) buckets_seen.add(rdep['out']) elif 'high' in rdep['out']: - ret += ('%s <- ADX %s, %s; // bucket: %s\n' - % (existing[rdep['out']], - existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))) + ret += print_adx(existing[rdep['out']], + existing[rdep['out']], + existing[node['out']], + ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) elif 'low' in rdep['out']: carry = 'c' + rdep['out'][:-len('_low')] if rdep['out'] not in buckets_carried: - ret += ('%s, (%s) <- ADD %s, %s; // bucket: %s\n' - % (existing[rdep['out']], - carry, - existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))) + ret += print_add(existing[rdep['out']], + carry, + existing[rdep['out']], + existing[node['out']], + ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) buckets_carried.add(rdep['out']) else: - ret += ('%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' - % (existing[rdep['out']], - carry, - carry, - existing[rdep['out']], - existing[node['out']], - ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))) + ret += print_adc(existing[rdep['out']], + carry, + existing[rdep['out']], + existing[node['out']], + ' + '.join(sorted([rdep['out']] + list(rdep['extra_out'])))) else: assert(False) return ret @@ -680,7 +776,18 @@ for i, data in enumerate(data_list): if 'tmp' not in v: ret += list(vars_for(v, rec=False)) return tuple(ret) - for var in list(vars_for('x10')) + list(vars_for('x11')) + list(vars_for('x9')) + list(vars_for('x7')) + list(vars_for('x5')): # tuple(): #('x20_tmp', 'x49_tmp', 'x51_tmp', 'x55_tmp', 'x53_tmp'): + def vars_for_bucket(var): + if '_' not in var: + return tuple(list(vars_for_bucket(var + '_low')) + list(vars_for_bucket(var + '_high'))) + ret = [] + for dep in objs[var]['deps']: + if dep['op'] in ('GET_HIGH', 'GET_LOW'): + assert(len(dep['deps']) == 1) + assert('tmp' in dep['deps'][0]['out']) + ret.append(dep['deps'][0]['out']) + return tuple(ret) +# for var in list(vars_for('x10')) + list(vars_for('x11')) + list(vars_for('x9')) + list(vars_for('x7')) + list(vars_for('x5')): # tuple(): #('x20_tmp', 'x49_tmp', 'x51_tmp', 'x55_tmp', 'x53_tmp'): + for var in list(vars_for_bucket('x56')) + list(vars_for_bucket('x71')) + list(vars_for_bucket('x74')) + list(vars_for_bucket('x77')) + list(vars_for_bucket('x80')): # + list(vars_for('x11')) + list(vars_for('x9')) + list(vars_for('x7')) + list(vars_for('x5')): # tuple(): #('x20_tmp', 'x49_tmp', 'x51_tmp', 'x55_tmp', 'x53_tmp'): print(var) cur_possible_nodes = [n for n in possible_nodes if n['out'] == var] cur_possible_nodes, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars \ |