From 3c10ad879925d3d6410e090c3b0606be8a9c4a2d Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 12 Sep 2017 15:28:59 -0400 Subject: Update reg alloc --- register-allocate.py | 102 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 28 deletions(-) (limited to 'register-allocate.py') diff --git a/register-allocate.py b/register-allocate.py index f91382613..69f7ef7e9 100755 --- a/register-allocate.py +++ b/register-allocate.py @@ -28,8 +28,22 @@ def parse_lines(lines): ret['return'] = lines[-1][:-1].replace('return ', '').replace('Return ', '') ret['lines'] = [] for line in lines[1:-1]: - datatype, varname, arg1, op, arg2 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^ ]*) ([^ ]*) ([^ ]*);$', line)[0] - ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2), 'source':line}) + match0 = re.findall('^(u?int[0-9]*_t) ([^ ]*), (u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line) + match1 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line) + match2 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^ ]*) ([^ ]*) ([^ ]*);$', line) + if len(match0) > 0: + datatype1, varname1, datatype2, varname2, op, arg1, arg2, arg3 = match0[0] + print('XXX FIXME %s' % line) + ret['lines'].append({'type':datatype1, 'out':varname1, 'op':op, 'args':(arg1, arg2, arg3), 'source':line, 'out2':varname2, 'type2':datatype2}) + elif len(match1) > 0: + datatype, varname, op, arg1, arg2, arg3 = match1[0] + ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2, arg3), 'source':line}) + elif len(match2) > 0: + datatype, varname, arg1, op, arg2 = match2[0] + ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2), 'source':line}) + else: + print(line) + assert(False) ret['lines'] = tuple(ret['lines']) return ret @@ -125,9 +139,12 @@ def split_graph(objs): obj['rev_deps'] = tuple() def collect_ac_buckets(graph): + seen = set() to_process = list(graph['out'].values()) while len(to_process) > 0: line, to_process = to_process[0], to_process[1:] + if line['out'] in seen: continue + seen.add(line['out']) if line['op'] == '+': args = list(line['deps']) new_args = [] @@ -151,14 +168,20 @@ def get_objects(start, ret=None): get_objects(node['deps'], ret=ret) return ret +def int_or_zero_key(v): + orig = v + v = v.strip('abcdefghijklmnopqrstuvwxyz') + if v.isdigit(): return (int(v), orig) + return (0, orig) + def prune(start): objs = get_objects(start) for var in objs.keys(): - objs[var]['rev_deps'] = tuple(objs[arg] for arg in sorted(objs.keys()) - if any(node['out'] == var for node in objs[arg]['deps'])) + objs[var]['rev_deps'] = tuple(obj for obj in objs[var]['rev_deps'] + if obj['out'] in objs.keys() and any(node['out'] == var for node in obj['deps'])) def to_graph(input_data): - objs = dict((var, {'out':var, 'style':''}) for var in list(get_input_var_names(input_data)) + list(get_var_names(input_data))) + objs = dict((var, {'out':var, 'style':'', 'rev_deps':[]}) for var in list(get_input_var_names(input_data)) + list(get_var_names(input_data))) for var in get_input_var_names(input_data): objs[var]['deps'] = tuple() objs[var]['op'] = 'INPUT' @@ -170,9 +193,10 @@ def to_graph(input_data): objs[var]['op'] = line['op'] objs[var]['type'] = line['type'] objs[var]['deps'] = tuple(objs[arg] for arg in line['args'] if arg in objs.keys()) + for node in objs[var]['deps']: + node['rev_deps'].append(objs[var]) for var in objs.keys(): - objs[var]['rev_deps'] = tuple(objs[arg] for arg in sorted(objs.keys()) - if any(node['out'] == var for node in objs[arg]['deps'])) + objs[var]['rev_deps'] = tuple(sorted(objs[var]['rev_deps'], key=(lambda n: int_or_zero_key(n['out'])))) graph = {'out':dict((var, objs[var]) for var in get_output_var_names(input_data)), 'in':dict((var, objs[var]) for var in get_input_var_names(input_data)) } collect_ac_buckets(graph) @@ -326,7 +350,7 @@ def annotate_with_alloc(objs, mapping): else: obj['reg'] = '' -def get_plus_deps(nodes, ops=('+',), types=('uint128_t',), seen=None): +def get_plus_deps(nodes, ops=('+',), types=('uint64_t',), seen=None): if seen is None: seen = set() for node in nodes: for dep in node['deps']: @@ -337,6 +361,29 @@ def get_plus_deps(nodes, ops=('+',), types=('uint128_t',), seen=None): for dep in get_plus_deps([dep], ops=ops, types=types, seen=seen): yield dep +deps_table_memo = {} +def all_deps_of(node): + if node['out'] in deps_table_memo.keys(): return deps_table_memo[node['out']] + ret = set() + for dep in node['deps']: + ret.add(dep['out']) + ret.update(all_deps_of(dep)) + deps_table_memo[node['out']] = tuple(sorted(ret, key=int_or_zero_key)) + return deps_table_memo[node['out']] + +def transitively_depends_on(node, maybe_dep): + return (node['out'] == maybe_dep['out']) or (maybe_dep['out'] in all_deps_of(node)) + +def cmp_node_by_dep(x, y): + default = cmp(x['out'], y['out']) + if x['out'] == y['out']: return default + if transitively_depends_on(x, y): ret = 1 + elif transitively_depends_on(y, x): ret = -1 + else: ret = default + return ret + + + def print_nodes(objs): for var in sorted(objs.keys(), key=(lambda s:(int(s.strip('cx_lowhightmp')), s))): yield ' %s [label="%s%s" %s];\n' % (objs[var]['out'], ' + '.join(sorted([objs[var]['out']] + list(objs[var]['extra_out']))), objs[var]['reg'], objs[var]['style']) @@ -767,11 +814,6 @@ def schedule(input_data, existing, emit_vars): def inline_schedule(sched, input_vars, output_vars): KNOWN_CONSTRAINTS = dict(('r%sx' % l, l) for l in 'abcd') - def int_or_zero_key(v): - orig = v - v = v.strip('abcdefghijklmnopqrstuvwxyz') - if v.isdigit(): return (int(v), orig) - return (0, orig) variables = list(sorted(set(list(re.findall('%\[([a-zA-Z0-9_]*)\]', sched)) + list(re.findall('%([a-zA-Z0-9_]+)', sched))), key=int_or_zero_key)) @@ -787,7 +829,7 @@ def inline_schedule(sched, input_vars, output_vars): sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg) transient_regs = [renaming[reg] for reg in transient_regs] ret = '' - ret += 'asm (\n' + ret += '__asm__ (\n' ret += sched ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n' ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n' @@ -804,16 +846,16 @@ if __name__ == '__main__': in_file, out_file = sys.argv[1], sys.argv[2] data = parse_lines(get_lines(in_file)) graph = to_graph(data) - possible_nodes = dict((n['out'], n) - for in_obj in graph['in'].values() - for n in in_obj['rev_deps'] - if n['op'] == '*') - for var, node in list(possible_nodes.items()): - possible_nodes.update(dict((n['out'], n) - for n in node['rev_deps'] - if n['op'] == '*')) - possible_nodes = list(sorted(possible_nodes.items())) - possible_nodes = [n for v, n in possible_nodes] + #possible_nodes = dict((n['out'], n) + # for in_obj in graph['in'].values() + # for n in in_obj['rev_deps'] + # if n['op'] == '*') + #for var, node in list(possible_nodes.items()): + # possible_nodes.update(dict((n['out'], n) + # for n in node['rev_deps'] + # if n['op'] == '*')) + #possible_nodes = list(sorted(possible_nodes.items())) + #possible_nodes = [n for v, n in possible_nodes] in_nodes = tuple(graph['in'].values()) existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = {}, {}, tuple(), tuple(REGISTERS), tuple(), tuple(), tuple(), tuple() objs = get_objects(graph['out'].values()) @@ -835,9 +877,13 @@ if __name__ == '__main__': assert('tmp' in dep['deps'][0]['out']) ret.append(dep['deps'][0]['out']) return tuple(ret) - for var in list(vars_for_bucket('x56')) + list(vars_for_bucket('x71')) + list(vars_for_bucket('x74')) + list(vars_for_bucket('x77')) + list(vars_for_bucket('x80')): - #print(var) - cur_possible_nodes = [n for n in possible_nodes if n['out'] == var] + plus_deps = tuple(n for n in get_plus_deps(objs.values()) + if len(n['extra_out']) > 0) + plus_deps = tuple(sorted(plus_deps, cmp=cmp_node_by_dep)) + for var in [v + for n in plus_deps + for v in vars_for_bucket(n['out'])]: + cur_possible_nodes = [objs[var]] # [n for n in possible_nodes if n['out'] == var] cur_possible_nodes, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars \ = allocate_one_subtree(in_nodes, cur_possible_nodes, existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars) existing.update(cur_map) @@ -847,4 +893,4 @@ if __name__ == '__main__': dict((existing[n['out']], n['out']) for n in graph['out'].values())) deps = adjust_bits(data, print_graph(graph, existing)) with codecs.open(out_file, 'w', encoding='utf8') as f: - f.write(data['header'] + '\n\n' + sched + '\n\n' + data['footer']) + f.write(data['header'] + '\n\n' + sched + '\n\n' + data['footer'] + '\n') -- cgit v1.2.3