aboutsummaryrefslogtreecommitdiff
path: root/etc
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-08-14 16:14:37 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-08-14 16:14:37 -0400
commit0ac04e0b1f28c5c9e2073335809adf4837c04cc5 (patch)
tree342ce27d6896894eaedc311740f78dfbdf965853 /etc
parentd9b3dc8017e837f9ce43736ee486db22b7e03f0f (diff)
Handle most of register allocation
Diffstat (limited to 'etc')
-rw-r--r--etc/compile-by-zinc/femulScheduled.log76
-rwxr-xr-xetc/compile-by-zinc/heuristic-search.py252
2 files changed, 271 insertions, 57 deletions
diff --git a/etc/compile-by-zinc/femulScheduled.log b/etc/compile-by-zinc/femulScheduled.log
index 70e059c1d..0312dcea0 100644
--- a/etc/compile-by-zinc/femulScheduled.log
+++ b/etc/compile-by-zinc/femulScheduled.log
@@ -9,47 +9,47 @@ uint128_t x27 = (uint128_t) x7 * x15; // MULX r64,r64,r64, start: 6, end: 10
uint128_t x32 = (uint128_t) x7 * x17; // MULX r64,r64,r64, start: 7, end: 11
uint128_t x41 = (uint128_t) x7 * x19; // MULX r64,r64,r64, start: 8, end: 12
uint128_t x23 = x21 + x22; // ADD; ADC(X), start: 9, end: 11
-uint128_t x25 = (uint128_t) x9 * x13; // MULX r64,r64,r64, start: 9, end: 13
-uint128_t x34 = (uint128_t) x9 * x15; // MULX r64,r64,r64, start: 10, end: 14
-uint128_t x43 = (uint128_t) x9 * x17; // MULX r64,r64,r64, start: 11, end: 15
-uint128_t x30 = (uint128_t) x11 * x13; // MULX r64,r64,r64, start: 12, end: 16
-uint128_t x26 = x24 + x25; // ADD; ADC(X), start: 13, end: 15
-uint128_t x39 = (uint128_t) x11 * x15; // MULX r64,r64,r64, start: 13, end: 17
-uint128_t x37 = (uint128_t) x10 * x13; // MULX r64,r64,r64, start: 14, end: 18
-uint128_t x28 = x26 + x27; // ADD; ADC(X), start: 15, end: 17
-uint64_t x45 = x10 * 0x13; // IMUL r64,r64,i, start: 15, end: 18
-uint128_t x31 = x29 + x30; // ADD; ADC(X), start: 16, end: 18
+uint64_t x46 = x7 * 0x13; // IMUL r64,r64,i, start: 9, end: 12
+uint128_t x25 = (uint128_t) x9 * x13; // MULX r64,r64,r64, start: 10, end: 14
+uint128_t x34 = (uint128_t) x9 * x15; // MULX r64,r64,r64, start: 11, end: 15
+uint128_t x43 = (uint128_t) x9 * x17; // MULX r64,r64,r64, start: 12, end: 16
+uint64_t x47 = x9 * 0x13; // IMUL r64,r64,i, start: 13, end: 16
+uint128_t x26 = x24 + x25; // ADD; ADC(X), start: 14, end: 16
+uint128_t x30 = (uint128_t) x11 * x13; // MULX r64,r64,r64, start: 14, end: 18
+uint128_t x39 = (uint128_t) x11 * x15; // MULX r64,r64,r64, start: 15, end: 19
+uint128_t x28 = x26 + x27; // ADD; ADC(X), start: 16, end: 18
uint64_t x48 = x11 * 0x13; // IMUL r64,r64,i, start: 16, end: 19
-uint64_t x47 = x9 * 0x13; // IMUL r64,r64,i, start: 17, end: 20
-uint128_t x33 = x31 + x32; // ADD; ADC(X), start: 18, end: 20
-uint128_t x38 = x36 + x37; // ADD; ADC(X), start: 18, end: 20
-uint64_t x46 = x7 * 0x13; // IMUL r64,r64,i, start: 18, end: 21
-uint128_t x49 = (uint128_t) x45 * x15; // MULX r64,r64,r64, start: 19, end: 23
-uint128_t x35 = x33 + x34; // ADD; ADC(X), start: 20, end: 22
-uint128_t x40 = x38 + x39; // ADD; ADC(X), start: 20, end: 22
+uint128_t x37 = (uint128_t) x10 * x13; // MULX r64,r64,r64, start: 17, end: 21
+uint128_t x31 = x29 + x30; // ADD; ADC(X), start: 18, end: 20
+uint64_t x45 = x10 * 0x13; // IMUL r64,r64,i, start: 18, end: 21
+uint128_t x51 = (uint128_t) x46 * x18; // MULX r64,r64,r64, start: 19, end: 23
+uint128_t x33 = x31 + x32; // ADD; ADC(X), start: 20, end: 22
uint128_t x53 = (uint128_t) x47 * x19; // MULX r64,r64,r64, start: 20, end: 24
-uint128_t x61 = (uint128_t) x48 * x19; // MULX r64,r64,r64, start: 21, end: 25
-uint128_t x42 = x40 + x41; // ADD; ADC(X), start: 22, end: 24
-uint128_t x63 = (uint128_t) x45 * x19; // MULX r64,r64,r64, start: 22, end: 26
-uint128_t x50 = x20 + x49; // ADD; ADC(X), start: 23, end: 25
-uint128_t x51 = (uint128_t) x46 * x18; // MULX r64,r64,r64, start: 23, end: 27
-uint128_t x44 = x42 + x43; // ADD; ADC(X), start: 24, end: 26
-uint128_t x59 = (uint128_t) x47 * x18; // MULX r64,r64,r64, start: 24, end: 28
-uint128_t x65 = (uint128_t) x48 * x18; // MULX r64,r64,r64, start: 25, end: 29
+uint128_t x38 = x36 + x37; // ADD; ADC(X), start: 21, end: 23
+uint128_t x59 = (uint128_t) x47 * x18; // MULX r64,r64,r64, start: 21, end: 25
+uint128_t x35 = x33 + x34; // ADD; ADC(X), start: 22, end: 24
+uint128_t x49 = (uint128_t) x45 * x15; // MULX r64,r64,r64, start: 22, end: 26
+uint128_t x40 = x38 + x39; // ADD; ADC(X), start: 23, end: 25
+uint128_t x57 = (uint128_t) x45 * x17; // MULX r64,r64,r64, start: 23, end: 27
+uint128_t x63 = (uint128_t) x45 * x19; // MULX r64,r64,r64, start: 24, end: 28
+uint128_t x42 = x40 + x41; // ADD; ADC(X), start: 25, end: 27
+uint128_t x67 = (uint128_t) x45 * x18; // MULX r64,r64,r64, start: 25, end: 29
+uint128_t x50 = x20 + x49; // ADD; ADC(X), start: 26, end: 28
uint128_t x55 = (uint128_t) x48 * x17; // MULX r64,r64,r64, start: 26, end: 30
-uint128_t x64 = x28 + x63; // ADD; ADC(X), start: 26, end: 28
-uint128_t x52 = x50 + x51; // ADD; ADC(X), start: 27, end: 29
-uint128_t x57 = (uint128_t) x45 * x17; // MULX r64,r64,r64, start: 27, end: 31
-uint128_t x67 = (uint128_t) x45 * x18; // MULX r64,r64,r64, start: 28, end: 32
-uint128_t x54 = x52 + x53; // ADD; ADC(X), start: 29, end: 31
-uint128_t x66 = x64 + x65; // ADD; ADC(X), start: 29, end: 31
-uint128_t x56 = x54 + x55; // ADD; ADC(X), start: 31, end: 33
-uint128_t x58 = x23 + x57; // ADD; ADC(X), start: 31, end: 33
-uint128_t x60 = x58 + x59; // ADD; ADC(X), start: 33, end: 35
-uint128_t x68 = x35 + x67; // ADD; ADC(X), start: 33, end: 35
-uint64_t x69 = (uint64_t) (x56 >> 0x33); // SHRD r,r,i, start: 33, end: 36
-uint64_t x70 = (uint64_t) x56 & 0x7ffffffffffff; // AND, start: 33, end: 34
-uint128_t x62 = x60 + x61; // ADD; ADC(X), start: 35, end: 37
+uint128_t x44 = x42 + x43; // ADD; ADC(X), start: 27, end: 29
+uint128_t x61 = (uint128_t) x48 * x19; // MULX r64,r64,r64, start: 27, end: 31
+uint128_t x52 = x50 + x51; // ADD; ADC(X), start: 28, end: 30
+uint128_t x65 = (uint128_t) x48 * x18; // MULX r64,r64,r64, start: 28, end: 32
+uint128_t x58 = x23 + x57; // ADD; ADC(X), start: 29, end: 31
+uint128_t x54 = x52 + x53; // ADD; ADC(X), start: 30, end: 32
+uint128_t x60 = x58 + x59; // ADD; ADC(X), start: 31, end: 33
+uint128_t x56 = x54 + x55; // ADD; ADC(X), start: 32, end: 34
+uint128_t x62 = x60 + x61; // ADD; ADC(X), start: 33, end: 35
+uint128_t x64 = x28 + x63; // ADD; ADC(X), start: 34, end: 36
+uint64_t x69 = (uint64_t) (x56 >> 0x33); // SHRD r,r,i, start: 34, end: 37
+uint64_t x70 = (uint64_t) x56 & 0x7ffffffffffff; // AND, start: 34, end: 35
+uint128_t x68 = x35 + x67; // ADD; ADC(X), start: 35, end: 37
+uint128_t x66 = x64 + x65; // ADD; ADC(X), start: 36, end: 38
uint128_t x71 = x69 + x62; // ADD; ADC(X), start: 37, end: 39
uint64_t x72 = (uint64_t) (x71 >> 0x33); // SHRD r,r,i, start: 39, end: 42
uint64_t x73 = (uint64_t) x71 & 0x7ffffffffffff; // AND, start: 39, end: 40
diff --git a/etc/compile-by-zinc/heuristic-search.py b/etc/compile-by-zinc/heuristic-search.py
index c09e93627..57784183e 100755
--- a/etc/compile-by-zinc/heuristic-search.py
+++ b/etc/compile-by-zinc/heuristic-search.py
@@ -12,8 +12,8 @@ MAX_INSTRUCTION_WINDOW = 1000
INSTRUCTIONS_PER_CYCLE = 4
-REGISTERS = tuple(['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP', 'RSP']
- + ['r%d' % i for i in range(8, 16)])
+REGISTERS = tuple(['r%d' % i for i in range(8, 16)]
+ + ['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP', 'RSP'])
CORE_DATA = tuple(('p%d' % i, 1) for i in range(8))
CORES = tuple(name for name, count in CORE_DATA)
@@ -285,19 +285,33 @@ def schedule(data, basepoint, do_print):
def unfreeze(v):
return unfreeze_gen(v, unfreeze)
- def update_register_vals_with_core_args(core, args, register_vals):
- new_rdx = register_vals['RDX']
- if 'MULX' in core['instruction']:
- new_rdx = sorted(args, key=(lambda x: int(x.lstrip('0xrs'))))[0]
- changed = (register_vals['RDX'] != new_rdx)
- register_vals['RDX'] = new_rdx
- return changed, register_vals
+
+ def is_mul(core, use_imul=False):
+ return 'MULX' in core['instruction'] or (use_imul and 'IMUL' in core['instruction'])
+
+ def get_mulx_rdx_args(core, args, use_imul=False):
+ if is_mul(core, use_imul=use_imul):
+ if args[0][:2] == '0x': return (args[1],)
+ return tuple(args[:1])
+ return tuple(sorted(args, key=(lambda x: int(x.lstrip('0xrs'))))[:1])
+ return tuple()
+
+ def update_register_vals_with_core_args(var, core, args, register_vals):
+ changed = False
+ rdx_is_last_imul_output = False
+ if 'last_imul_output' not in register_vals: register_vals['last_imul_output'] = None
+ for new_rdx in get_mulx_rdx_args(core, args, use_imul=True):
+ rdx_is_last_imul_output = (register_vals['last_imul_output'] == new_rdx)
+ changed = (register_vals['RDX'] != new_rdx)
+ register_vals['RDX'] = new_rdx
+ if 'IMUL' in core['instruction']: register_vals['last_imul_output'] = var
+ return changed, rdx_is_last_imul_output, register_vals
@memoize
def update_core_state(var, core, args, core_state):
core = unfreeze(core)
(vars_remaining_cycles, core_remaining_cycle_count, cur_instructions_in_cycle, register_vals) = unfreeze(core_state)
- changed, register_vals = update_register_vals_with_core_args(core, args, register_vals)
+ changed, rdx_is_last_imul_output, register_vals = update_register_vals_with_core_args(var, core, args, register_vals)
cost = 0
if cur_instructions_in_cycle >= INSTRUCTIONS_PER_CYCLE:
cost += 1
@@ -340,16 +354,20 @@ def schedule(data, basepoint, do_print):
def get_wait_times(var_cores, core_state):
for var, core, args in var_cores:
(vars_remaining_cycles, core_remaining_cycle_count, cur_instructions_in_cycle, register_vals) = unfreeze(core_state)
- changed, register_vals = update_register_vals_with_core_args(core, args, register_vals)
+ changed, rdx_is_last_imul_output, register_vals = update_register_vals_with_core_args(var, core, args, register_vals)
cost, new_core_state = update_core_state(var, freeze(core), args, core_state)
- yield (cost, -len(reverse_dependencies.get(var, [])), changed, -core['latency'], var, core, args, new_core_state)
+ yield (cost, -len(reverse_dependencies.get(var, [])), changed, -core['latency'], not rdx_is_last_imul_output, var, get_mulx_rdx_args(core, args, use_imul=True), core, args, new_core_state)
def cmp_wait_time(v1, v2):
- (cost1, neg_len_deps1, changed1, neg_latency1, var1, core1, args1, new_core_state1) = v1
- (cost2, neg_len_deps2, changed2, neg_latency2, var2, core2, args2, new_core_state2) = v2
+ (cost1, neg_len_deps1, changed1, neg_latency1, not_rdx_is_last_imul_output1, var1, mulx_rdx_args, core1, args1, new_core_state1) = v1
+ (cost2, neg_len_deps2, changed2, neg_latency2, not_rdx_is_last_imul_output2, var2, mulx_rdx_args, core2, args2, new_core_state2) = v2
if cost1 != cost2: return cmp(cost1, cost2)
- if core1['instruction'] == core2['instruction']:
+ if core1['instruction'] == core2['instruction'] or (is_mul(core1, use_imul=True) and is_mul(core2, use_imul=True)):
if changed1 != changed2: return cmp(changed1, changed2)
+ if not_rdx_is_last_imul_output1 != not_rdx_is_last_imul_output2: return cmp(not_rdx_is_last_imul_output1, not_rdx_is_last_imul_output2)
+ if core1['instruction'] != core2['instruction'] and \
+ (True or get_mulx_rdx_args(core1, args1, use_imul=True) == get_mulx_rdx_args(core2, args2, use_imul=True)):
+ return cmp(var1, var2)
if neg_len_deps1 != neg_len_deps2: return cmp(neg_len_deps1, neg_len_deps2)
if neg_latency1 != neg_latency2: return cmp(neg_latency1, neg_latency2)
if var1 != var2: return cmp(var1, var2)
@@ -367,7 +385,8 @@ def schedule(data, basepoint, do_print):
cur_cycle += cost
schedule_with_cycle_info.append((var,
{'start':cur_cycle, 'finish':cur_cycle + core['latency']},
- core))
+ core,
+ args))
return schedule_with_cycle_info
def evaluate_cost(schedule_with_cycle_info):
@@ -393,12 +412,25 @@ def schedule(data, basepoint, do_print):
if len(sorted_next_statements) > 0:
pre_min_cost = sorted_next_statements[0][0]
# print((pre_min_cost, tuple(var for cost2, var, core, new_core_state in sorted_next_statements if pre_min_cost == cost2)))
+ rdx_with_imul = set(arg
+ for cost, reverse_dep_count, changed, neg_latency, not_rdx_is_last_imul_output, var, mulx_rdx_args, core, args, new_core_state in sorted_next_statements
+ for arg in mulx_rdx_args
+ if pre_min_cost == cost and 'IMUL' in core['instruction'])
sorted_subset_next_statements \
- = tuple((cost, var, core, args, new_core_state) for cost, reverse_dep_count, changed, neg_latency, var, core, args, new_core_state in sorted_next_statements
+ = tuple((cost, var, core, args, new_core_state, mulx_rdx_args) for cost, reverse_dep_count, changed, neg_latency, not_rdx_is_last_imul_output, var, mulx_rdx_args, core, args, new_core_state in sorted_next_statements
if pre_min_cost == cost)
+ if False and any(all(arg not in rdx_with_imul for arg in mulx_rdx_args)
+ for cost, var, core, args, new_core_state, mulx_rdx_args in sorted_subset_next_statements
+ if len(mulx_rdx_args) > 0):
+ sorted_subset_next_statements = tuple([(cost, var, core, args, new_core_state, mulx_rdx_args)
+ for cost, var, core, args, new_core_state, mulx_rdx_args in sorted_subset_next_statements
+ if all(arg not in rdx_with_imul for arg in mulx_rdx_args)]
+ + [(cost, var, core, args, new_core_state, mulx_rdx_args)
+ for cost, var, core, args, new_core_state, mulx_rdx_args in sorted_subset_next_statements
+ if not all(arg not in rdx_with_imul for arg in mulx_rdx_args)])
sorted_subset_next_statements = sorted_subset_next_statements[:1]
if pre_min_cost == 0: sorted_subset_next_statements = sorted_subset_next_statements[:1]
- for cost, var, core, args, new_core_state in sorted_subset_next_statements:
+ for cost, var, core, args, new_core_state, mulx_rdx_args in sorted_subset_next_statements:
cost, schedule = make_schedule(var, core, args)
if min_cost is None or cost < min_cost:
min_cost, min_schedule = cost, schedule
@@ -407,10 +439,192 @@ def schedule(data, basepoint, do_print):
min_cost, min_schedule = evaluate_cost_memo(freeze_gen((freeze([]), core_state))), []
return min_cost, min_schedule
+ def get_live_ranges(var_to_line, schedule_with_cycles, RET_loc):
+ var_names = get_var_names(data)
+ input_var_names = get_input_var_names(data)
+ output_var_names = get_output_var_names(data)
+ ret = dict((var, {'start':0, 'accessed':[], 'mul_accessed':[]}) for var in input_var_names)
+ for (var, locs, core, args) in schedule_with_cycles:
+ assert var not in ret.keys()
+ line = var_to_line[var]
+ ret[var] = {'start':locs['start'], 'accessed':[], 'mul_accessed':[]}
+ for arg in line['args']:
+ if arg in var_names + input_var_names:
+ for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
+ ret[arg]['accessed'].append(locs['start'] + latency)
+ elif arg[:2] == '0x':
+ pass
+ else:
+ print(arg)
+ for arg in get_mulx_rdx_args(core, args, use_imul=True):
+ ret[arg]['mul_accessed'].append(locs['start'])
+ for var in output_var_names:
+ ret[var]['accessed'].append(RET_loc)
+ for var in ret.keys():
+ ret[var]['end'] = max(ret[var]['accessed'])
+ for var in ret.keys():
+ if not (len(ret[var]['mul_accessed']) == 0 or tuple(ret[var]['mul_accessed']) == tuple(ret[var]['accessed'])):
+ print((var, ret[var]['accessed'], ret[var]['mul_accessed']))
+ assert False
+ return ret
+
+ def remake_overlaps(live_ranges):
+ live_ranges = dict(live_ranges)
+ for var in live_ranges.keys():
+ live_ranges[var] = dict(live_ranges[var])
+ live_ranges[var]['end'] = max(live_ranges[var]['accessed'])
+ live_ranges[var]['overlaps'] = tuple(sorted(
+ other_var
+ for other_var in live_ranges.keys()
+ if other_var != var and
+ (live_ranges[other_var]['start'] <= live_ranges[var]['start'] <= live_ranges[other_var]['end']
+ or live_ranges[var]['start'] <= live_ranges[other_var]['start'] <= live_ranges[var]['end'])
+ ))
+ return live_ranges
+
+ def make_initial_register_ranges(RET_loc):
+ return dict((reg, [None] * (RET_loc + 2)) for reg in REGISTERS)
+
+ def force_mulx_args(live_ranges, register_ranges):
+ for var in live_ranges.keys():
+ for loc in live_ranges[var]['mul_accessed']:
+ assert register_ranges['RDX'][loc] is None
+ register_ranges['RDX'][loc] = var
+ return register_ranges
+
+ def lookup_var_in_reg(register_ranges, loc, var):
+ var_dict = dict((register_ranges[reg][loc], reg) for reg in register_ranges.keys())
+ next_var_dict = dict((register_ranges[reg][loc+1], reg) for reg in register_ranges.keys())
+ next_next_var_dict = dict((register_ranges[reg][loc+2], reg) for reg in register_ranges.keys())
+ if var in var_dict.keys():
+ return [var_dict[var]]
+ elif var + '_low' in var_dict.keys() and var + '_high' in next_var_dict.keys():
+ return ['%s:%s' % (next_var_dict[var + '_high'], var_dict[var + '_low'])]
+ elif var + '_low' in next_var_dict.keys() and var + '_high' in next_next_var_dict.keys():
+ return ['%s:%s' % (next_next_var_dict[var + '_high'], next_var_dict[var + '_low'])]
+ else:
+ return []
+
+ def update_source(line, loc, register_ranges):
+ source = line['source']
+ for var in sorted([line['out']] + list(line['args']), key=len):
+ for reg in lookup_var_in_reg(register_ranges, loc, var):
+ source = source.replace(var, reg)
+ return source
+
+ def get_next_registers_use(loc, register_ranges, live_ranges):
+ def gen():
+ for reg in REGISTERS:
+ last_vals = [val for val in register_ranges[reg][:loc] if val is not None]
+ if len(last_vals) == 0:
+ yield (reg, None)
+ else:
+ val = last_vals[-1].replace('_low', '').replace('_high', '')
+ accessed = [aloc for aloc in live_ranges[val]['accessed'] if aloc > loc]
+ if len(accessed) == 0:
+ yield (reg, None)
+ else:
+ yield (reg, accessed[0])
+ return dict(gen())
+
+ def free_registers(loc, register_ranges, live_ranges):
+ all_free_registers = [reg for reg in REGISTERS if register_ranges[reg][loc] is None]
+ if loc == 0: return tuple(all_free_registers)
+ if loc == 1:
+ return tuple([reg for reg in all_free_registers if register_ranges[reg][loc-1] is None]
+ + [reg for reg in all_free_registers if register_ranges[reg][loc-1] is not None])
+ next_uses = get_next_registers_use(loc, register_ranges, live_ranges)
+ def get_key(reg):
+ if register_ranges[reg][loc-1] is None and register_ranges[reg][loc-2] is None:
+ return (0, next_uses[reg])
+ if register_ranges[reg][loc-1] is None and register_ranges[reg][loc-2] is not None:
+ return (1, next_uses[reg])
+ return (2, next_uses[reg])
+ return tuple(sorted(all_free_registers, key=get_key))
+
+ def spill_register(loc, register_ranges, exclude_vals):
+ # kick out the value that will disappear soonest
+ empties = dict((reg, [new_loc for new_loc, val in enumerate(register_ranges[reg][loc:]) if val is None])
+ for reg in register_ranges.keys()
+ if reg != 'RDX' and register_ranges[reg][loc].replace('_low', '').replace('_high', '') not in exclude_vals)
+ sorted_empties = sorted((new_locs[0], reg) for reg, new_locs in empties.items() if len(new_locs) > 0)
+ cycles, reg = sorted_empties[0]
+ print('Spilling %s from %s!' % (register_ranges[reg][loc], reg))
+ for new_loc in range(loc, loc + cycles):
+ register_ranges[reg][new_loc] = None
+ return reg, register_ranges
+
+
+ def linear_allocate(var_to_line, schedule_with_cycles, live_ranges, register_ranges):
+ for (var, locs, core, args) in schedule_with_cycles:
+ registers = free_registers(locs['start'], register_ranges, live_ranges)
+ line = var_to_line[var]
+# print((len(registers), line['source'], tuple((reg, register_ranges[reg][locs['start']]) for reg in REGISTERS)))
+ if line['type'] == 'uint128_t' and line['op'] == '+':
+ for latency, bits in ((0, '_low'), (1, '_high')):
+ for argi, arg in enumerate(line['args']):
+ found = list(lookup_var_in_reg(register_ranges, locs['start'] + latency, arg + bits))
+ if len(found) == 0:
+ if len(registers) == 0:
+ reg, register_ranges = spill_register(locs['start'] + latency, register_ranges, line['args'])
+ else:
+ reg, registers = registers[0], registers[1:]
+ assert register_ranges[reg][locs['start'] + latency] is None
+ register_ranges[reg][locs['start'] + latency] = arg + bits
+ else:
+ reg = found[0]
+ if argi == 0:
+ register_ranges[reg][locs['start'] + latency + 1] = line['out'] + bits
+ else:
+ if line['type'] == 'uint128_t':
+ out_args = [line['out'] + '_high', line['out'] + '_low']
+ else:
+ out_args = [line['out']]
+ for arg in sorted(out_args + list(line['args']), key=len):
+ if arg[:2] == '0x': continue
+ if len(list(lookup_var_in_reg(register_ranges, locs['start'], arg))) == 0:
+ if len(registers) == 0:
+ reg, register_ranges = spill_register(locs['start'], register_ranges, out_args + list(line['args']))
+ else:
+ reg, registers = registers[0], registers[1:]
+ assert register_ranges[reg][locs['start']] is None
+ for latency in range(max(c['latency'] for c in core['core'])): # handle instructions that need data for multiple cycles, like add;adcx
+ register_ranges[reg][locs['start'] + latency] = arg
+ if arg in out_args:
+ for latency in range(core['latency']+1):
+ register_ranges[reg][locs['start'] + latency] = arg
+ print(var_to_line[var]['source'] + ' // ' + update_source(var_to_line[var], locs['start'], register_ranges))
+# sys.exit(0)
+
+# def insert_possible_registers(live_ranges):
+# live_ranges = dict(live_ranges)
+# for var in live_ranges.keys():
+# live_ranges[var] = dict(live_ranges[var])
+# live_ranges[var]['possible registers'] = tuple(sorted(
+# other_var
+# for other_var in live_ranges.keys()
+# if other_var != var and
+# (live_ranges[other_var]['start'] <= live_ranges[var]['start'] <= live_ranges[other_var]['end']
+## or live_ranges[var]['start'] <= live_ranges[other_var]['start'] <= live_ranges[var]['end'])
+## ))
+## return live_ranges
+
+# def register_allocate(live_ranges):
+# allocated = {}
+# remaining_registers = list(REGISTERS)
+
+
core_state = freeze(make_initial_core_state())
cost, schedule = schedule_remaining(get_initial_indices(data), core_state) #, freeze_core_state(make_initial_core_state()))
schedule_with_cycle_info = add_cycle_info(schedule)
- for var, cycles, core in schedule_with_cycle_info:
+ live_ranges = remake_overlaps(get_live_ranges(lines, schedule_with_cycle_info, cost))
+ register_ranges = make_initial_register_ranges(cost)
+ register_ranges = force_mulx_args(live_ranges, register_ranges)
+ print(linear_allocate(lines, schedule_with_cycle_info, live_ranges, register_ranges))
+# print(live_ranges)
+# print(register_ranges)
+# sys.exit(0)
+ for var, cycles, core, args in schedule_with_cycle_info:
if var in lines.keys():
do_print(lines[var]['source'], ' // %s, start: %s, end: %s' % (core['instruction'], basepoint + cycles['start'], basepoint + cycles['finish']))
else: