aboutsummaryrefslogtreecommitdiff
path: root/register-allocate.py
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-09-12 20:06:44 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-09-12 20:06:44 -0400
commit2bbbfed14c2d45fe5a1be6e079408b7be7c33587 (patch)
treed9ebd2edacad5535efc4ad1667eead6ed9475a0e /register-allocate.py
parent98daac5a86136cf0f1018b292baf8a676f6fd579 (diff)
Be better about asm syntax dialects
Diffstat (limited to 'register-allocate.py')
-rwxr-xr-xregister-allocate.py168
1 files changed, 101 insertions, 67 deletions
diff --git a/register-allocate.py b/register-allocate.py
index 16faee062..e05e346cd 100755
--- a/register-allocate.py
+++ b/register-allocate.py
@@ -5,11 +5,13 @@ import codecs, re, sys, os
LAMBDA = u'\u03bb'
NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RBP', 'RSI', 'RDI')
+NUMBERED_REGISTERS = tuple('r%d' % i for i in range(16))
RESERVED_REGISTERS = ('RSP', )
TO_BE_RESTORED_REGISTERS = ('RBP', )
NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS))
-REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + ['r%d' % i for i in range(8, 16)])
+REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + list(NUMBERED_REGISTERS))
REGISTERS = ['reg%d' % i for i in range(13)]
+DEFAULT_DIALECT = 'att'
def get_lines(filename):
with codecs.open(filename, 'r', encoding='utf8') as f:
@@ -516,18 +518,33 @@ def fix_emit_vars(emit_vars):
ret = []
waiting = []
seen = set()
+ get_high_waiting = None
for node in emit_vars:
waiting.append(node)
+ early_new_waiting = []
new_waiting = []
for wnode in waiting:
if wnode['out'] in seen:
continue
+ elif wnode['op'] == 'GET_HIGH' and wnode['deps'][0]['out'] == get_high_waiting:
+ ret.append(wnode)
+ seen.add(wnode['out'])
+ get_high_waiting = None
+ elif wnode['op'] == 'GET_HIGH' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+':
+ new_waiting.append(wnode)
+ elif get_high_waiting is None and wnode['op'] == 'GET_LOW' and len(wnode['rev_deps']) > 0 and wnode['rev_deps'][0]['op'] == '+':
+ ret.append(wnode)
+ seen.add(wnode['out'])
+ assert(len(wnode['deps']) == 1)
+ get_high_waiting = wnode['deps'][0]['out']
+ elif get_high_waiting is not None:
+ new_waiting.append(wnode)
elif all(dep['out'] in seen for dep in wnode['deps']):
ret.append(wnode)
seen.add(wnode['out'])
else:
new_waiting.append(wnode)
- waiting = new_waiting
+ waiting = early_new_waiting + new_waiting
while len(waiting) > 0:
# print('Waiting on...')
# print(list(sorted(node['out'] for node in waiting)))
@@ -548,36 +565,54 @@ def print_input(reg_out, mem_in):
#return '"mov %%[%s], %%[%s]\\n\\t"\n' % (mem_in, reg_out)
return ""
-def print_val(reg):
- if reg.upper() in NAMED_REGISTERS:
- return '%%%s' % reg
+def print_val(reg, dialect=DEFAULT_DIALECT, numbered_registers=False, final_pass=False):
+ assert(dialect in ('intel', 'att'))
+ if reg.upper() in NAMED_REGISTERS or (numbered_registers and reg.lower() in NUMBERED_REGISTERS):
+ if dialect == 'intel':
+ if final_pass:
+ return reg
+ else:
+ return '%%%s' % reg
+ elif dialect == 'att':
+ return '%%%%%s' % reg
if reg[:2] == '0x':
- return '$%s' % reg
+ if dialect == 'intel':
+ return '%s' % reg
+ elif dialect == 'att':
+ return '$%s' % reg
return '%%[%s]' % reg
-def print_mov_no_adjust(reg_out, reg_in, comment=None):
- #return '%s <- MOV %s;\n' % (reg_out, reg_in)
- #ret, reg_in = print_load(reg_in)
- ret = '"mov %s, %s\\t\\n"' % (reg_out, reg_in)
+# args should be (outputs, inputs), as in intel syntax, regardless of what dialect says
+def print_instr(instr, args, comment=None, dialect=DEFAULT_DIALECT, do_print_val=True):
+ if do_print_val:
+ args = tuple(print_val(arg, dialect=dialect) for arg in args)
+ if dialect == 'att':
+ args = tuple(reversed(args))
+ ret ='"%s %s\\t\\n"' % (instr, ', '.join(args))
if comment is not None:
ret += ' // %s' % comment
ret += '\n'
return ret
+def print_mov_no_adjust(reg_out, reg_in, comment=None, do_print_val=False):
+ #return '%s <- MOV %s;\n' % (reg_out, reg_in)
+ #ret, reg_in = print_load(reg_in)
+ return print_instr('mov', (reg_out, reg_in), comment=comment, do_print_val=do_print_val)
+
def print_mov(reg_out, reg_in):
#return '%s <- MOV %s;\n' % (reg_out, reg_in)
#ret, reg_in = print_load(reg_in)
- return print_mov_no_adjust(print_val(reg_out), print_val(reg_in))
+ return print_mov_no_adjust(reg_out, reg_in, do_print_val=True)
def print_load_constant(reg_out, imm):
assert(imm[:2] == '0x')
- return print_mov_no_adjust(print_val(reg_out), print_val(imm))
+ return print_mov_no_adjust(reg_out, imm, do_print_val=True)
def print_load_specific_reg(reg, specific_reg='rdx'):
ret = ''
#ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
if reg != specific_reg:
- ret += print_mov_no_adjust(print_val(specific_reg), print_val(reg))
+ ret += print_mov_no_adjust(specific_reg, reg, do_print_val=True)
return ret, specific_reg
def print_unload_specific_reg(specific_reg='rdx'):
ret = ''
@@ -602,7 +637,7 @@ def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src):
ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx')
assert(rx2 != actual_rx1)
ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1])
- ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(reg_out_high), print_val(reg_out_low), print_val(actual_rx2), src))
+ ret += ret2 + ret3 + print_instr('mulx', (reg_out_high, reg_out_low, actual_rx2), comment=src)
ret += print_unload_specific_reg('rdx')
return ret
@@ -627,7 +662,16 @@ def print_adx(reg_out, rx1, rx2, bucket):
#return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
assert(rx1 == reg_out)
ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
+ return ret + print_instr('adx', (reg_out, rx2), 'bucket: ' + bucket)
+
+def print_adc(reg_out, carry_out, carry_in, rx1, rx2, bucket):
+ #return '%s <- ADCX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
+ global LAST_CARRY
+ assert(LAST_CARRY == carry_in)
+ LAST_CARRY = carry_out
+ assert(rx1 == reg_out)
+ ret, rx2 = print_load(rx2, dont_clobber=[rx1])
+ return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket)
def print_add(reg_out, cf, rx1, rx2, bucket):
#return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket)
@@ -636,24 +680,24 @@ def print_add(reg_out, cf, rx1, rx2, bucket):
#assert(LAST_CARRY is None or LAST_CARRY == cf)
LAST_CARRY = cf
ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
+ return ret + print_instr('add', (reg_out, rx2), 'bucket: ' + bucket)
-def print_adc(reg_out, cf, rx1, rx2, bucket):
- #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf, cf, rx1, rx2, bucket)
+def print_adc(reg_out, cf_out, cf_in, rx1, rx2, bucket):
+ #return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf_out, cf_in, rx1, rx2, bucket)
assert(reg_out == rx1)
ret = ''
global LAST_CARRY
- if LAST_CARRY != cf:
- ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf)
- LAST_CARRY = cf
+ if LAST_CARRY != cf_in:
+ ret += 'ERRRRRRROR: %s != %s\n' % (LAST_CARRY, cf_in)
+ LAST_CARRY = cf_out
ret2, rx2 = print_load(rx2, dont_clobber=[rx1])
ret += ret2
- return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
+ return ret + print_instr('adc', (reg_out, rx2), 'bucket: ' + bucket)
def print_adcx(reg_out, cf, bucket):
#return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket)
assert(LAST_CARRY == cf)
- return ('"adcx %%[%s], $0\\t\\n" // bucket: %s\n' % (reg_out, bucket))
+ return print_instr('adcx', (reg_out, '0x0'), 'bucket: ' + bucket)
def print_and(reg_out, rx1, rx2, src):
#return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src)
@@ -662,10 +706,8 @@ def print_and(reg_out, rx1, rx2, src):
if reg_out != rx1:
return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src)
else:
- ret = ''
- if rx2[:2] != '0x':
- ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1])
- return ret + ('"and %s, %s\\t\\n" // %s\n' % (print_val(reg_out), print_val(rx2), src))
+ ret, rx2 = print_load(rx2, can_clobber=[reg_out, 'rdx'], dont_clobber=[rx1])
+ return ret + print_instr('and', (reg_out, rx2), src)
def print_shr(reg_out, rx1, imm, src):
@@ -674,7 +716,7 @@ def print_shr(reg_out, rx1, imm, src):
LAST_CARRY = None
assert(rx1 == reg_out)
assert(imm[:2] == '0x')
- return ('"shr %%[%s], $%s\\t\\n" // %s\n' % (reg_out, imm, src))
+ return print_instr('shr', (reg_out, imm), src)
def print_shrd(reg_out, rx_low, rx_high, imm, src):
#return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
@@ -687,13 +729,12 @@ def print_shrd(reg_out, rx_low, rx_high, imm, src):
print_shrd(reg_out, rx_high, rx_low, imm, src)
assert(rx_low == reg_out)
assert(imm[:2] == '0x')
- return ('"shrd %%[%s], %%[%s], $%s\\t\\n" // %s\n' % (rx_low, rx_high, imm, src))
+ return print_instr('shrd', (rx_low, rx_high, imm), src)
def schedule(input_data, existing, emit_vars):
ret = ''
buckets_seen = set()
- buckets_carried = set()
emit_vars = fix_emit_vars(emit_vars)
ret += ('// Convention is low_reg:high_reg\n')
for node in emit_vars:
@@ -754,34 +795,31 @@ def schedule(input_data, existing, emit_vars):
' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
buckets_seen.add(node['rev_deps'][0]['out'])
elif node['op'] == 'GET_HIGH':
- ret += print_adx(existing[node['rev_deps'][0]['out']],
+ carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')]
+ ret += print_adc(existing[node['rev_deps'][0]['out']],
+ None,
+ carry,
existing[node['rev_deps'][0]['out']],
existing[node['out']],
' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
elif node['op'] == 'GET_LOW':
carry = 'c' + node['rev_deps'][0]['out'][:-len('_low')]
- if node['rev_deps'][0]['out'] not in buckets_carried:
- ret += print_add(existing[node['rev_deps'][0]['out']],
- carry,
- existing[node['rev_deps'][0]['out']],
- existing[node['out']],
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
- buckets_carried.add(node['rev_deps'][0]['out'])
- else:
- ret += print_adc(existing[node['rev_deps'][0]['out']],
- carry,
- existing[node['rev_deps'][0]['out']],
- existing[node['out']],
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
+ ret += print_add(existing[node['rev_deps'][0]['out']],
+ carry,
+ existing[node['rev_deps'][0]['out']],
+ existing[node['out']],
+ ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
elif node['op'] in ('GET_CARRY',):
- carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')]
- ret += print_adcx(existing[node['rev_deps'][0]['out']],
- carry,
- ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
+ #carry = 'c' + node['rev_deps'][0]['out'][:-len('_high')]
+ #ret += print_adc(existing[node['rev_deps'][0]['out']],
+ # carry,
+ # ' + '.join(sorted([node['rev_deps'][0]['out']] + list(node['rev_deps'][0]['extra_out']))))
+ pass
elif node['op'] == '+' and len(node['extra_out']) > 0:
pass
elif node['op'] == '+' and len(node['deps']) == 2 and node['type'] == 'uint64_t':
- ret += print_adx(existing[node['out']],
+ ret += print_add(existing[node['out']],
+ None,
existing[node['deps'][0]['out']],
existing[node['deps'][1]['out']],
'%s = %s + %s'
@@ -801,25 +839,20 @@ def schedule(input_data, existing, emit_vars):
' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
buckets_seen.add(rdep['out'])
elif 'high' in rdep['out']:
- ret += print_adx(existing[rdep['out']],
+ carry = 'c' + rdep['out'][:-len('_high')]
+ ret += print_adc(existing[rdep['out']],
+ None,
+ carry,
existing[rdep['out']],
existing[node['out']],
' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
elif 'low' in rdep['out']:
carry = 'c' + rdep['out'][:-len('_low')]
- if rdep['out'] not in buckets_carried:
- ret += print_add(existing[rdep['out']],
- carry,
- existing[rdep['out']],
- existing[node['out']],
- ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
- buckets_carried.add(rdep['out'])
- else:
- ret += print_adc(existing[rdep['out']],
- carry,
- existing[rdep['out']],
- existing[node['out']],
- ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
+ ret += print_add(existing[rdep['out']],
+ carry,
+ existing[rdep['out']],
+ existing[node['out']],
+ ' + '.join(sorted([rdep['out']] + list(rdep['extra_out']))))
else:
assert(False)
return ret
@@ -845,19 +878,20 @@ def inline_schedule(sched, input_vars, output_vars):
[reg for reg in available_registers[count:] if reg.upper() not in TO_BE_RESTORED_REGISTERS]
renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):]))
for from_reg, to_reg in renaming.items():
- sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg)
+ sched = sched.replace('%%[%s]' % from_reg, print_val(to_reg, numbered_registers=True))
transient_regs = [renaming[reg] for reg in transient_regs]
for reg in REAL_REGISTERS:
- sched = sched.replace('%' + reg.lower(), reg.lower())
+ sched = sched.replace(print_val(reg.lower(), numbered_registers=True),
+ print_val(reg.lower(), numbered_registers=True, final_pass=True))
ret = ''
ret += 'uint64_t %s;\n' % ', '.join(output_vars[reg] for reg in output_regs)
ret += 'uint64_t %s;\n\n' % ', '.join(reg.lower() for reg in TO_BE_RESTORED_REGISTERS)
ret += 'asm (\n'
for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
- ret += print_mov_no_adjust('%%[%s]' % reg, reg)
+ ret += print_mov_no_adjust('%%[%s]' % reg, print_val(reg, numbered_registers=True, final_pass=True))
ret += sched
for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
- ret += print_mov_no_adjust(reg, '%%[%s]' % reg)
+ ret += print_mov_no_adjust(print_val(reg, final_pass=True), '%%[%s]' % reg)
ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n'
ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars] +
['[%s] "m" (%s)' % (reg, reg) for reg in map(str.lower, TO_BE_RESTORED_REGISTERS)]) + '\n'