aboutsummaryrefslogtreecommitdiff
path: root/etc
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-08-17 00:24:05 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-08-17 00:24:05 -0400
commit6de8d499c797c0f107d1bdc13ad785f7c5dbdf69 (patch)
tree3d6348f133169f89f73eac8ab6d6d5e7f213c4d8 /etc
parent0ac04e0b1f28c5c9e2073335809adf4837c04cc5 (diff)
More WIP on register allocation
The current allocation is terrible, probably because we are currently requiring that all instructions output to registers. My current guess at a decent thing to do is to make a pass, after register allocation, and eliminate all registers that simply get stored to memory, replacing the relevant instructions with the memory-using versions. Then we can re-register allocate, ignoring values that go straight to memory.
Diffstat (limited to 'etc')
-rwxr-xr-xetc/compile-by-zinc/heuristic-search.py84
1 files changed, 78 insertions, 6 deletions
diff --git a/etc/compile-by-zinc/heuristic-search.py b/etc/compile-by-zinc/heuristic-search.py
index 57784183e..e2d0b43fd 100755
--- a/etc/compile-by-zinc/heuristic-search.py
+++ b/etc/compile-by-zinc/heuristic-search.py
@@ -498,6 +498,8 @@ def schedule(data, basepoint, do_print):
next_next_var_dict = dict((register_ranges[reg][loc+2], reg) for reg in register_ranges.keys())
if var in var_dict.keys():
return [var_dict[var]]
+ elif var + '_low' in var_dict.keys() and var + '_high' in var_dict.keys():
+ return ['%s:%s' % (var_dict[var + '_high'], var_dict[var + '_low'])]
elif var + '_low' in var_dict.keys() and var + '_high' in next_var_dict.keys():
return ['%s:%s' % (next_var_dict[var + '_high'], var_dict[var + '_low'])]
elif var + '_low' in next_var_dict.keys() and var + '_high' in next_next_var_dict.keys():
@@ -507,9 +509,17 @@ def schedule(data, basepoint, do_print):
def update_source(line, loc, register_ranges):
source = line['source']
- for var in sorted([line['out']] + list(line['args']), key=len):
- for reg in lookup_var_in_reg(register_ranges, loc, var):
- source = source.replace(var, reg)
+ if line['op'] in ('&', '|', '^') and line['type'] == 'uint64_t' and line['args'][1][:2] == '0x':
+ for reg in lookup_var_in_reg(register_ranges, loc, line['out']):
+ source = source.replace(line['out'], reg)
+ for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0] + '_low'):
+ source = source.replace(line['args'][0], reg)
+ for reg in lookup_var_in_reg(register_ranges, loc, line['args'][0]):
+ source = source.replace(line['args'][0], reg)
+ else:
+ for var in sorted([line['out']] + list(line['args']), key=len):
+ for reg in lookup_var_in_reg(register_ranges, loc, var):
+ source = source.replace(var, reg)
return source
def get_next_registers_use(loc, register_ranges, live_ranges):
@@ -554,8 +564,18 @@ def schedule(data, basepoint, do_print):
register_ranges[reg][new_loc] = None
return reg, register_ranges
+ def prune_stores(stores_available, movs_needed):
+ movs_used = set(arg for mov_type, reg, arg, loc in movs_needed)
+ for mov_type, arg, reg, loc in stores_available:
+ if arg in movs_used:
+ yield (mov_type, arg, reg, loc)
+
+ def insert_movs(schedule_with_cycles, register_ranges, live_registers):
+ pass
def linear_allocate(var_to_line, schedule_with_cycles, live_ranges, register_ranges):
+ movs_needed = []
+ stores_available = []
for (var, locs, core, args) in schedule_with_cycles:
registers = free_registers(locs['start'], register_ranges, live_ranges)
line = var_to_line[var]
@@ -571,10 +591,36 @@ def schedule(data, basepoint, do_print):
reg, registers = registers[0], registers[1:]
assert register_ranges[reg][locs['start'] + latency] is None
register_ranges[reg][locs['start'] + latency] = arg + bits
+ movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency))
else:
reg = found[0]
+ if reg == 'RDX':
+ if locs['start'] + latency == 0:
+ movs_needed.append(('MOVmr', arg + bits, reg, locs['start'] + latency))
+ elif register_ranges['RDX'][locs['start'] + latency - 1] != arg + bits:
+ # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr
+ mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys()
+ if register_ranges['RDX'][locs['start'] + latency - 1] == arg + bits]
+ + [('MOVmr', arg + bits)])[0]
+ movs_needed.append((mov_type, from_val, reg, locs['start'] + latency))
if argi == 0:
register_ranges[reg][locs['start'] + latency + 1] = line['out'] + bits
+ stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency + 1))
+ elif line['type'] == 'uint64_t' and line['op'] == '>>' and line['args'][0] in var_to_line.keys() and var_to_line[line['args'][0]]['type'] == 'uint128_t' and line['args'][1][:2] == '0x':
+ for arg, latency in ((line['out'], core['latency']), (line['args'][0] + '_low', 0), (line['args'][0] + '_high', 0)):
+ found = list(lookup_var_in_reg(register_ranges, locs['start'], arg))
+ if len(found) == 0:
+ if len(registers) == 0:
+ reg, register_ranges = spill_register(locs['start'] + latency, register_ranges, line['args'])
+ else:
+ reg, registers = registers[0], registers[1:]
+ assert register_ranges[reg][locs['start'] + latency] is None
+ register_ranges[reg][locs['start'] + latency] = arg
+ if arg != line['out']:
+ for c in range(latency):
+ movs_needed.append(('MOVmr', arg, reg, locs['start'] + c))
+ else:
+ stores_available.append(('MOVrm', reg, line['out'] + bits, locs['start'] + latency))
else:
if line['type'] == 'uint128_t':
out_args = [line['out'] + '_high', line['out'] + '_low']
@@ -582,18 +628,44 @@ def schedule(data, basepoint, do_print):
out_args = [line['out']]
for arg in sorted(out_args + list(line['args']), key=len):
if arg[:2] == '0x': continue
- if len(list(lookup_var_in_reg(register_ranges, locs['start'], arg))) == 0:
+ if line['type'] == 'uint64_t' and arg in var_to_line.keys() and var_to_line[arg]['type'] == 'uint128_t':
+ if line['op'] in ('&', '|', '^'): arg = arg + '_low'
+ found = list(lookup_var_in_reg(register_ranges, locs['start'], arg))
+ if len(found) == 0:
if len(registers) == 0:
reg, register_ranges = spill_register(locs['start'], register_ranges, out_args + list(line['args']))
else:
reg, registers = registers[0], registers[1:]
assert register_ranges[reg][locs['start']] is None
- for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
- register_ranges[reg][locs['start'] + latency] = arg
if arg in out_args:
for latency in range(core['latency']+1):
register_ranges[reg][locs['start'] + latency] = arg
+ stores_available.append(('MOVrm', reg, arg, locs['start'] + core['latency']))
+ else:
+ for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
+ register_ranges[reg][locs['start'] + latency] = arg
+ movs_needed.append(('MOVmr', arg, reg, locs['start']))
+ elif arg in out_args:
+ assert False
+ else:
+ reg = found[0]
+ if reg == 'RDX':
+ if locs['start'] == 0:
+ movs_needed.append(('MOVmr', arg, reg, locs['start']))
+ elif register_ranges['RDX'][locs['start'] - 1] != arg:
+ # MOVrr if the value is in some register, otherwise (if the filtered list is empty), MOVmr
+ mov_type, from_val = ([('MOVrr', reg_from) for reg_from in register_ranges.keys()
+ if register_ranges['RDX'][locs['start'] - 1] == arg]
+ + [('MOVmr', arg)])[0]
+ movs_needed.append((mov_type, from_val, reg, locs['start']))
+
print(var_to_line[var]['source'] + ' // ' + update_source(var_to_line[var], locs['start'], register_ranges))
+ movs_needed = sorted(movs_needed, key=(lambda v: v[3]))
+ stores_needed = sorted(prune_stores(stores_available, movs_needed), key=(lambda v: v[3]))
+ print(len(movs_needed))
+ print(len(stores_needed))
+ for i in movs_needed: print(i)
+ for i in stores_needed: print(i)
# sys.exit(0)
# def insert_possible_registers(live_ranges):