aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-08-17 00:24:05 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-08-17 00:24:05 -0400
commit6de8d499c797c0f107d1bdc13ad785f7c5dbdf69 (patch)
tree3d6348f133169f89f73eac8ab6d6d5e7f213c4d8
parent0ac04e0b1f28c5c9e2073335809adf4837c04cc5 (diff)
More WIP on register allocation
The current allocation is terrible, probably because we are currently requiring that all instructions output to registers. My current guess at a decent thing to do is to make a pass, after register allocation, and eliminate all registers that simply get stored to memory, replacing the relevant instructions with the memory-using versions. Then we can re-register allocate, ignoring values that go straight to memory.
-rwxr-xr-xetc/compile-by-zinc/heuristic-search.py84
1 files changed, 78 insertions, 6 deletions
diff --git a/etc/compile-by-zinc/heuristic-search.py b/etc/compile-by-zinc/heuristic-search.py
index 57784183e..e2d0b43fd 100755
--- a/etc/compile-by-zinc/heuristic-search.py
+++ b/etc/compile-by-zinc/heuristic-search.py
@@ -498,6 +498,8 @@ def schedule(data, basepoint, do_print):
next_next_var_dict = dict((register_ranges[reg][loc+2], reg) for reg in register_ranges.keys())
if var in var_dict.keys():
return [var_dict[var]]
+ elif var + '_low' in var_dict.keys() and var + '_high' in var_dict.keys():
+ return ['%s:%s' % (var_dict[var + '_high'], var_dict[var + '_low'])]
elif var + '_low' in var_dict.keys() and var + '_high' in next_var_dict.keys():
return ['%s:%s' % (next_var_dict[var + '_high'], var_dict[var + '_low'])]
elif var + '_low' in next_var_dict.keys() and var + '_high' in next_next_var_dict.keys():
@@ -507,9 +509,17 @@ def schedule(data, basepoint, do_print):
def update_source(line, loc, register_ranges):
source = line['source']
- for var in sorted([line['out']] + list(line['args']), key=len):
- for reg in lookup_var_in_reg(register_ranges, loc, var):
- source = source.replace(var, reg)
+ if line['op'] in ('&', '|', '^') and line['type'] == 'uint64_t' and line['args'][1][:2] == '0x':
+ for reg in lookup_var_in_reg(register_ranges, loc, line['out']):
+ source = source.replace(line['out'], reg)
+ for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0] + '_low'):
+ source = source.replace(line['args'][0], reg)
+ for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0]):
+ source = source.replace(line['args'][0], reg)
+ else:
+ for var in sorted([line['out']] + list(line['args']), key=len):
+ for reg in lookup_var_in_reg(register_ranges, loc, var):
+ source = source.replace(var, reg)
return source
def get_next_registers_use(loc, register_ranges, live_ranges):
@@ -554,8 +564,18 @@ def schedule(data, basepoint, do_print):
register_ranges[reg][new_loc] = None
return reg, register_ranges
+ def prune_stores(stores_available, movs_needed):
+ movs_used = set(arg for mov_type, reg, arg, loc in movs_needed)
+ for mov_type, arg, reg, loc in stores_available:
+ if arg in movs_used:
+ yield (mov_type, arg, reg, loc)
+
+ def insert_movs(schedule_with_cycles, register_ranges, live_registers):
+ pass
def linear_allocate(var_to_line, schedule_with_cycles, live_ranges, register_ranges):
+ movs_needed = []
+ stores_available = []
for (var, locs, core, args) in schedule_with_cycles:
registers = free_registers(locs['start'], register_ranges, live_ranges)
line = var_to_line[var]
@@ -571,10 +591,36 @@ def schedule(data, basepoint, do_print):
reg, registers = registers[0], registers[1:]
assert register_ranges[reg][locs['start'] + latency] is None
register_ranges[reg][locs['start'] + latency] = arg + bits
+ movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency))
else:
reg = found[0]
+ if reg == 'RDX':
+ if locs['start'] + latency == 0:
+ movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency))
+ elif register_ranges['RDX'][locs['start'] + latency - 1] != arg + bits:
+ # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr
+ mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys()
+ if register_ranges['RDX'][locs['start'] + latency - 1] == arg + bits]
+ + [('MOVmr', arg + bits)])[0]
+ movs_needed.append((mov_type, from_val, reg, locs['start'] + latency))
if argi == 0:
register_ranges[reg][locs['start'] + latency + 1] = line['out'] + bits
+ stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency + 1))
+ elif line['type'] == 'uint64_t' and line['op'] == '>>' and line['args'][0] in var_to_line.keys() and var_to_line[line['args'][0]]['type'] == 'uint128_t' and line['args'][1][:2] == '0x':
+ for arg, latency in ((line['out'], core['latency']), (line['args'][0] + '_low', 0), (line['args'][0] + '_high', 0)):
+ found = list(lookup_var_in_reg(register_ranges, locs['start'], arg))
+ if len(found) == 0:
+ if len(registers) == 0:
+ reg, register_ranges = spill_register(locs['start'] + latency, register_ranges, line['args'])
+ else:
+ reg, registers = registers[0], registers[1:]
+ assert register_ranges[reg][locs['start'] + latency] is None
+ register_ranges[reg][locs['start'] + latency] = arg
+ if arg != line['out']:
+ for c in range(latency):
+ movs_needed.append(('MOVmr', arg, reg, locs['start'] + c))
+ else:
+ stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency))
else:
if line['type'] == 'uint128_t':
out_args = [line['out'] + '_high', line['out'] + '_low']
@@ -582,18 +628,44 @@ def schedule(data, basepoint, do_print):
out_args = [line['out']]
for arg in sorted(out_args + list(line['args']), key=len):
if arg[:2] == '0x': continue
- if len(list(lookup_var_in_reg(register_ranges, locs['start'], arg))) == 0:
+ if line['type'] == 'uint64_t' and arg in var_to_line.keys() and var_to_line[arg]['type'] == 'uint128_t':
+ if line['op'] in ('&', '|', '^'): arg = arg + '_low'
+ found = list(lookup_var_in_reg(register_ranges, locs['start'], arg))
+ if len(found) == 0:
if len(registers) == 0:
reg, register_ranges = spill_register(locs['start'], register_ranges, out_args + list(line['args']))
else:
reg, registers = registers[0], registers[1:]
assert register_ranges[reg][locs['start']] is None
- for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
- register_ranges[reg][locs['start'] + latency] = arg
if arg in out_args:
for latency in range(core['latency']+1):
register_ranges[reg][locs['start'] + latency] = arg
+ stores_available.append(('MOVrm', reg, arg, locs['start'] + core['latency']))
+ else:
+ for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
+ register_ranges[reg][locs['start'] + latency] = arg
+ movs_needed.append(('MOVmr', arg, reg, locs['start']))
+ elif arg in out_args:
+ assert False
+ else:
+ reg = found[0]
+ if reg == 'RDX':
+ if locs['start'] == 0:
+ movs_needed.append(('MOVmr', arg, reg, locs['start']))
+ elif register_ranges['RDX'][locs['start'] - 1] != arg:
+ # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr
+ mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys()
+ if register_ranges['RDX'][locs['start'] - 1] == arg]
+ + [('MOVmr', arg)])[0]
+ movs_needed.append((mov_type, from_val, reg, locs['start']))
+
print(var_to_line[var]['source'] + ' // ' + update_source(var_to_line[var], locs['start'], register_ranges))
+ movs_needed = sorted(movs_needed, key=(lambda v: v[3]))
+ stores_needed = sorted(prune_stores(stores_available, movs_needed), key=(lambda v: v[3]))
+ print(len(movs_needed))
+ print(len(stores_needed))
+ for i in movs_needed: print(i)
+ for i in stores_needed: print(i)
# sys.exit(0)
# def insert_possible_registers(live_ranges):