From 6de8d499c797c0f107d1bdc13ad785f7c5dbdf69 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 17 Aug 2017 00:24:05 -0400 Subject: More WIP on register allocation The current allocation is terrible, probably because we are currently requiring that all instructions output to registers. My current guess at a decent thing to do is to make a pass, after register allocation, and eliminate all registers that simply get stored to memory, replacing the relevant instructions with the memory-using versions. Then we can re-register allocate, ignoring values that go straight to memory. --- etc/compile-by-zinc/heuristic-search.py | 84 ++++++++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 6 deletions(-) (limited to 'etc') diff --git a/etc/compile-by-zinc/heuristic-search.py b/etc/compile-by-zinc/heuristic-search.py index 57784183e..e2d0b43fd 100755 --- a/etc/compile-by-zinc/heuristic-search.py +++ b/etc/compile-by-zinc/heuristic-search.py @@ -498,6 +498,8 @@ def schedule(data, basepoint, do_print): next_next_var_dict = dict((register_ranges[reg][loc+2], reg) for reg in register_ranges.keys()) if var in var_dict.keys(): return [var_dict[var]] + elif var + '_low' in var_dict.keys() and var + '_high' in var_dict.keys(): + return ['%s:%s' % (var_dict[var + '_high'], var_dict[var + '_low'])] elif var + '_low' in var_dict.keys() and var + '_high' in next_var_dict.keys(): return ['%s:%s' % (next_var_dict[var + '_high'], var_dict[var + '_low'])] elif var + '_low' in next_var_dict.keys() and var + '_high' in next_next_var_dict.keys(): @@ -507,9 +509,17 @@ def schedule(data, basepoint, do_print): def update_source(line, loc, register_ranges): source = line['source'] - for var in sorted([line['out']] + list(line['args']), key=len): - for reg in lookup_var_in_reg(register_ranges, loc, var): - source = source.replace(var, reg) + if line['op'] in ('&', '|', '^') and line['type'] == 'uint64_t' and line['args'][1][:2] == '0x': + for reg in lookup_var_in_reg(register_ranges, loc, line['out']): + source = source.replace(line['out'], reg) + for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0] + '_low'): + source = source.replace(line['args'][0], reg) + for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0]): + source = source.replace(line['args'][0], reg) + else: + for var in sorted([line['out']] + list(line['args']), key=len): + for reg in lookup_var_in_reg(register_ranges, loc, var): + source = source.replace(var, reg) return source def get_next_registers_use(loc, register_ranges, live_ranges): @@ -554,8 +564,18 @@ def schedule(data, basepoint, do_print): register_ranges[reg][new_loc] = None return reg, register_ranges + def prune_stores(stores_available, movs_needed): + movs_used = set(arg for mov_type, reg, arg, loc in movs_needed) + for mov_type, arg, reg, loc in stores_available: + if arg in movs_used: + yield (mov_type, arg, reg, loc) + + def insert_movs(schedule_with_cycles, register_ranges, live_registers): + pass def linear_allocate(var_to_line, schedule_with_cycles, live_ranges, register_ranges): + movs_needed = [] + stores_available = [] for (var, locs, core, args) in schedule_with_cycles: registers = free_registers(locs['start'], register_ranges, live_ranges) line = var_to_line[var] @@ -571,10 +591,36 @@ def schedule(data, basepoint, do_print): reg, registers = registers[0], registers[1:] assert register_ranges[reg][locs['start'] + latency] is None register_ranges[reg][locs['start'] + latency] = arg + bits + movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency)) else: reg = found[0] + if reg == 'RDX': + if locs['start'] + latency == 0: + movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency)) + elif register_ranges['RDX'][locs['start'] + latency - 1] != arg + bits: + # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr + mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys() + if register_ranges['RDX'][locs['start'] + latency - 1] == arg + bits] + + [('MOVmr', arg + bits)])[0] + movs_needed.append((mov_type, from_val, reg, locs['start'] + latency)) if argi == 0: register_ranges[reg][locs['start'] + latency + 1] = line['out'] + bits + stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency + 1)) + elif line['type'] == 'uint64_t' and line['op'] == '>>' and line['args'][0] in var_to_line.keys() and var_to_line[line['args'][0]]['type'] == 'uint128_t' and line['args'][1][:2] == '0x': + for arg, latency in ((line['out'], core['latency']), (line['args'][0] + '_low', 0), (line['args'][0] + '_high', 0)): + found = list(lookup_var_in_reg(register_ranges, locs['start'], arg)) + if len(found) == 0: + if len(registers) == 0: + reg, register_ranges = spill_register(locs['start'] + latency, register_ranges, line['args']) + else: + reg, registers = registers[0], registers[1:] + assert register_ranges[reg][locs['start'] + latency] is None + register_ranges[reg][locs['start'] + latency] = arg + if arg != line['out']: + for c in range(latency): + movs_needed.append(('MOVmr', arg, reg, locs['start'] + c)) + else: + stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency)) else: if line['type'] == 'uint128_t': out_args = [line['out'] + '_high', line['out'] + '_low'] @@ -582,18 +628,44 @@ def schedule(data, basepoint, do_print): out_args = [line['out']] for arg in sorted(out_args + list(line['args']), key=len): if arg[:2] == '0x': continue - if len(list(lookup_var_in_reg(register_ranges, locs['start'], arg))) == 0: + if line['type'] == 'uint64_t' and arg in var_to_line.keys() and var_to_line[arg]['type'] == 'uint128_t': + if line['op'] in ('&', '|', '^'): arg = arg + '_low' + found = list(lookup_var_in_reg(register_ranges, locs['start'], arg)) + if len(found) == 0: if len(registers) == 0: reg, register_ranges = spill_register(locs['start'], register_ranges, out_args + list(line['args'])) else: reg, registers = registers[0], registers[1:] assert register_ranges[reg][locs['start']] is None - for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx - register_ranges[reg][locs['start'] + latency] = arg if arg in out_args: for latency in range(core['latency']+1): register_ranges[reg][locs['start'] + latency] = arg + stores_available.append(('MOVrm', reg, arg, locs['start'] + core['latency'])) + else: + for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx + register_ranges[reg][locs['start'] + latency] = arg + movs_needed.append(('MOVmr', arg, reg, locs['start'])) + elif arg in out_args: + assert False + else: + reg = found[0] + if reg == 'RDX': + if locs['start'] == 0: + movs_needed.append(('MOVmr', arg, reg, locs['start'])) + elif register_ranges['RDX'][locs['start'] - 1] != arg: + # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr + mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys() + if register_ranges['RDX'][locs['start'] - 1] == arg] + + [('MOVmr', arg)])[0] + movs_needed.append((mov_type, from_val, reg, locs['start'])) + print(var_to_line[var]['source'] + ' // ' + update_source(var_to_line[var], locs['start'], register_ranges)) + movs_needed = sorted(movs_needed, key=(lambda v: v[3])) + stores_needed = sorted(prune_stores(stores_available, movs_needed), key=(lambda v: v[3])) + print(len(movs_needed)) + print(len(stores_needed)) + for i in movs_needed: print(i) + for i in stores_needed: print(i) # sys.exit(0) # def insert_possible_registers(live_ranges): -- cgit v1.2.3