aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-11 15:30:11 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2017-10-18 23:01:29 -0400
commit0234d56d299e34651b287a42b777d1f644fc56d5 (patch)
tree864c212fffc8d42095bfe81a92e6870b02122da4
parent95ac9ae063e7af1b9e9023be2b601a44b72b9e01 (diff)
Add default computation for goldilocks
As per https://github.com/mit-plv/fiat-crypto/pull/248#discussion_r144016387, we turn on goldilocks by default if the prime is of the form 2^2k - 2^k - 1.
-rwxr-xr-xsrc/Specific/Framework/make_curve.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py
index ab81358c1..26a95487d 100755
--- a/src/Specific/Framework/make_curve.py
+++ b/src/Specific/Framework/make_curve.py
@@ -13,23 +13,39 @@ def compute_s(modulus_str):
return '%s^%s' % (base, exp)
def compute_c(modulus_str):
base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', r'([0-9\^ +\*-]*)$')), modulus_str).groups()
- if rest.strip() == '': return '[]'
+ if rest.strip() == '': return []
assert(rest.strip()[0] == '-')
rest = negate_numexpr(rest.strip()[1:])
ret = []
for part in re.findall(r'(-?[0-9\^\*]+)', rest.replace(' ', '')):
if part.isdigit():
- ret.append('(1, %s)' % part)
+ ret.append(('1', part))
elif part[:2] == '2^' and part[2:].isdigit():
- ret.append('(%s, 1)' % part)
+ ret.append((part, '1'))
else:
raw_input('Unhandled part: %s' % part)
ret = None
break
if ret is not None:
- return '[%s]' % '; '.join(reversed(ret))
+ return list(reversed(ret))
# XXX FIXME: Is this the right way to extract c?
- return '[(1, %s)]' % rest
+ return [('1', rest)]
+def compute_goldilocks(s, c):
+ ms = re.match(r'^2\^([0-9]+)$', s)
+ if ms is None: return False
+ two_k = int(ms.groups()[0])
+ assert(isinstance(c, list))
+ if len(c) != 2: return False
+ one_vs = [str(v) for k, v in c if str(k) == '1']
+ others = [(str(k), str(v)) for k, v in c if str(k) != '1']
+ if len(one_vs) != 1 or len(others) != 1 or one_vs[0] != '1' or others[0][1] != '1': return False
+ mk = re.match(r'^2\^([0-9]+)$', others[0][0])
+ if mk is None: return False
+ k = int(mk.groups()[0])
+ if two_k != 2 * k: return False
+ return True
+
+
def negate_numexpr(expr):
remap = dict([(d, d) for d in '0123456789^ '] + [('-', '+'), ('+', '-')])
@@ -144,10 +160,14 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent='
return ret
def nested_list_to_string(v):
- if isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode):
+ if isinstance(v, bool):
+ return {True:'true', False:'false'}[v]
+ elif isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode):
return str(v)
elif isinstance(v, list):
return '[%s]' % '; '.join(map(nested_list_to_string, v))
+ elif isinstance(v, tuple):
+ return '(%s)' % ', '.join(map(nested_list_to_string, v))
else:
print('ERROR: Invalid type in nested_list_to_string: %s' % str(type(v)))
assert(False)
@@ -180,8 +200,7 @@ def make_curve_parameters(parameters):
replacements['carry_chains'] = 'Some %s%%nat' % nested_list_to_string(default_carry_chains())
replacements['s'] = parameters.get('s', compute_s(parameters['modulus']))
replacements['c'] = parameters.get('c', compute_c(parameters['modulus']))
- if isinstance(replacements['c'], list):
- replacements['c'] = '[%s]' % '; '.join('(%s, %s)' % (str(w), str(v)) for w, v in replacements['c'])
+ replacements['goldilocks'] = parameters.get('goldilocks', compute_goldilocks(replacements['s'], replacements['c']))
for op, nargs in (('mul', 2), ('square', 1)):
replacements[op] = format_c_code(parameters.get(op + '_header', None),
parameters.get(op + '_code', None),
@@ -190,9 +209,8 @@ def make_curve_parameters(parameters):
for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'):
if k not in replacements.keys():
replacements[k] = 'None'
- for k in ('goldilocks', ):
- if k not in replacements.keys():
- replacements[k] = 'false'
+ for k in ('s', 'c', 'goldilocks'):
+ replacements[k] = nested_list_to_string(replacements[k])
for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'):
if k not in replacements.keys():
replacements[k] = 'idtac'