aboutsummaryrefslogtreecommitdiff
path: root/register-allocate.py
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@mit.edu>2019-01-08 04:21:38 -0500
committerGravatar Andres Erbsen <andreser@mit.edu>2019-01-09 22:49:02 -0500
commit3ca227f1137e6a3b65bc33f5689e1c230d591595 (patch)
treee1e5a2dd2a2f34f239d3276227ddbdc69eeeb667 /register-allocate.py
parent3ec21c64b3682465ca8e159a187689b207c71de4 (diff)
remove old pipeline
Diffstat (limited to 'register-allocate.py')
-rwxr-xr-xregister-allocate.py970
1 files changed, 0 insertions, 970 deletions
diff --git a/register-allocate.py b/register-allocate.py
deleted file mode 100755
index 276deb46c..000000000
--- a/register-allocate.py
+++ /dev/null
@@ -1,970 +0,0 @@
-#!/usr/bin/env python
-from __future__ import with_statement
-import codecs, re, sys, os
-
-LAMBDA = u'\u03bb'
-
-NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RBP', 'RSI', 'RDI')
-NUMBERED_REGISTERS = tuple('r%d' % i for i in range(16))
-RESERVED_REGISTERS = ('RBP', )
-TO_BE_RESTORED_REGISTERS = ('RSP', )
-NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS))
-REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + list(NUMBERED_REGISTERS))
-REGISTERS = ['reg%d' % i for i in range(13)]
-DEFAULT_DIALECT = 'att'
-
-def get_lines(filename):
- with codecs.open(filename, 'r', encoding='utf8') as f:
- lines = f.read().replace('\r\n', '\n')
- return [line.strip() for line in re.findall("%s '.*?[Rr]eturn [^\r\n]*" % LAMBDA, lines, flags=re.DOTALL)[0].split('\n')]
-
-def strip_casts(text):
- return re.sub(r'\(u?int[0-9]*_t\)\s*\(?([^\)]*)\)?', r'\1', text)
-
-def parse_lines(lines):
- lines = list(map(strip_casts, lines))
- assert lines[0][:len(LAMBDA + ' ')] == LAMBDA + ' '
- assert lines[0][-1] == ','
- ret = {}
- ret['header'] = lines[0]
- ret['footer'] = lines[-1]
- ret['vars'] = lines[0][len(LAMBDA + ' '):-1]
- assert lines[-1][-1] == ')'
- ret['return'] = lines[-1][:-1].replace('return ', '').replace('Return ', '')
- ret['lines'] = []
- for line in lines[1:-1]:
- match0 = re.findall('^(u?int[0-9]*_t) ([^ ]*), (u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line)
- match1 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^\(]*)\(([^ ]*), ([^ ]*), ([^ ]*)\);$', line)
- match2 = re.findall('^(u?int[0-9]*_t) ([^ ]*) = ([^ ]*) ([^ ]*) ([^ ]*);$', line)
- if len(match0) > 0:
- datatype1, varname1, datatype2, varname2, op, arg1, arg2, arg3 = match0[0]
- print('XXX FIXME %s' % line)
- ret['lines'].append({'type':datatype1, 'out':varname1, 'op':op, 'args':(arg1, arg2, arg3), 'source':line, 'out2':varname2, 'type2':datatype2})
- elif len(match1) > 0:
- datatype, varname, op, arg1, arg2, arg3 = match1[0]
- ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2, arg3), 'source':line})
- elif len(match2) > 0:
- datatype, varname, arg1, op, arg2 = match2[0]
- ret['lines'].append({'type':datatype, 'out':varname, 'op':op, 'args':(arg1, arg2), 'source':line})
- else:
- print(line)
- assert(False)
- ret['lines'] = tuple(ret['lines'])
- return ret
-
-def get_var_names(input_data):
- return tuple(line['out'] for line in input_data['lines'])
-
-def get_input_var_names(input_data):
- return tuple(i for i in data['vars'].replace('%core', '').replace(',', ' ').replace('(', ' ').replace(')', ' ').replace("'", ' ').split(' ')
- if i != '')
-
-def get_output_var_names(input_data):
- return tuple(i for i in data['return'].replace(',', ' ').replace('(', ' ').replace(')', ' ').split(' ')
- if i != '')
-
-def line_of_var(input_data, var):
- retv = [line for line in input_data['lines'] if line['out'] == var]
- if len(retv) > 0: return retv[0]
- return {'out': var, 'args':tuple(), 'op': 'INPUT', 'type':'uint64_t'}
-
-def make_data_dependencies(input_data):
- input_var_names = get_input_var_names(input_data)
- dependencies = dict((var, tuple()) for var in input_var_names)
- for line in input_data['lines']:
- dependencies[line['out']] = tuple(arg for arg in line['args']
- if arg[0] not in '0123456789')
- return dependencies
-def make_reverse_data_dependencies(dependencies):
- reverse_dependencies = dict((k, []) for k in dependencies.keys())
- for k, v in dependencies.items():
- for arg in v:
- reverse_dependencies[arg].append(k)
- return reverse_dependencies
-
-def get_low_or_high(obj, low_or_high):
- assert(low_or_high in ('low', 'high'))
- if obj['op'] == 'COMBINE':
- if low_or_high == 'low': return obj['deps'][0]
- if low_or_high == 'high': return obj['deps'][1]
- else:
- return {'out':obj['out'] + '_' + low_or_high, 'style':'', 'deps':(obj,), 'op':'GET_' + low_or_high.upper(), 'type':'uint64_t', 'extra_out':tuple(o + '_' + low_or_high for o in obj['extra_out']), 'rev_deps':tuple()}
-
-def add_combine_low_high(objs):
- for obj in objs:
- if obj['type'] == 'uint128_t':
- obj_low = get_low_or_high(obj, 'low')
- obj_high = get_low_or_high(obj, 'high')
- obj_new = {'out':obj['out'], 'style':'', 'deps':(obj_low, obj_high), 'op':'COMBINE', 'type':'uint128_t', 'extra_out':obj['extra_out'], 'rev_deps':obj['rev_deps']}
- obj['out'] += '_tmp'
- obj['rev_deps'] = (obj_low, obj_high)
- obj_high['rev_deps'] = (obj_new,)
- obj_low['rev_deps'] = (obj_new,)
- for rdep in obj_new['rev_deps']:
- rdep['deps'] = tuple(d if d is not obj else obj_new
- for d in rdep['deps'])
-
-
-def split_graph(objs):
- for obj in objs:
- if obj['op'] == '&' and obj['type'] == 'uint64_t' and len(obj['deps']) == 1 and obj['deps'][0]['type'] == 'uint128_t' and obj['deps'][0]['op'] == 'COMBINE':
- combine_node = obj['deps'][0]
- low = combine_node['deps'][0]
- obj['deps'] = (low,)
- low['rev_deps'] = tuple(list(low['rev_deps']) + [obj])
- if obj['op'] == '+' and obj['type'] == 'uint128_t' and len(obj['rev_deps']) == 2 and obj['rev_deps'][0]['op'] == 'GET_LOW' and obj['rev_deps'][1]['op'] == 'GET_HIGH':
- for tmp in ('_tmp', '_temp'):
- if obj['out'][-len(tmp):] == tmp:
- obj['out'] = obj['out'][:-len(tmp)]
- obj_low, obj_high = obj['rev_deps']
- obj_carry = {'out':'c' + obj['out'], 'style':'', 'deps':(obj_low,), 'op':'GET_CARRY', 'type':'CARRY', 'extra_out':tuple(), 'rev_deps':(obj_high,)}
- assert(len(obj_low['deps']) == 1)
- assert(len(obj_high['deps']) == 1)
- assert(obj_low['type'] == 'uint64_t')
- assert(obj_high['type'] == 'uint64_t')
- obj_low['deps'], obj_high['deps'] = [], [obj_carry]
- obj_low['op'] = '+'
- obj_high['op'] = '+'
- for dep in obj['deps']:
- if dep['type'] == 'uint64_t':
- obj_low['deps'].append(dep)
- dep['rev_deps'] = tuple(d if d is not obj else obj_low
- for d in dep['rev_deps'])
- elif dep['type'] == 'uint128_t':
- dep_low, dep_high = get_low_or_high(dep, 'low'), get_low_or_high(dep, 'high')
- obj_low['deps'].append(dep_low)
- obj_high['deps'].append(dep_high)
- dep_low['rev_deps'] = tuple(list(dep_low['rev_deps']) + [obj_low])
- dep_high['rev_deps'] = tuple(list(dep_high['rev_deps']) + [obj_high])
- else:
- assert(False)
- obj_low['deps'], obj_high['deps'] = tuple(obj_low['deps']), tuple(obj_high['deps'])
- obj_low['rev_deps'] = list(obj_low['rev_deps']) + [obj_carry]
- obj['deps'] = tuple()
- obj['rev_deps'] = tuple()
-
-def collect_ac_buckets(graph):
- seen = set()
- to_process = list(graph['out'].values())
- while len(to_process) > 0:
- line, to_process = to_process[0], to_process[1:]
- if line['out'] in seen: continue
- seen.add(line['out'])
- if line['op'] == '+':
- args = list(line['deps'])
- new_args = []
- while len(args) > 0:
- arg, args = args[0], args[1:]
- if arg['op'] == '+' and len(arg['rev_deps']) == 1 and line['type'] == 'uint128_t':
- line['extra_out'] = tuple(sorted(list(line['extra_out']) + [arg['out']] + list(arg['extra_out'])))
- for arg_arg in arg['deps']:
- arg_arg['rev_deps'] = (line,)
- args.append(arg_arg)
- else:
- new_args.append(arg)
- line['deps'] = tuple(new_args)
- to_process += list(line['deps'])
-
-def get_objects(start, ret=None):
- if ret is None: ret = {}
- for node in start:
- if node['out'] in ret.keys(): continue
- ret[node['out']] = node
- get_objects(node['deps'], ret=ret)
- return ret
-
-def int_or_zero_key(v):
- orig = v
- v = v.strip('abcdefghijklmnopqrstuvwxyz')
- if v.isdigit(): return (int(v), orig)
- return (0, orig)
-
-def prune(start):
- objs = get_objects(start)
- for var in objs.keys():
- objs[var]['rev_deps'] = tuple(obj for obj in objs[var]['rev_deps']
- if obj['out'] in objs.keys() and any(node['out'] == var for node in obj['deps']))
-
-def to_graph(input_data):
- objs = dict((var, {'out':var, 'style':'', 'rev_deps':[]}) for var in list(get_input_var_names(input_data)) + list(get_var_names(input_data)))
- for var in get_input_var_names(input_data):
- objs[var]['deps'] = tuple()
- objs[var]['op'] = 'INPUT'
- objs[var]['type'] = 'uint64_t'
- objs[var]['extra_out'] = tuple()
- for line in input_data['lines']:
- var = line['out']
- objs[var]['extra_out'] = tuple()
- objs[var]['op'] = line['op']
- objs[var]['type'] = line['type']
- objs[var]['deps'] = tuple(objs[arg] for arg in line['args'] if arg in objs.keys())
- for node in objs[var]['deps']:
- node['rev_deps'].append(objs[var])
- for var in objs.keys():
- objs[var]['rev_deps'] = tuple(sorted(objs[var]['rev_deps'], key=(lambda n: int_or_zero_key(n['out']))))
- graph = {'out':dict((var, objs[var]) for var in get_output_var_names(input_data)),
- 'in':dict((var, objs[var]) for var in get_input_var_names(input_data)) }
- collect_ac_buckets(graph)
- add_combine_low_high(objs.values())
- split_graph(objs.values())
- prune(tuple(graph['out'].values()))
- #split_graph(objs)
- return graph
-
-
-def adjust_bits(input_data, graph):
- for line in input_data['lines']:
- if line['type'] == 'uint128_t':
- graph = graph.replace(line['out'], line['out'] + '_128')
- return graph
-
-def fill_node(node, color='red'):
- node['style'] = ', style="filled", fillcolor="%s"' % color
-
-def fill_deps(node, color='red'):
- fill_node(node)
- for dep in node['deps']:
- fill_deps(dep, color=color)
-
-def fill_subgraph(in_node, color='red'):
- #print((in_node['out'], in_node['op'], [d['out'] for d in in_node['rev_deps']]))
- fill_node(in_node, color=color)
- if in_node['op'] != '+':
- fill_deps(in_node, color=color)
- for rdep in in_node['rev_deps']:
- fill_subgraph(rdep, color=color)
-
-def is_temp(node):
- for tmp in ('_tmp', '_temp'):
- if node['out'][-len(tmp):] == tmp:
- return True
- return False
-
-def is_allocated_to_reg(full_map, node):
- return node['out'] in full_map.keys() and all(reg in REGISTERS for reg in full_map[node['out']].split(':'))
-
-def deps_allocated(full_map, node):
- if node['op'] == 'INPUT': return True
- if node['out'] not in full_map.keys(): return False
- return all(deps_allocated(full_map, dep) for dep in node['deps'])
-
-# returns {cur_map with new_name->reg}, still_free_temps, still_free_list, all_temps, freed, new_buckets, emit_vars
-def allocate_node(existing, node, *args):
- cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = args
- free_temps = list(free_temps)
- free_list = list(free_list)
- all_temps = list(all_temps)
- full_map = dict(existing)
- cur_map = dict(cur_map)
- freed = list(freed)
- new_buckets = list(new_buckets)
- emit_vars = list(emit_vars)
- full_map.update(cur_map)
- def do_ret():
- return cur_map, tuple(free_temps), tuple(free_list), tuple(all_temps), tuple(freed), tuple(new_buckets), tuple(emit_vars)
- def do_free(var):
- for reg in full_map[var].split(':'):
- if reg in all_temps:
- if reg not in free_temps:
- free_temps.append(reg)
- elif reg in REGISTERS:
- if reg not in free_list:
-# print('freeing %s from %s' % (reg, var))
- free_list.append(reg)
- def do_free_deps(node):
- full_map.update(cur_map)
- if node['out'] in full_map.keys():
- for dep in node['deps']:
- if dep['out'] in freed or dep['out'] not in full_map.keys(): continue
- if not is_allocated_to_reg(full_map, dep): continue
- if (all(deps_allocated(full_map, rdep) for rdep in dep['rev_deps']) or
- all(reg in all_temps for reg in full_map[dep['out']].split(':'))):
- do_free(dep['out'])
- freed.append(dep['out'])
- if node['out'] in full_map.keys():
- do_free_deps(node)
- return do_ret()
- if node['op'] in ('GET_HIGH', 'GET_LOW') and len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) <= 2 and all(n['op'] in ('GET_HIGH', 'GET_LOW') for n in node['deps'][0]['rev_deps']) and is_allocated_to_reg(full_map, node['deps'][0]):
- reg_idx = {'GET_LOW':0, 'GET_HIGH':1}[node['op']]
- cur_map[node['out']] = full_map[node['deps'][0]['out']].split(':')[reg_idx]
- emit_vars.append(node)
- return do_ret()
- if len(node['deps']) == 1 and len(node['deps'][0]['rev_deps']) == 1 and is_allocated_to_reg(full_map, node['deps'][0]) and node['type'] == node['deps'][0]['type']:
- cur_map[node['out']] = full_map[node['deps'][0]['out']]
- emit_vars.append(node)
- return do_ret()
- if len(node['deps']) == 0 and node['op'] == 'INPUT':
- assert(node['type'] == 'uint64_t')
- cur_map[node['out']] = 'm' + node['out'] # free_list.pop()
- emit_vars.append(node)
- return do_ret()
- if is_temp(node):
- num_reg = {'uint64_t':1, 'uint128_t':2}[node['type']]
- # TODO: make this more efficient by allowing re-use of
- # dependnecies which are no longer needed (which are currently
- # only reaped after this node is assigned)
- while len(free_temps) < num_reg:
- reg = free_list.pop()
- free_temps.append(reg)
- all_temps.append(reg)
- cur_map[node['out']] = ':'.join(free_temps[:num_reg])
- free_temps = free_temps[num_reg:]
- emit_vars.append(node)
- do_free_deps(node)
- return do_ret()
- if node['op'] == '+' and node['type'] == 'uint64_t' and len(node['extra_out']) > 0:
- cur_map[node['out']] = free_list.pop()
- emit_vars.append(node)
- new_buckets.append(node)
- do_free_deps(node)
- return do_ret()
- if node['op'] == '*' and node['type'] == 'uint64_t' and len(node['deps']) == 1:
- dep = node['deps'][0]
- assert(dep['out'] in full_map.keys())
- if is_allocated_to_reg(full_map, dep) and \
- all(rdep is node or (is_allocated_to_reg(full_map, rdep) and full_map[rdep['out']] != full_map[dep['out']])
- for rdep in dep['rev_deps']):
- cur_map[node['out']] = full_map[dep['out']]
- freed += [dep['out']]
- else:
- cur_map[node['out']] = free_list.pop()
- emit_vars.append(node)
- return do_ret()
- raw_input([node['out'], node['op'], node['type'], [(dep['out'], full_map.get(dep['out'])) for dep in node['deps']]])
- return do_ret()
-
-def allocate_deps(existing, node, *args):
- for dep in node['deps']:
- args = allocate_deps(existing, dep, *args)
- return allocate_node(existing, node, *args)
-
-def allocate_subgraph(existing, node, *args):
- if node['op'] != '+':
- args = allocate_deps(existing, node, *args)
- else:
- args = allocate_node(existing, node, *args)
- if node['op'] != '+':
- for rdep in node['rev_deps']:
- args = allocate_subgraph(existing, rdep, *args)
- return args
-
-def annotate_with_alloc(objs, mapping):
- for obj in objs:
- if obj['out'] in mapping.keys():
- obj['reg'] = ' (' + mapping[obj['out']] + ')'
- else:
- obj['reg'] = ''
-
-def get_plus_deps(nodes, ops=('+',), types=('uint64_t',), seen=None):
- if seen is None: seen = set()
- for node in nodes:
- for dep in node['deps']:
- if dep['out'] in seen: continue
- seen.add(dep['out'])
- if dep['op'] in ops and dep['type'] in types:
- yield dep
- for dep in get_plus_deps([dep], ops=ops, types=types, seen=seen):
- yield dep
-
-deps_table_memo = {}
-def all_deps_of(node):
- if node['out'] in deps_table_memo.keys(): return deps_table_memo[node['out']]
- ret = set()
- for dep in node['deps']:
- ret.add(dep['out'])
- ret.update(all_deps_of(dep))
- deps_table_memo[node['out']] = tuple(sorted(ret, key=int_or_zero_key))
- return deps_table_memo[node['out']]
-
-def transitively_depends_on(node, maybe_dep):
- return (node['out'] == maybe_dep['out']) or (maybe_dep['out'] in all_deps_of(node))
-
-def cmp_node_by_dep(x, y):
- default = cmp(x['out'], y['out'])
- if x['out'] == y['out']: return default
- if transitively_depends_on(x, y): ret = 1
- elif transitively_depends_on(y, x): ret = -1
- else: ret = default
- return ret
-
-
-
-def print_nodes(objs):
- for var in sorted(objs.keys(), key=(lambda s:(int(s.strip('cx_lowhightmp')), s))):
- yield ' %s [label="%s%s" %s];\n' % (objs[var]['out'], ' + '.join(sorted([objs[var]['out']] + list(objs[var]['extra_out']))), objs[var]['reg'], objs[var]['style'])
-def print_deps(objs):
- for var in sorted(objs.keys()):
- for dep in objs[var]['deps']:
- yield ' %s -> %s [ label="%s" ] ;\n' % (dep['out'], objs[var]['out'], objs[var]['op'])
-
-def push_allocate(existing, nodes, *args, **kwargs):
- if 'seen' not in kwargs.keys(): kwargs['seen'] = set()
- full_map = dict(existing)
- for node in nodes:
- if node['out'] in kwargs['seen']: continue
- cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = args
- free_temps = list(free_temps)
- free_list = list(free_list)
- all_temps = list(all_temps)
- cur_map = dict(cur_map)
- freed = list(freed)
- new_buckets = list(new_buckets)
- emit_vars = list(emit_vars)
- full_map.update(cur_map)
- if node['out'] in full_map.keys() and node['op'] == '+' and all(d['out'] not in full_map.keys() for d in node['rev_deps']) and set(d['op'] for d in node['rev_deps']) == set(('&', 'COMBINE', 'GET_CARRY')):
- and_node = [d for d in node['rev_deps'] if d['op'] == '&'][0]
- carry_node = [d for d in node['rev_deps'] if d['op'] == 'GET_CARRY'][0]
- combine_node = [d for d in node['rev_deps'] if d['op'] == 'COMBINE'][0]
- high_node = [d for d in combine_node['deps'] if d is not node][0]
- assert(len(combine_node['rev_deps']) == 1)
- shr_node = combine_node['rev_deps'][0]
- assert(shr_node['op'] == '>>')
- assert(shr_node['out'] not in full_map.keys())
- assert(len(combine_node['deps']) == 2)
- assert(all(d['out'] in full_map.keys() for d in combine_node['deps']))
- cur_map[carry_node['out']] = 'c0'
- emit_vars.append(carry_node)
- cur_map[combine_node['out']] = ':'.join(full_map[d['out']] for d in combine_node['deps'])
- emit_vars.append(combine_node)
- assert(high_node['out'] in full_map.keys())
- cur_map[shr_node['out']] = full_map[high_node['out']]
- emit_vars.append(shr_node)
- cur_map[and_node['out']] = full_map[node['out']]
- emit_vars.append(and_node)
- fill_node(combine_node)
- fill_node(carry_node)
- fill_node(shr_node)
- fill_node(and_node)
- freed += [node['out'], carry_node['out'], high_node['out'], combine_node['out']]
- elif node['out'] in full_map.keys() and len(node['rev_deps']) == 1 and all(d['out'] not in full_map.keys() for d in node['rev_deps']) and len(node['rev_deps'][0]['deps']) == 1 and node['type'] == node['rev_deps'][0]['type']:
- next_node = node['rev_deps'][0]
- cur_map[next_node['out']] = full_map[node['out']]
- emit_vars.append(next_node)
- fill_node(next_node)
- full_map.update(cur_map)
- freed += [node['out']]
- elif node['out'] not in full_map.keys() and len(node['rev_deps']) == 2 and len(node['deps']) == 2 and all(d['out'] not in full_map.keys() for d in node['rev_deps']) and all(d['out'] in full_map.keys() for d in node['deps']) and node['type'] == 'uint64_t' and all(d['type'] == 'uint64_t' for d in node['rev_deps']) and all(d['type'] == 'uint64_t' for d in node['deps']):
- from1, from2 = node['deps']
- to1, to2 = node['rev_deps']
- assert(full_map[from1['out']] != full_map[from2['out']])
- cur_map[node['out']] = full_map[from1['out']]
- emit_vars.append(node)
- cur_map[to1['out']] = full_map[from1['out']]
- emit_vars.append(to1)
- cur_map[to2['out']] = full_map[from2['out']]
- emit_vars.append(to2)
- fill_node(node)
- fill_node(to1)
- fill_node(to2)
- full_map.update(cur_map)
- freed += [node['out'], from1['out'], from2['out']]
- elif node['out'] not in full_map.keys() and len(node['rev_deps']) == 0 and len(node['deps']) == 2 and all(d['out'] not in full_map.keys() for d in node['rev_deps']) and all(d['out'] in full_map.keys() for d in node['deps']) and node['type'] == 'uint64_t' and all(d['type'] == 'uint64_t' for d in node['rev_deps']) and all(d['type'] == 'uint64_t' for d in node['deps']):
- from1, from2 = node['deps']
- assert(full_map[from1['out']] != full_map[from2['out']])
- cur_map[node['out']] = full_map[from1['out']]
- emit_vars.append(node)
- fill_node(node)
- full_map.update(cur_map)
- freed += [from1['out'], from2['out']]
- full_map.update(cur_map)
- args = (cur_map, tuple(free_temps), tuple(free_list), tuple(all_temps), tuple(freed), tuple(new_buckets), tuple(emit_vars))
- kwargs['seen'].add(node['out'])
- args = push_allocate(existing, node['rev_deps'], *args, **kwargs)
- return args
-
-def allocate_one_subtree(in_nodes, possible_nodes, existing, *args):
- cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = args
- existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars \
- = dict(existing), dict(cur_map), list(free_temps), list(free_list), list(all_temps), tuple(freed), tuple(new_buckets), tuple(emit_vars)
- args = (cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars)
- sorted_nodes = []
- for node in possible_nodes:
- try:
- lens = [len([rd for rd in d['rev_deps'] if rd['out'] not in existing.keys()]) for d in node['deps']]
- temp_cur_map, temp_free_temps, temp_free_list, temp_all_temps, temp_freed, temp_new_buckets, temp_emit_vars = allocate_subgraph(existing, node, *args)
- if set(temp_free_temps) != set(temp_all_temps):
- print(('BAD', node['out'], temp_cur_map, temp_free_temps, temp_free_list, temp_all_temps, temp_freed))
- sorted_nodes.append(((len(temp_free_list),
- -min(lens),
- -max(lens),
- -len(temp_new_buckets),
- len(temp_free_temps),
- -int(node['out'].strip('x_lowhightemp'))),
- node))
- except IndexError:
- print('Too many reg: %s' % node['out'])
- sorted_nodes = tuple(reversed(sorted(sorted_nodes, key=(lambda v: v[0]))))
-# print([(n[0], n[1]['out']) for n in sorted_nodes])
- node = sorted_nodes[0][1]
- possible_nodes = [n for n in possible_nodes if n is not node]
-# print('Allocating for %s' % node['out'])
- args = allocate_subgraph(existing, node, *args)
- fill_subgraph(node)
- args = push_allocate(existing, in_nodes, *args)
- cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = args
- return possible_nodes, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars
-
-
-def print_graph(graph, allocs):
- objs = get_objects(graph['out'].values())
- annotate_with_alloc(objs.values(), allocs)
- body = ''.join(print_nodes(objs))
- body += ''.join(print_deps(objs))
- body += ''.join(' in -> %s ;\n' % node['out'] for node in graph['in'].values())
- body += ''.join(' %s -> out ;\n' % node['out'] for node in graph['out'].values())
- return ('digraph G {\n' + body + '}\n')
-
-def fix_emit_vars(emit_vars):
- ret = []
- waiting = []
- seen = set()
- get_high_waiting = None
- for node in emit_vars:
- waiting.append(node)
- early_new_waiting = []
- new_waiting = []
- for wnode in waiting:
- if wnode['out'] in seen:
- continue
- elif wnode['op'] == 'GET_HIGH' and wnode['deps'][0]['out'] == get_high_waiting:
- ret.append(wnode)
- seen.add(wnode['out'])
- get_high_waiting = None
- elif wnode['op'] == 'GET_HIGH' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+':
- new_waiting.append(wnode)
- elif get_high_waiting is None and wnode['op'] == 'GET_LOW' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+':
- ret.append(wnode)
- seen.add(wnode['out'])
- assert(len(wnode['deps']) == 1)
- get_high_waiting = wnode['deps'][0]['out']
- elif get_high_waiting is not None:
- new_waiting.append(wnode)
- elif all(dep['out'] in seen for dep in wnode['deps']):
- ret.append(wnode)
- seen.add(wnode['out'])
- else:
- new_waiting.append(wnode)
- waiting = early_new_waiting + new_waiting
- while len(waiting) > 0:
-# print('Waiting on...')
-# print(list(sorted(node['out'] for node in waiting)))
- new_waiting = []
- for wnode in waiting:
- if wnode['out'] in seen:
- continue
- elif all(dep['out'] in seen for dep in wnode['deps']):
- ret.append(wnode)
- seen.add(wnode['out'])
- else:
- new_waiting.append(wnode)
- waiting = new_waiting
- return tuple(ret)
-
-def print_input(reg_out, mem_in):
- #return '%s <- LOAD %s;\n' % (reg_out, mem_in)
- #return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out)
- return ""
-
-def print_val(reg, dialect=DEFAULT_DIALECT, numbered_registers=False, final_pass=False):
- assert(dialect in ('intel', 'att'))
- if reg.upper() in NAMED_REGISTERS or (numbered_registers and reg.lower() in NUMBERED_REGISTERS):
- if dialect == 'intel':
- if final_pass:
- return reg
- else:
- return '%%%s' % reg
- elif dialect == 'att':
- return '%%%%%s' % reg
- if reg[:2] == '0x':
- if dialect == 'intel':
- return '%s' % reg
- elif dialect == 'att':
- return '$%s' % reg
- return '%%[%s]' % reg
-
-# args should be (outputs, inputs), as in intel syntax, regardless of what dialect says
-def print_instr(instr, args, comment=None, dialect=DEFAULT_DIALECT, do_print_val=True):
- if do_print_val:
- args = tuple(print_val(arg, dialect=dialect) for arg in args)
- if dialect == 'att':
- args = tuple(reversed(args))
- ret ='"%s %s\\t\\n"' % (instr, ', '.join(args))
- if comment is not None:
- ret += ' // %s' % comment
- ret += '\n'
- return ret
-
-def print_mov_no_adjust(reg_out, reg_in, comment=None, do_print_val=False):
- #return '%s <- MOV %s;\n' % (reg_out, reg_in)
- #ret, reg_in = print_load(reg_in)
- return print_instr('mov', (reg_out, reg_in), comment=comment, do_print_val=do_print_val)
-
-def print_mov(reg_out, reg_in):
- #return '%s <- MOV %s;\n' % (reg_out, reg_in)
- #ret, reg_in = print_load(reg_in)
- return print_mov_no_adjust(reg_out, reg_in, do_print_val=True)
-
-def print_load_constant(reg_out, imm):
- assert(imm[:2] == '0x')
- return print_mov_no_adjust(reg_out, imm, do_print_val=True)
-
-def print_load_specific_reg(reg, specific_reg='rdx'):
- ret = ''
- #ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
- if reg != specific_reg:
- ret += print_mov_no_adjust(specific_reg, reg, do_print_val=True)
- return ret, specific_reg
-def print_unload_specific_reg(specific_reg='rdx'):
- ret = ''
- #ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
- return ret
-#def get_arg_reg(d):
-# return 'arg%d' % d
-def print_load(reg, can_clobber=tuple(), dont_clobber=tuple()):
- assert(not isinstance(can_clobber, str))
- assert(not isinstance(dont_clobber, str))
- can_clobber = [i for i in reversed(can_clobber) if i not in dont_clobber]
- if reg in REGISTERS:
- return ('', reg)
- else:
- cur_reg = can_clobber.pop()
- ret = print_mov_no_adjust(print_val(cur_reg), print_val(reg))
- return (ret, cur_reg)
-
-def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src):
- #return '%s:%s <- MULX %s, %s; // %s\n' % (reg_out_low, reg_out_high, rx1, rx2, src)
- ret = ''
- ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx')
- assert(rx2 != actual_rx1)
- ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1])
- ret += ret2 + ret3 + print_instr('mulx', (reg_out_high, reg_out_low, actual_rx2), comment=src)
- ret += print_unload_specific_reg('rdx')
- return ret
-
-def print_mov_bucket(reg_out, reg_in, bucket):
- #return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket)
- #ret, reg_in = print_load(reg_in, can_clobber=[reg_out])
- return print_mov_no_adjust(print_val(reg_out), print_val(reg_in), 'bucket: ' + bucket)
-
-LAST_CARRY = None
-
-def print_imul_constant(reg_out, reg_in, imm, src):
- global LAST_CARRY
- LAST_CARRY = None
- ret = ''
- assert(imm[:2] == '0x')
- ret2, reg_in = print_load(reg_in, can_clobber=[reg_out])
- ret += ret2 + print_instr('imul', (reg_out, reg_in, imm), comment=src)
- return ret
-
-
-def print_mul_by_constant(reg_out, reg_in, constant, src):
- #return '%s <- MULX %s, %s; // %s\n' % (ret_out, reg_in, constant, src)
- ret = ''
- #if constant == '0x13':
- # ret += ('// FIXME: lea for %s\n' % src)
- assert(constant[:2] == '0x')
- #return ret + \
- # print_load_constant('rdx', constant) + \
- # print_mulx(reg_out, 'rdx', 'rdx', reg_in, src)
- return ret + \
- print_imul_constant(reg_out, reg_in, constant, src)
-
-def print_adx(reg_out, rx1, rx2, bucket):
- #return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
- assert(rx1 == reg_out)
- ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + print_instr('adx', (reg_out, rx2), 'bucket: ' + bucket)
-
-def print_adc(reg_out, carry_out, carry_in, rx1, rx2, bucket):
- #return '%s <- ADCX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
- global LAST_CARRY
- assert(LAST_CARRY == carry_in)
- LAST_CARRY = carry_out
- assert(rx1 == reg_out)
- ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket)
-
-def print_add(reg_out, cf, rx1, rx2, bucket):
- #return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket)
- global LAST_CARRY
- assert(reg_out == rx1)
- #assert(LAST_CARRY is None or LAST_CARRY == cf)
- LAST_CARRY = cf
- ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + print_instr('add', (reg_out, rx2), 'bucket: ' + bucket)
-
-def print_adc(reg_out, cf_out, cf_in, rx1, rx2, bucket):
- #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf_out, cf_in, rx1, rx2, bucket)
- assert(reg_out == rx1)
- ret = ''
- global LAST_CARRY
- if LAST_CARRY != cf_in:
- ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf_in)
- LAST_CARRY = cf_out
- ret2, rx2 = print_load(rx2, dont_clobber=[rx1])
- ret += ret2
- return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket)
-
-def print_adcx(reg_out, cf, bucket):
- #return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket)
- assert(LAST_CARRY == cf)
- return print_instr('adcx', (reg_out, '0x0'), 'bucket: ' + bucket)
-
-def print_and(reg_out, rx1, rx2, src):
- #return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src)
- global LAST_CARRY
- LAST_CARRY = None
- if reg_out != rx1:
- return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src)
- else:
- ret, rx2 = print_load(rx2, can_clobber=[reg_out, 'rdx'], dont_clobber=[rx1])
- return ret + print_instr('and', (reg_out, rx2), src)
-
-
-def print_shr(reg_out, rx1, imm, src):
- #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
- global LAST_CARRY
- LAST_CARRY = None
- assert(rx1 == reg_out)
- assert(imm[:2] == '0x')
- return print_instr('shr', (reg_out, imm), src)
-
-def print_shrd(reg_out, rx_low, rx_high, imm, src):
- #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
- global LAST_CARRY
- LAST_CARRY = None
- if rx_low != reg_out and rx_high == reg_out:
- return print_mov('rdx', rx_low) + \
- print_mov(rx_high, rx_low) + \
- print_mov(rx_low, 'rdx') + \
- print_shrd(reg_out, rx_high, rx_low, imm, src)
- assert(rx_low == reg_out)
- assert(imm[:2] == '0x')
- return print_instr('shrd', (rx_low, rx_high, imm), src)
-
-
-def schedule(input_data, existing, emit_vars):
- ret = ''
- buckets_seen = set()
- emit_vars = fix_emit_vars(emit_vars)
- ret += ('// Convention is low_reg:high_reg\n')
- for node in emit_vars:
- if node['op'] == 'INPUT':
- ret += print_input(existing[node['out']], node['out'])
- elif node['op'] == '*' and len(node['deps']) == 2:
- assert(len(existing[node['out']].split(':')) == 2)
- out_low, out_high = existing[node['out']].split(':')
- ret += print_mulx(out_low, out_high,
- existing[node['deps'][0]['out']],
- existing[node['deps'][1]['out']],
- '%s = %s * %s'
- % (node['out'],
- node['deps'][0]['out'],
- node['deps'][1]['out']))
- elif node['op'] == '*' and len(node['deps']) == 1:
- extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += print_mul_by_constant(existing[node['out']],
- existing[node['deps'][0]['out']],
- extra_arg,
- '%s = %s * %s'
- % (node['out'],
- node['deps'][0]['out'],
- extra_arg))
- elif node['op'] == '&' and len(node['deps']) == 1:
- extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += print_and(existing[node['out']],
- existing[node['deps'][0]['out']],
- extra_arg,
- '%s = %s & %s'
- % (node['out'],
- node['deps'][0]['out'],
- extra_arg))
- elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['op'] == 'COMBINE':
- extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += print_shrd(existing[node['out']],
- existing[node['deps'][0]['deps'][0]['out']],
- existing[node['deps'][0]['deps'][1]['out']],
- extra_arg,
- '%s = %s:%s >> %s'
- % (node['out'],
- node['deps'][0]['deps'][0]['out'],
- node['deps'][0]['deps'][1]['out'],
- extra_arg))
- elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['type'] == 'uint64_t':
- extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += print_shr(existing[node['out']],
- existing[node['deps'][0]['deps'][0]['out']],
- extra_arg,
- '%s = %s >> %s'
- % (node['out'],
- node['deps'][0]['deps'][0]['out'],
- extra_arg))
- elif node['op'] in ('GET_HIGH', 'GET_LOW'):
- if node['rev_deps'][0]['out'] not in buckets_seen:
- ret += print_mov_bucket(existing[node['rev_deps'][0]['out']],
- existing[node['out']],
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
- buckets_seen.add(node['rev_deps'][0]['out'])
- elif node['op'] == 'GET_HIGH':
- carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')]
- ret += print_adc(existing[node['rev_deps'][0]['out']],
- None,
- carry,
- existing[node['rev_deps'][0]['out']],
- existing[node['out']],
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
- elif node['op'] == 'GET_LOW':
- carry = 'c' + node['rev_deps'][0]['out'][:-len('_low')]
- ret += print_add(existing[node['rev_deps'][0]['out']],
- carry,
- existing[node['rev_deps'][0]['out']],
- existing[node['out']],
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
- elif node['op'] in ('GET_CARRY',):
- #carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')]
- #ret += print_adc(existing[node['rev_deps'][0]['out']],
- # carry,
- # ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
- pass
- elif node['op'] == '+' and len(node['extra_out']) > 0:
- pass
- elif node['op'] == '+' and len(node['deps']) == 2 and node['type'] == 'uint64_t':
- ret += print_add(existing[node['out']],
- None,
- existing[node['deps'][0]['out']],
- existing[node['deps'][1]['out']],
- '%s = %s + %s'
- % (node['out'],
- node['deps'][0]['out'],
- node['deps'][1]['out']))
- elif node['op'] in ('COMBINE',):
- pass
- else:
- raw_input((node['out'], node['op']))
- if node['op'] not in ('GET_HIGH', 'GET_LOW', 'COMBINE', 'GET_CARRY'):
- for rdep in node['rev_deps']:
- if len(rdep['extra_out']) > 0 and rdep['op'] == '+':
- if rdep['out'] not in buckets_seen:
- ret += print_mov_bucket(existing[rdep['out']],
- existing[node['out']],
- ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
- buckets_seen.add(rdep['out'])
- elif 'high' in rdep['out']:
- carry = 'c' + rdep['out'][:-len('_high')]
- ret += print_adc(existing[rdep['out']],
- None,
- carry,
- existing[rdep['out']],
- existing[node['out']],
- ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
- elif 'low' in rdep['out']:
- carry = 'c' + rdep['out'][:-len('_low')]
- ret += print_add(existing[rdep['out']],
- carry,
- existing[rdep['out']],
- existing[node['out']],
- ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
- else:
- assert(False)
- return ret
-
-def inline_schedule(sched, input_vars, output_vars):
- KNOWN_CONSTRAINTS = dict(('r%sx' % l, l) for l in 'abcd')
- variables = list(sorted(set(list(re.findall('%\[([a-zA-Z0-9_]*)\]', sched)) +
- list(re.findall('%([a-zA-Z0-9_]+)', sched))),
- key=int_or_zero_key))
- mems, variables = [i for i in variables if i[:2] == 'mx'], [i for i in variables if i[:2] != 'mx']
- special_reg, variables = [i for i in variables if i.upper() in NAMED_REGISTERS], [i for i in variables if i.upper() not in NAMED_REGISTERS]
- transient_regs, output_regs = [i for i in variables if i not in output_vars.values()], [i for i in variables if i in output_vars.keys()]
- available_registers = [NAMED_REGISTER_MAPPING.get('r%d' % i, 'r%d' % i).lower() for i in range(16)
- if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or (NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg
- and NAMED_REGISTER_MAPPING['r%d' % i] not in RESERVED_REGISTERS)]
- assert(len(available_registers) >= len(transient_regs))
- for reg in output_regs:
- sched = sched.replace('%%[%s]' % reg, '%%[r%s]' % output_vars[reg])
- available_registers = available_registers[-len(transient_regs):]
- assert(len(available_registers) > len(TO_BE_RESTORED_REGISTERS)) # makes the replacement of low registers with ones we have to handle specially easier
- count = len([reg for reg in TO_BE_RESTORED_REGISTERS if reg.lower() not in available_registers])
- available_registers = [reg.lower() for reg in TO_BE_RESTORED_REGISTERS] + \
- [reg for reg in available_registers[count:] if reg.upper() not in TO_BE_RESTORED_REGISTERS]
- renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):]))
- for from_reg, to_reg in renaming.items():
- sched = sched.replace('%%[%s]' % from_reg, print_val(to_reg, numbered_registers=True))
- transient_regs = [renaming[reg] for reg in transient_regs]
- for reg in REAL_REGISTERS:
- sched = sched.replace(print_val(reg.lower(), numbered_registers=True),
- print_val(reg.lower(), numbered_registers=True, final_pass=True))
- ret = ''
- ret += 'uint64_t %s;\n' % ', '.join(output_vars[reg] for reg in output_regs)
- ret += 'uint64_t %s;\n\n' % ', '.join(reg.lower() for reg in TO_BE_RESTORED_REGISTERS)
- ret += 'asm (\n'
- for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
- ret += print_mov_no_adjust('%%[%s]' % reg, print_val(reg, numbered_registers=True, final_pass=True))
- ret += sched
- for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
- ret += print_mov_no_adjust(print_val(reg, final_pass=True), '%%[%s]' % reg)
- ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n'
- ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars] +
- ['[%s] "m" (%s)' % (reg, reg) for reg in map(str.lower, TO_BE_RESTORED_REGISTERS)]) + '\n'
- ret += ': ' + ', '.join(['"cc"'] +
- ['"%s"' % reg for reg in special_reg] +
- ['"%s"' % reg for reg in transient_regs if reg.upper() not in TO_BE_RESTORED_REGISTERS]) + '\n'
- ret += ');\n'
- return ret
-
-if __name__ == '__main__':
- if len(sys.argv) != 3:
- print('USAGE: %s INPUT OUTPUT' % os.path.basename(__file__))
- sys.exit(1)
- in_file, out_file = sys.argv[1], sys.argv[2]
- data = parse_lines(get_lines(in_file))
- graph = to_graph(data)
- #possible_nodes = dict((n['out'], n)
- # for in_obj in graph['in'].values()
- # for n in in_obj['rev_deps']
- # if n['op'] == '*')
- #for var, node in list(possible_nodes.items()):
- # possible_nodes.update(dict((n['out'], n)
- # for n in node['rev_deps']
- # if n['op'] == '*'))
- #possible_nodes = list(sorted(possible_nodes.items()))
- #possible_nodes = [n for v, n in possible_nodes]
- in_nodes = tuple(graph['in'].values())
- existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars = {}, {}, tuple(), tuple(REGISTERS), tuple(), tuple(), tuple(), tuple()
- objs = get_objects(graph['out'].values())
- def vars_for(var, rec=True):
- pre_ret = [n['out'] for n in objs[var]['rev_deps']]
- ret = [v for v in pre_ret if 'tmp' in v]
- if rec:
- for v in pre_ret:
- if 'tmp' not in v:
- ret += list(vars_for(v, rec=False))
- return tuple(ret)
- def vars_for_bucket(var):
- if '_' not in var:
- return tuple(list(vars_for_bucket(var + '_low')) + list(vars_for_bucket(var + '_high')))
- ret = []
- for dep in objs[var]['deps']:
- if dep['op'] in ('GET_HIGH', 'GET_LOW'):
- assert(len(dep['deps']) == 1)
- assert('tmp' in dep['deps'][0]['out'])
- ret.append(dep['deps'][0]['out'])
- return tuple(ret)
- plus_deps = tuple(n for n in get_plus_deps(objs.values())
- if len(n['extra_out']) > 0)
- plus_deps = tuple(sorted(plus_deps, cmp=cmp_node_by_dep))
- for var in [v
- for n in plus_deps
- for v in vars_for_bucket(n['out'])]:
- cur_possible_nodes = [objs[var]] # [n for n in possible_nodes if n['out'] == var]
- cur_possible_nodes, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars \
- = allocate_one_subtree(in_nodes, cur_possible_nodes, existing, cur_map, free_temps, free_list, all_temps, freed, new_buckets, emit_vars)
- existing.update(cur_map)
- cur_map = {}
- sched = inline_schedule(schedule(data, existing, emit_vars),
- dict((existing[n['out']], n['out']) for n in graph['in'].values()),
- dict((existing[n['out']], n['out']) for n in graph['out'].values()))
- deps = adjust_bits(data, print_graph(graph, existing))
- with codecs.open(out_file, 'w', encoding='utf8') as f:
- f.write(data['header'] + '\n\n' + sched + '\n\n' + data['footer'] + '\n')