diff options
author | Jason Gross <jgross@mit.edu> | 2017-10-11 15:30:11 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-10-18 23:01:29 -0400 |
commit | 0234d56d299e34651b287a42b777d1f644fc56d5 (patch) | |
tree | 864c212fffc8d42095bfe81a92e6870b02122da4 | |
parent | 95ac9ae063e7af1b9e9023be2b601a44b72b9e01 (diff) |
Add default computation for goldilocks
As per
https://github.com/mit-plv/fiat-crypto/pull/248#discussion_r144016387,
we turn on goldilocks by default if the prime is of the form 2^2k - 2^k
- 1.
-rwxr-xr-x | src/Specific/Framework/make_curve.py | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py index ab81358c1..26a95487d 100755 --- a/src/Specific/Framework/make_curve.py +++ b/src/Specific/Framework/make_curve.py @@ -13,23 +13,39 @@ def compute_s(modulus_str): return '%s^%s' % (base, exp) def compute_c(modulus_str): base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', r'([0-9\^ +\*-]*)$')), modulus_str).groups() - if rest.strip() == '': return '[]' + if rest.strip() == '': return [] assert(rest.strip()[0] == '-') rest = negate_numexpr(rest.strip()[1:]) ret = [] for part in re.findall(r'(-?[0-9\^\*]+)', rest.replace(' ', '')): if part.isdigit(): - ret.append('(1, %s)' % part) + ret.append(('1', part)) elif part[:2] == '2^' and part[2:].isdigit(): - ret.append('(%s, 1)' % part) + ret.append((part, '1')) else: raw_input('Unhandled part: %s' % part) ret = None break if ret is not None: - return '[%s]' % '; '.join(reversed(ret)) + return list(reversed(ret)) # XXX FIXME: Is this the right way to extract c? - return '[(1, %s)]' % rest + return [('1', rest)] +def compute_goldilocks(s, c): + ms = re.match(r'^2\^([0-9]+)$', s) + if ms is None: return False + two_k = int(ms.groups()[0]) + assert(isinstance(c, list)) + if len(c) != 2: return False + one_vs = [str(v) for k, v in c if str(k) == '1'] + others = [(str(k), str(v)) for k, v in c if str(k) != '1'] + if len(one_vs) != 1 or len(others) != 1 or one_vs[0] != '1' or others[0][1] != '1': return False + mk = re.match(r'^2\^([0-9]+)$', others[0][0]) + if mk is None: return False + k = int(mk.groups()[0]) + if two_k != 2 * k: return False + return True + + def negate_numexpr(expr): remap = dict([(d, d) for d in '0123456789^ '] + [('-', '+'), ('+', '-')]) @@ -144,10 +160,14 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent=' return ret def nested_list_to_string(v): - if isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode): + if isinstance(v, bool): + return {True:'true', False:'false'}[v] + elif isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode): return str(v) elif isinstance(v, list): return '[%s]' % '; '.join(map(nested_list_to_string, v)) + elif isinstance(v, tuple): + return '(%s)' % ', '.join(map(nested_list_to_string, v)) else: print('ERROR: Invalid type in nested_list_to_string: %s' % str(type(v))) assert(False) @@ -180,8 +200,7 @@ def make_curve_parameters(parameters): replacements['carry_chains'] = 'Some %s%%nat' % nested_list_to_string(default_carry_chains()) replacements['s'] = parameters.get('s', compute_s(parameters['modulus'])) replacements['c'] = parameters.get('c', compute_c(parameters['modulus'])) - if isinstance(replacements['c'], list): - replacements['c'] = '[%s]' % '; '.join('(%s, %s)' % (str(w), str(v)) for w, v in replacements['c']) + replacements['goldilocks'] = parameters.get('goldilocks', compute_goldilocks(replacements['s'], replacements['c'])) for op, nargs in (('mul', 2), ('square', 1)): replacements[op] = format_c_code(parameters.get(op + '_header', None), parameters.get(op + '_code', None), @@ -190,9 +209,8 @@ def make_curve_parameters(parameters): for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'): if k not in replacements.keys(): replacements[k] = 'None' - for k in ('goldilocks', ): - if k not in replacements.keys(): - replacements[k] = 'false' + for k in ('s', 'c', 'goldilocks'): + replacements[k] = nested_list_to_string(replacements[k]) for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'): if k not in replacements.keys(): replacements[k] = 'idtac' |