aboutsummaryrefslogtreecommitdiff
path: root/register-allocate.py
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-09-12 15:28:59 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-09-12 15:28:59 -0400
commit3c10ad879925d3d6410e090c3b0606be8a9c4a2d (patch)
tree76302e0474fc0eac46f992631ba9974c871aa992 /register-allocate.py
parent284a3d5aa23aa17f9fcdcccc378c350ea3a88a24 (diff)
Update reg alloc
Diffstat (limited to 'register-allocate.py')
-rwxr-xr-xregister-allocate.py102
1 files changed, 74 insertions, 28 deletions
diff --git a/register-allocate.py b/register-allocate.py
index f91382613..69f7ef7e9 100755
--- a/register-allocate.py
+++ b/register-allocate.py
@@ -28,8 +28,22 @@ def parse_lines(lines):
ret['return'] = lines[-1][:-1].replace('return ', '').replace('Return ', '')
ret['lines'] = []
for line in lines[1:-1]:
- datatype, varname, arg1, op, arg2 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^ ]*) ([^ ]*) ([^ ]*);$', line)[0]
- ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2), 'source':line})
+ match0 = re.findall('^(u?int[0-9]*_t) ([^ ]*), (u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line)
+ match1 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line)
+ match2 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^ ]*) ([^ ]*) ([^ ]*);$', line)
+ if len(match0) > 0:
+ datatype1, varname1, datatype2, varname2, op, arg1, arg2, arg3 = match0[0]
+ print('XXX FIXME %s' % line)
+ ret['lines'].append({'type':datatype1, 'out':varname1, 'op':op, 'args':(arg1, arg2, arg3), 'source':line, 'out2':varname2, 'type2':datatype2})
+ elif len(match1) > 0:
+ datatype, varname, op, arg1, arg2, arg3 = match1[0]
+ ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2, arg3), 'source':line})
+ elif len(match2) > 0:
+ datatype, varname, arg1, op, arg2 = match2[0]
+ ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2), 'source':line})
+ else:
+ print(line)
+ assert(False)
ret['lines'] = tuple(ret['lines'])
return ret
@@ -125,9 +139,12 @@ def split_graph(objs):
obj['rev_deps'] = tuple()
def collect_ac_buckets(graph):
+ seen = set()
to_process = list(graph['out'].values())
while len(to_process) > 0:
line, to_process = to_process[0], to_process[1:]
+ if line['out'] in seen: continue
+ seen.add(line['out'])
if line['op'] == '+':
args = list(line['deps'])
new_args = []
@@ -151,14 +168,20 @@ def get_objects(start, ret=None):
get_objects(node['deps'], ret=ret)
return ret
+def int_or_zero_key(v):
+ orig = v
+ v = v.strip('abcdefghijklmnopqrstuvwxyz')
+ if v.isdigit(): return (int(v), orig)
+ return (0, orig)
+
def prune(start):
objs = get_objects(start)
for var in objs.keys():
- objs[var]['rev_deps'] = tuple(objs[arg] for arg in sorted(objs.keys())
- if any(node['out'] == var for node in objs[arg]['deps']))
+ objs[var]['rev_deps'] = tuple(obj for obj in objs[var]['rev_deps']
+ if obj['out'] in objs.keys() and any(node['out'] == var for node in obj['deps']))
def to_graph(input_data):
- objs = dict((var, {'out':var, 'style':''}) for var in list(get_input_var_names(input_data)) + list(get_var_names(input_data)))
+ objs = dict((var, {'out':var, 'style':'', 'rev_deps':[]}) for var in list(get_input_var_names(input_data)) + list(get_var_names(input_data)))
for var in get_input_var_names(input_data):
objs[var]['deps'] = tuple()
objs[var]['op'] = 'INPUT'
@@ -170,9 +193,10 @@ def to_graph(input_data):
objs[var]['op'] = line['op']
objs[var]['type'] = line['type']
objs[var]['deps'] = tuple(objs[arg] for arg in line['args'] if arg in objs.keys())
+ for node in objs[var]['deps']:
+ node['rev_deps'].append(objs[var])
for var in objs.keys():
- objs[var]['rev_deps'] = tuple(objs[arg] for arg in sorted(objs.keys())
- if any(node['out'] == var for node in objs[arg]['deps']))
+ objs[var]['rev_deps'] = tuple(sorted(objs[var]['rev_deps'], key=(lambda n: int_or_zero_key(n['out']))))
graph = {'out':dict((var, objs[var]) for var in get_output_var_names(input_data)),
'in':dict((var, objs[var]) for var in get_input_var_names(input_data)) }
collect_ac_buckets(graph)
@@ -326,7 +350,7 @@ def annotate_with_alloc(objs, mapping):
else:
obj['reg'] = ''
-def get_plus_deps(nodes, ops=('+',), types=('uint128_t',), seen=None):
+def get_plus_deps(nodes, ops=('+',), types=('uint64_t',), seen=None):
if seen is None: seen = set()
for node in nodes:
for dep in node['deps']:
@@ -337,6 +361,29 @@ def get_plus_deps(nodes, ops=('+',), types=('uint128_t',), seen=None):
for dep in get_plus_deps([dep], ops=ops, types=types, seen=seen):
yield dep
+deps_table_memo = {}
+def all_deps_of(node):
+ if node['out'] in deps_table_memo.keys(): return deps_table_memo[node['out']]
+ ret = set()
+ for dep in node['deps']:
+ ret.add(dep['out'])
+ ret.update(all_deps_of(dep))
+ deps_table_memo[node['out']] = tuple(sorted(ret, key=int_or_zero_key))
+ return deps_table_memo[node['out']]
+
+def transitively_depends_on(node, maybe_dep):
+ return (node['out'] == maybe_dep['out']) or (maybe_dep['out'] in all_deps_of(node))
+
+def cmp_node_by_dep(x, y):
+ default = cmp(x['out'], y['out'])
+ if x['out'] == y['out']: return default
+ if transitively_depends_on(x, y): ret = 1
+ elif transitively_depends_on(y, x): ret = -1
+ else: ret = default
+ return ret
+
+
+
def print_nodes(objs):
for var in sorted(objs.keys(), key=(lambda s:(int(s.strip('cx_lowhightmp')), s))):
yield ' %s [label="%s%s" %s];\n' % (objs[var]['out'], ' + '.join(sorted([objs[var]['out']] + list(objs[var]['extra_out']))), objs[var]['reg'], objs[var]['style'])
@@ -767,11 +814,6 @@ def schedule(input_data, existing, emit_vars):
def inline_schedule(sched, input_vars, output_vars):
KNOWN_CONSTRAINTS = dict(('r%sx' % l, l) for l in 'abcd')
- def int_or_zero_key(v):
- orig = v
- v = v.strip('abcdefghijklmnopqrstuvwxyz')
- if v.isdigit(): return (int(v), orig)
- return (0, orig)
variables = list(sorted(set(list(re.findall('%\[([a-zA-Z0-9_]*)\]', sched)) +
list(re.findall('%([a-zA-Z0-9_]+)', sched))),
key=int_or_zero_key))
@@ -787,7 +829,7 @@ def inline_schedule(sched, input_vars, output_vars):
sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg)
transient_regs = [renaming[reg] for reg in transient_regs]
ret = ''
- ret += 'asm (\n'
+ ret += '__asm__ (\n'
ret += sched
ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n'
ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n'
@@ -804,16 +846,16 @@ if __name__ == '__main__':
in_file, out_file = sys.argv[1], sys.argv[2]
data = parse_lines(get_lines(in_file))
graph = to_graph(data)
- possible_nodes = dict((n['out'], n)
- for in_obj in graph['in'].values()
- for n in in_obj['rev_deps']
- if n['op'] == '*')
- for var, node in list(possible_nodes.items()):
- possible_nodes.update(dict((n['out'], n)
- for n in node['rev_deps']
- if n['op'] == '*'))
- possible_nodes = list(sorted(possible_nodes.items()))
- possible_nodes = [n for v, n in possible_nodes]
+ #possible_nodes = dict((n['out'], n)
+ # for in_obj in graph['in'].values()
+ # for n in in_obj['rev_deps']
+ # if n['op'] == '*')
+ #for var, node in list(possible_nodes.items()):
+ # possible_nodes.update(dict((n['out'], n)
+ # for n in node['rev_deps']
+ # if n['op'] == '*'))
+ #possible_nodes = list(sorted(possible_nodes.items()))
+ #possible_nodes = [n for v, n in possible_nodes]
in_nodes = tuple(graph['in'].values())
existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = {}, {}, tuple(), tuple(REGISTERS), tuple(), tuple(), tuple(), tuple()
objs = get_objects(graph['out'].values())
@@ -835,9 +877,13 @@ if __name__ == '__main__':
assert('tmp' in dep['deps'][0]['out'])
ret.append(dep['deps'][0]['out'])
return tuple(ret)
- for var in list(vars_for_bucket('x56')) + list(vars_for_bucket('x71')) + list(vars_for_bucket('x74')) + list(vars_for_bucket('x77')) + list(vars_for_bucket('x80')):
- #print(var)
- cur_possible_nodes = [n for n in possible_nodes if n['out'] == var]
+ plus_deps = tuple(n for n in get_plus_deps(objs.values())
+ if len(n['extra_out']) > 0)
+ plus_deps = tuple(sorted(plus_deps, cmp=cmp_node_by_dep))
+ for var in [v
+ for n in plus_deps
+ for v in vars_for_bucket(n['out'])]:
+ cur_possible_nodes = [objs[var]] # [n for n in possible_nodes if n['out'] == var]
cur_possible_nodes, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars \
= allocate_one_subtree(in_nodes, cur_possible_nodes, existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars)
existing.update(cur_map)
@@ -847,4 +893,4 @@ if __name__ == '__main__':
dict((existing[n['out']], n['out']) for n in graph['out'].values()))
deps = adjust_bits(data, print_graph(graph, existing))
with codecs.open(out_file, 'w', encoding='utf8') as f:
- f.write(data['header'] + '\n\n' + sched + '\n\n' + data['footer'])
+ f.write(data['header'] + '\n\n' + sched + '\n\n' + data['footer'] + '\n')