aboutsummaryrefslogtreecommitdiff
path: root/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py
diff options
context:
space:
mode:
Diffstat (limited to 'etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py')
-rwxr-xr-xetc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py180
1 files changed, 123 insertions, 57 deletions
diff --git a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py
index f37bc3ff1..1083846d8 100755
--- a/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py
+++ b/etc/compile-by-zinc/make-graph-with-reg-by-ac-buckets.py
@@ -6,10 +6,12 @@ import subprocess
LAMBDA = u'\u03bb'
-OP_NAMES = {'*':'MUL', '+':'ADD', '>>':'SHL', '<<':'SHR', '|':'OR', '&':'AND'}
-
+NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RSI', 'RDI')
+NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS))
REGISTERS = tuple(#['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] + #, 'RSP'] # RSP is stack pointer?
- ['r%d' % i for i in range(13)])
+ ['reg%d' % i for i in range(13)])
+#REAL_REGISTERS = tuple(['RAX', 'RBX', 'RCX', 'RDX', 'RSI', 'RDI', 'RBP'] + #, 'RSP'] # RSP is stack pointer?
+# ['reg%d' % i for i in range(13)])
REGISTER_COLORS = ['color="black"', 'color="white",fillcolor="black"', 'color="maroon"', 'color="green"', 'fillcolor="olive"',
'color="navy"', 'color="purple"', 'fillcolor="teal"', 'fillcolor="silver"', 'fillcolor="gray"', 'fillcolor="red"',
'fillcolor="lime"', 'fillcolor="yellow"', 'fillcolor="blue"', 'fillcolor="fuschia"', 'fillcolor="aqua"']
@@ -281,7 +283,7 @@ def allocate_node(existing, node, *args):
return do_ret()
if len(node['deps']) == 0 and node['op'] == 'INPUT':
assert(node['type'] == 'uint64_t')
- cur_map[node['out']] = 'r' + node['out'] # free_list.pop()
+ cur_map[node['out']] = 'm' + node['out'] # free_list.pop()
emit_vars.append(node)
return do_ret()
if is_temp(node):
@@ -513,65 +515,76 @@ def print_input(reg_out, mem_in):
#return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out)
return ""
+def print_val(reg):
+ if reg.upper() in NAMED_REGISTERS:
+ return '%%%s' % reg
+ if reg[:2] == '0x':
+ return '$%s' % reg
+ return '%%[%s]' % reg
+
def print_load_specific_reg(reg, specific_reg='rdx'):
ret = ''
- ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
- ret += '"mov %%[%s], %%%s\\t\\n"\n' % (reg, specific_reg)
- return ret, (specific_reg,)
+ #ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
+ if reg != specific_reg:
+ ret += '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(specific_reg))
+ return ret, specific_reg
def print_unload_specific_reg(specific_reg='rdx'):
ret = ''
- ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
+ #ret += '"mov %%[%s_backup], %%%s\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
return ret
-def print_load(*regs):
- TEMP_REG = ['arg%d' % d for d in reversed(range(15))]
- ret, out_reg = '', []
- for reg in regs:
- if reg in REGISTERS:
- out_reg.append(reg)
- continue
- else:
- cur_reg = TEMP_REG.pop()
- ret += '"mov %%[%s], %%[%s]\\t\\n"\n' % (reg, cur_reg)
- out_reg.append(cur_reg)
- if len(out_reg) == 1: return ret, out_reg[0]
- return ret, tuple(out_reg)
+#def get_arg_reg(d):
+# return 'arg%d' % d
+def print_load(reg, can_clobber=tuple(), dont_clobber=tuple()):
+ assert(not isinstance(can_clobber, str))
+ assert(not isinstance(dont_clobber, str))
+ can_clobber = [i for i in reversed(can_clobber) if i not in dont_clobber]
+ if reg in REGISTERS:
+ return ('', reg)
+ else:
+ cur_reg = can_clobber.pop()
+ ret = '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(cur_reg))
+ return (ret, cur_reg)
def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src):
#return '%s:%s <- MULX %s, %s; // %s\n' % (reg_out_low, reg_out_high, rx1, rx2, src)
ret = ''
ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx')
- ret3, actual_rx2 = print_load(rx2)
- ret += ret2 + ret3 + ('"mulx %%[%s], %%[%s], %%[%s]\\t\\n" // %s\n' % (actual_rx2, reg_out_high, reg_out_low, src))
+ assert(rx2 != actual_rx1)
+ ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1])
+ ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(actual_rx2), print_val(reg_out_high), print_val(reg_out_low), src))
ret += print_unload_specific_reg('rdx')
return ret
def print_mov_bucket(reg_out, reg_in, bucket):
#return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket)
- ret, reg_in = print_load(reg_in)
- return ret + ('"mov %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_in, reg_out, bucket))
+ #ret, reg_in = print_load(reg_in, can_clobber=[reg_out])
+ return ('"mov %s, %s\\t\\n" // bucket: %s\n' % (print_val(reg_in), print_val(reg_out), bucket))
def print_mov(reg_out, reg_in):
#return '%s <- MOV %s;\n' % (reg_out, reg_in)
- ret, reg_in = print_load(reg_in)
- return ret + ('"mov %%[%s], %%[%s]\\t\\n"\n' % (reg_in, reg_out))
+ #ret, reg_in = print_load(reg_in)
+ return ('"mov %s, %s\\t\\n"\n' % (print_val(reg_in), print_val(reg_out)))
+
+def print_load_constant(reg_out, imm):
+ assert(imm[:2] == '0x')
+ return ('"mov $%s, %s\\t\\n"\n' % (imm, print_val(reg_out)))
LAST_CARRY = None
def print_mul_by_constant(reg_out, reg_in, constant, src):
#return '%s <- MULX %s, %s; // %s\n' % (ret_out, reg_in, constant, src)
- #assert(LAST_CARRY is None)
- global LAST_CARRY
- ret, reg_in = print_load(reg_in)
+ ret = ''
if constant == '0x13':
- return ret + ('FIXME: lea for %s\n' % src)
- else:
- LAST_CARRY = None
- return ret + ('"imul %%[%s], $%s, %%[%s]\\t\\n" // %s\n' % (reg_in, constant, reg_out, src))
+ ret += ('// FIXME: lea for %s\n' % src)
+ assert(constant[:2] == '0x')
+ return ret + \
+ print_load_constant('rdx', constant) + \
+ print_mulx(reg_out, 'rdx', 'rdx', reg_in, src)
def print_adx(reg_out, rx1, rx2, bucket):
#return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
assert(rx1 == reg_out)
- ret, rx2 = print_load(rx2)
+ ret, rx2 = print_load(rx2, dont_clobber=[rx1])
return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
def print_add(reg_out, cf, rx1, rx2, bucket):
@@ -580,7 +593,7 @@ def print_add(reg_out, cf, rx1, rx2, bucket):
assert(reg_out == rx1)
#assert(LAST_CARRY is None or LAST_CARRY == cf)
LAST_CARRY = cf
- ret, rx2 = print_load(rx2)
+ ret, rx2 = print_load(rx2, dont_clobber=[rx1])
return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
def print_adc(reg_out, cf, rx1, rx2, bucket):
@@ -591,7 +604,7 @@ def print_adc(reg_out, cf, rx1, rx2, bucket):
if LAST_CARRY != cf:
ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf)
LAST_CARRY = cf
- ret2, rx2 = print_load(rx2)
+ ret2, rx2 = print_load(rx2, dont_clobber=[rx1])
ret += ret2
return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
@@ -610,11 +623,30 @@ def print_and(reg_out, rx1, rx2, src):
if rx2[:2] == '0x':
return ('"and $%s, %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src))
else:
- ret, rx2 = print_load(rx2)
+ ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1])
return ret + ('"and %%[%s], %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src))
-#def print_shr(reg_out, rx1, imm, src):
- #return '%s <- SHR %s, %s;\n' %
+def print_shr(reg_out, rx1, imm, src):
+ #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
+ global LAST_CARRY
+ LAST_CARRY = None
+ assert(rx1 == reg_out)
+ assert(imm[:2] == '0x')
+ return ('"shr $%s, %%[%s]\\t\\n" // %s\n' % (imm, reg_out, src))
+
+def print_shrd(reg_out, rx_low, rx_high, imm, src):
+ #return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
+ global LAST_CARRY
+ LAST_CARRY = None
+ if rx_low != reg_out and rx_high == reg_out:
+ return print_mov('rdx', rx_low) + \
+ print_mov(rx_high, rx_low) + \
+ print_mov(rx_low, 'rdx') + \
+ print_shrd(reg_out, rx_high, rx_low, imm, src)
+ assert(rx_low == reg_out)
+ assert(imm[:2] == '0x')
+ return ('"shrd $%s, %%[%s], %%[%s]\\t\\n" // %s\n' % (imm, rx_low, rx_high, src))
+
def schedule(input_data, existing, emit_vars):
ret = ''
@@ -655,24 +687,24 @@ def schedule(input_data, existing, emit_vars):
extra_arg))
elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['op'] == 'COMBINE':
extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += ('%s <- SHR %s:%s, %s; // %s = %s:%s >> %s\n'
- % (existing[node['out']],
- existing[node['deps'][0]['deps'][0]['out']],
- existing[node['deps'][0]['deps'][1]['out']],
- extra_arg,
- node['out'],
- node['deps'][0]['deps'][0]['out'],
- node['deps'][0]['deps'][1]['out'],
- extra_arg))
+ ret += print_shrd(existing[node['out']],
+ existing[node['deps'][0]['deps'][0]['out']],
+ existing[node['deps'][0]['deps'][1]['out']],
+ extra_arg,
+ '%s = %s:%s >> %s'
+ % (node['out'],
+ node['deps'][0]['deps'][0]['out'],
+ node['deps'][0]['deps'][1]['out'],
+ extra_arg))
elif node['op'] == '>>' and len(node['deps']) == 1 and node['deps'][0]['type'] == 'uint64_t':
extra_arg = [arg for arg in line_of_var(data, node['out'])['args'] if arg[:2] == '0x'][0]
- ret += ('%s <- SHR %s, %s; // %s = %s >> %s\n'
- % (existing[node['out']],
- existing[node['deps'][0]['deps'][0]['out']],
- extra_arg,
- node['out'],
- node['deps'][0]['deps'][0]['out'],
- extra_arg))
+ ret += print_shr(existing[node['out']],
+ existing[node['deps'][0]['deps'][0]['out']],
+ extra_arg,
+ '%s = %s >> %s'
+ % (node['out'],
+ node['deps'][0]['deps'][0]['out'],
+ extra_arg))
elif node['op'] in ('GET_HIGH', 'GET_LOW'):
if node['rev_deps'][0]['out'] not in buckets_seen:
ret += print_mov_bucket(existing[node['rev_deps'][0]['out']],
@@ -750,6 +782,38 @@ def schedule(input_data, existing, emit_vars):
assert(False)
return ret
+def inline_schedule(sched, input_vars, output_vars):
+ KNOWN_CONSTRAINTS = dict(('r%sx' % l, l) for l in 'abcd')
+ def int_or_zero_key(v):
+ orig = v
+ v = v.strip('abcdefghijklmnopqrstuvwxyz')
+ if v.isdigit(): return (int(v), orig)
+ return (0, orig)
+ variables = list(sorted(set(list(re.findall('%\[([a-zA-Z0-9_]*)\]', sched)) +
+ list(re.findall('%([a-zA-Z0-9_]+)', sched))),
+ key=int_or_zero_key))
+ mems, variables = [i for i in variables if i[:2] == 'mx'], [i for i in variables if i[:2] != 'mx']
+ special_reg, variables = [i for i in variables if i.upper() in NAMED_REGISTERS], [i for i in variables if i.upper() not in NAMED_REGISTERS]
+ transient_regs, output_regs = [i for i in variables if i not in output_vars.values()], [i for i in variables if i in output_vars.keys()]
+ available_registers = ['r%d' % i for i in range(16)
+ if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg]
+ for reg in output_regs:
+ sched = sched.replace('%%[%s]' % reg, '%%[r%s]' % output_vars[reg])
+ renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):]))
+ for from_reg, to_reg in renaming.items():
+ sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg)
+ transient_regs = [renaming[reg] for reg in transient_regs]
+ ret = ''
+ ret += 'asm (\n'
+ ret += sched
+ ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n'
+ ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n'
+ ret += ': ' + ', '.join(['"cc"'] +
+ ['"%s"' % reg for reg in special_reg] +
+ ['"%s"' % reg for reg in transient_regs]) + '\n'
+ ret += ');\n'
+ return ret
+
data_list = parse_lines(get_lines('femulDisplay.log'))
for i, data in enumerate(data_list):
graph = to_graph(data)
@@ -807,7 +871,9 @@ for i, data in enumerate(data_list):
#mul_node = possible_nodes[0]
#print([n['out'] for n in mul_node['deps']])
#cur_map, free_temps, free_list, all_temps = allocate_subgraph(existing, mul_node, cur_map, free_temps, free_list, all_temps)
- sched = schedule(data, existing, emit_vars)
+ sched = inline_schedule(schedule(data, existing, emit_vars),
+ dict((existing[n['out']], n['out']) for n in graph['in'].values()),
+ dict((existing[n['out']], n['out']) for n in graph['out'].values()))
#fill_deps(buckets[0])
deps = adjust_bits(data, print_graph(graph, existing))
with codecs.open('femulData%d.dot' % i, 'w', encoding='utf8') as f: