diff options
Diffstat (limited to 'src/Specific/Framework/make_curve.py')
-rwxr-xr-x | src/Specific/Framework/make_curve.py | 52 |
1 files changed, 35 insertions, 17 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py index b530a72ee..6c155efb9 100755 --- a/src/Specific/Framework/make_curve.py +++ b/src/Specific/Framework/make_curve.py @@ -6,9 +6,8 @@ def compute_bitwidth(base): return 2**int(math.ceil(math.log(base, 2))) def compute_sz(modulus, base): return 1 + int(math.ceil(math.log(modulus, 2) / base)) -def default_carry_chain(cc): - assert(cc == 'carry_chain1') - return 'Some (seq 0 (pred sz))' +def default_carry_chains(): + return ('seq 0 (pred sz)', '[0; 1]') def compute_s(modulus_str): base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', '([0-9^ +-]*)$')), modulus_str).groups() return '%s^%s' % (base, exp) @@ -132,7 +131,23 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent=' ret += '\n%s)' % closing_indent return ret +def nested_list_to_string(v): + if isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode): + return str(v) + elif isinstance(v, list): + return '[%s]' % '; '.join(map(nested_list_to_string, v)) + else: + print('ERROR: Invalid type in nested_list_to_string: %s' % str(type(v))) + assert(False) + def make_curve_parameters(parameters): + def fix_option(term): + if not isinstance(term, str) and not isinstance(term, unicode): + return term + if term[:len('Some ')] != 'Some ' and term != 'None': + if ' ' in term: return 'Some (%s)' % term + return 'Some %s' % term + return term replacements = dict(parameters) assert(all(ch in '0123456789^+- ' for ch in parameters['modulus'])) modulus = eval(parameters['modulus'].replace('^', '**')) @@ -141,17 +156,16 @@ def make_curve_parameters(parameters): bitwidth = int(replacements['bitwidth']) replacements['sz'] = parameters.get('sz', str(compute_sz(modulus, base))) sz = int(replacements['sz']) - for cc in ('carry_chain1', 'carry_chain2'): - if cc in replacements.keys() and isinstance(replacements[cc], list): - replacements[cc] = 'Some [%s]%%nat' % '; '.join(map(str, replacements[cc])) - elif replacements[cc] == 'default': - replacements[cc] = default_carry_chain(cc) - elif isinstance(replacements[cc], str): - if replacements[cc][:len('Some ')] != 'Some ' and replacements[cc][:len('None')] != 'None': - if ' ' in replacements[cc]: replacements[cc] = '(%s)' % replacements[cc] - replacements[cc] = 'Some %s' % replacements[cc] - elif cc not in replacements.keys(): - replacements[cc] = 'None' + replacements['a24'] = fix_option(parameters.get('a24', 'None')) + replacements['carry_chains'] = fix_option(parameters.get('carry_chains', 'None')) + if isinstance(replacements['carry_chains'], list): + defaults = default_carry_chains() + replacements['carry_chains'] \ + = ('Some %s%%nat' + % nested_list_to_string([(v if v != 'default' else defaults[i]) + for i, v in enumerate(replacements['carry_chains'])])) + elif replacements['carry_chains'] == 'default': + replacements['carry_chains'] = 'Some %s%%nat' % nested_list_to_string(default_carry_chains()) replacements['s'] = parameters.get('s', compute_s(parameters['modulus'])) replacements['c'] = parameters.get('c', compute_c(parameters['modulus'])) if isinstance(replacements['c'], list): @@ -164,6 +178,9 @@ def make_curve_parameters(parameters): for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'): if k not in replacements.keys(): replacements[k] = 'None' + for k in ('goldilocks', ): + if k not in replacements.keys(): + replacements[k] = 'false' for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'): if k not in replacements.keys(): replacements[k] = 'idtac' @@ -180,12 +197,13 @@ Module Curve <: CurveParameters. Definition bitwidth : Z := %(bitwidth)s. Definition s : Z := %(s)s. Definition c : list limb := %(c)s. - Definition carry_chain1 : option (list nat) := Eval vm_compute in %(carry_chain1)s. - Definition carry_chain2 : option (list nat) := Eval vm_compute in %(carry_chain2)s. + Definition carry_chains : option (list (list nat)) := Eval vm_compute in %(carry_chains)s. - Definition a24 : Z := %(a24)s. + Definition a24 : option Z := %(a24)s. Definition coef_div_modulus : nat := %(coef_div_modulus)s%%nat. (* add %(coef_div_modulus)s*modulus before subtracting *) + Definition goldilocks : bool := %(goldilocks)s. + Definition mul_code : option (Z^sz -> Z^sz -> Z^sz) := %(mul)s. |