aboutsummaryrefslogtreecommitdiff
path: root/register-allocate.py
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-09-12 16:33:09 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-09-12 16:33:09 -0400
commit98daac5a86136cf0f1018b292baf8a676f6fd579 (patch)
tree51b0f170062345236e568f9c4667bb5625bf527d /register-allocate.py
parent3c10ad879925d3d6410e090c3b0606be8a9c4a2d (diff)
Update register allocation more
Switch over to intel syntax, because I can't figure out how to name registers in AT&T / GAS.
Diffstat (limited to 'register-allocate.py')
-rwxr-xr-xregister-allocate.py84
1 files changed, 56 insertions, 28 deletions
diff --git a/register-allocate.py b/register-allocate.py
index 69f7ef7e9..16faee062 100755
--- a/register-allocate.py
+++ b/register-allocate.py
@@ -4,8 +4,11 @@ import codecs, re, sys, os
LAMBDA = u'\u03bb'
-NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RSI', 'RDI')
+NAMED_REGISTERS = ('RAX', 'RCX', 'RDX', 'RBX', 'RSP', 'RBP', 'RSI', 'RDI')
+RESERVED_REGISTERS = ('RSP', )
+TO_BE_RESTORED_REGISTERS = ('RBP', )
NAMED_REGISTER_MAPPING = dict(('r%d' % i, reg) for i, reg in enumerate(NAMED_REGISTERS))
+REAL_REGISTERS = tuple(list(NAMED_REGISTERS) + ['r%d' % i for i in range(8, 16)])
REGISTERS = ['reg%d' % i for i in range(13)]
def get_lines(filename):
@@ -552,11 +555,29 @@ def print_val(reg):
return '$%s' % reg
return '%%[%s]' % reg
+def print_mov_no_adjust(reg_out, reg_in, comment=None):
+ #return '%s <- MOV %s;\n' % (reg_out, reg_in)
+ #ret, reg_in = print_load(reg_in)
+ ret = '"mov %s, %s\\t\\n"' % (reg_out, reg_in)
+ if comment is not None:
+ ret += ' // %s' % comment
+ ret += '\n'
+ return ret
+
+def print_mov(reg_out, reg_in):
+ #return '%s <- MOV %s;\n' % (reg_out, reg_in)
+ #ret, reg_in = print_load(reg_in)
+ return print_mov_no_adjust(print_val(reg_out), print_val(reg_in))
+
+def print_load_constant(reg_out, imm):
+ assert(imm[:2] == '0x')
+ return print_mov_no_adjust(print_val(reg_out), print_val(imm))
+
def print_load_specific_reg(reg, specific_reg='rdx'):
ret = ''
#ret += '"mov %%%s, %%[%s_backup]\\t\\n" // XXX: How do I specify that a particular register should be %s?\n' % (specific_reg, specific_reg, specific_reg)
if reg != specific_reg:
- ret += '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(specific_reg))
+ ret += print_mov_no_adjust(print_val(specific_reg), print_val(reg))
return ret, specific_reg
def print_unload_specific_reg(specific_reg='rdx'):
ret = ''
@@ -572,7 +593,7 @@ def print_load(reg, can_clobber=tuple(), dont_clobber=tuple()):
return ('', reg)
else:
cur_reg = can_clobber.pop()
- ret = '"mov %s, %s\\t\\n"\n' % (print_val(reg), print_val(cur_reg))
+ ret = print_mov_no_adjust(print_val(cur_reg), print_val(reg))
return (ret, cur_reg)
def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src):
@@ -581,23 +602,14 @@ def print_mulx(reg_out_low, reg_out_high, rx1, rx2, src):
ret2, actual_rx1 = print_load_specific_reg(rx1, 'rdx')
assert(rx2 != actual_rx1)
ret3, actual_rx2 = print_load(rx2, can_clobber=[reg_out_high, reg_out_low], dont_clobber=[actual_rx1])
- ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(actual_rx2), print_val(reg_out_high), print_val(reg_out_low), src))
+ ret += ret2 + ret3 + ('"mulx %s, %s, %s\\t\\n" // %s\n' % (print_val(reg_out_high), print_val(reg_out_low), print_val(actual_rx2), src))
ret += print_unload_specific_reg('rdx')
return ret
def print_mov_bucket(reg_out, reg_in, bucket):
#return '%s <- MOV %s; // bucket: %s\n' % (reg_out, reg_in, bucket)
#ret, reg_in = print_load(reg_in, can_clobber=[reg_out])
- return ('"mov %s, %s\\t\\n" // bucket: %s\n' % (print_val(reg_in), print_val(reg_out), bucket))
-
-def print_mov(reg_out, reg_in):
- #return '%s <- MOV %s;\n' % (reg_out, reg_in)
- #ret, reg_in = print_load(reg_in)
- return ('"mov %s, %s\\t\\n"\n' % (print_val(reg_in), print_val(reg_out)))
-
-def print_load_constant(reg_out, imm):
- assert(imm[:2] == '0x')
- return ('"mov $%s, %s\\t\\n"\n' % (imm, print_val(reg_out)))
+ return print_mov_no_adjust(print_val(reg_out), print_val(reg_in), 'bucket: ' + bucket)
LAST_CARRY = None
@@ -615,7 +627,7 @@ def print_adx(reg_out, rx1, rx2, bucket):
#return '%s <- ADX %s, %s; // bucket: %s\n' % (reg_out, rx1, rx2, bucket)
assert(rx1 == reg_out)
ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
+ return ret + ('"adx %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
def print_add(reg_out, cf, rx1, rx2, bucket):
#return '%s, (%s) <- ADD %s, %s; // bucket: %s\n' % (reg_out, cf, rx1, rx2, bucket)
@@ -624,7 +636,7 @@ def print_add(reg_out, cf, rx1, rx2, bucket):
#assert(LAST_CARRY is None or LAST_CARRY == cf)
LAST_CARRY = cf
ret, rx2 = print_load(rx2, dont_clobber=[rx1])
- return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
+ return ret + ('"add %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
def print_adc(reg_out, cf, rx1, rx2, bucket):
#return '%s, (%s) <- ADC (%s), %s, %s; // bucket: %s\n' % (reg_out, cf, cf, rx1, rx2, bucket)
@@ -636,12 +648,12 @@ def print_adc(reg_out, cf, rx1, rx2, bucket):
LAST_CARRY = cf
ret2, rx2 = print_load(rx2, dont_clobber=[rx1])
ret += ret2
- return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (rx2, reg_out, bucket))
+ return ret + ('"adc %%[%s], %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, rx2, bucket))
def print_adcx(reg_out, cf, bucket):
#return '%s <- ADCX (%s), %s, 0x0; // bucket: %s\n' % (reg_out, cf, reg_out, bucket)
assert(LAST_CARRY == cf)
- return ('"adcx $0, %%[%s]\\t\\n" // bucket: %s\n' % (reg_out, bucket))
+ return ('"adcx %%[%s], $0\\t\\n" // bucket: %s\n' % (reg_out, bucket))
def print_and(reg_out, rx1, rx2, src):
#return '%s <- AND %s, %s; // %s\n' % (reg_out, rx1, rx2, src)
@@ -650,11 +662,11 @@ def print_and(reg_out, rx1, rx2, src):
if reg_out != rx1:
return print_mov(reg_out, rx1) + print_and(reg_out, reg_out, rx2, src)
else:
- if rx2[:2] == '0x':
- return ('"and $%s, %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src))
- else:
+ ret = ''
+ if rx2[:2] != '0x':
ret, rx2 = print_load(rx2, can_clobber=[reg_out], dont_clobber=[rx1])
- return ret + ('"and %%[%s], %%[%s]\\t\\n" // %s\n' % (rx2, reg_out, src))
+ return ret + ('"and %s, %s\\t\\n" // %s\n' % (print_val(reg_out), print_val(rx2), src))
+
def print_shr(reg_out, rx1, imm, src):
#return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
@@ -662,7 +674,7 @@ def print_shr(reg_out, rx1, imm, src):
LAST_CARRY = None
assert(rx1 == reg_out)
assert(imm[:2] == '0x')
- return ('"shr $%s, %%[%s]\\t\\n" // %s\n' % (imm, reg_out, src))
+ return ('"shr %%[%s], $%s\\t\\n" // %s\n' % (reg_out, imm, src))
def print_shrd(reg_out, rx_low, rx_high, imm, src):
#return '%s <- SHR %s, %s; // %s\n' % (reg_out, rx1, imm, src)
@@ -675,7 +687,7 @@ def print_shrd(reg_out, rx_low, rx_high, imm, src):
print_shrd(reg_out, rx_high, rx_low, imm, src)
assert(rx_low == reg_out)
assert(imm[:2] == '0x')
- return ('"shrd $%s, %%[%s], %%[%s]\\t\\n" // %s\n' % (imm, rx_low, rx_high, src))
+ return ('"shrd %%[%s], %%[%s], $%s\\t\\n" // %s\n' % (rx_low, rx_high, imm, src))
def schedule(input_data, existing, emit_vars):
@@ -820,19 +832,35 @@ def inline_schedule(sched, input_vars, output_vars):
mems, variables = [i for i in variables if i[:2] == 'mx'], [i for i in variables if i[:2] != 'mx']
special_reg, variables = [i for i in variables if i.upper() in NAMED_REGISTERS], [i for i in variables if i.upper() not in NAMED_REGISTERS]
transient_regs, output_regs = [i for i in variables if i not in output_vars.values()], [i for i in variables if i in output_vars.keys()]
- available_registers = ['r%d' % i for i in range(16)
- if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg]
+ available_registers = [NAMED_REGISTER_MAPPING.get('r%d' % i, 'r%d' % i).lower() for i in range(16)
+ if ('r%d' % i) not in NAMED_REGISTER_MAPPING.keys() or (NAMED_REGISTER_MAPPING['r%d' % i].lower() not in special_reg
+ and NAMED_REGISTER_MAPPING['r%d' % i] not in RESERVED_REGISTERS)]
+ assert(len(available_registers) >= len(transient_regs))
for reg in output_regs:
sched = sched.replace('%%[%s]' % reg, '%%[r%s]' % output_vars[reg])
+ available_registers = available_registers[-len(transient_regs):]
+ assert(len(available_registers) > len(TO_BE_RESTORED_REGISTERS)) # makes the replacement of low registers with ones we have to handle specially easier
+ count = len([reg for reg in TO_BE_RESTORED_REGISTERS if reg.lower() not in available_registers])
+ available_registers = [reg.lower() for reg in TO_BE_RESTORED_REGISTERS] + \
+ [reg for reg in available_registers[count:] if reg.upper() not in TO_BE_RESTORED_REGISTERS]
renaming = dict((from_reg, to_reg) for from_reg, to_reg in zip(transient_regs, available_registers[-len(transient_regs):]))
for from_reg, to_reg in renaming.items():
sched = sched.replace('%%[%s]' % from_reg, '%%%s' % to_reg)
transient_regs = [renaming[reg] for reg in transient_regs]
+ for reg in REAL_REGISTERS:
+ sched = sched.replace('%' + reg.lower(), reg.lower())
ret = ''
- ret += '__asm__ (\n'
+ ret += 'uint64_t %s;\n' % ', '.join(output_vars[reg] for reg in output_regs)
+ ret += 'uint64_t %s;\n\n' % ', '.join(reg.lower() for reg in TO_BE_RESTORED_REGISTERS)
+ ret += 'asm (\n'
+ for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
+ ret += print_mov_no_adjust('%%[%s]' % reg, reg)
ret += sched
+ for reg in map(str.lower, TO_BE_RESTORED_REGISTERS):
+ ret += print_mov_no_adjust(reg, '%%[%s]' % reg)
ret += ': ' + ', '.join(['[r%s] "=&r" (%s)' % (output_vars[reg], output_vars[reg]) for reg in output_regs]) + '\n'
- ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars]) + '\n'
+ ret += ': ' + ', '.join(['[%s] "m" (%s)' % (reg, input_vars[reg]) for reg in input_vars] +
+ ['[%s] "m" (%s)' % (reg, reg) for reg in map(str.lower, TO_BE_RESTORED_REGISTERS)]) + '\n'
ret += ': ' + ', '.join(['"cc"'] +
['"%s"' % reg for reg in special_reg] +
['"%s"' % reg for reg in transient_regs]) + '\n'