aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/make_curve.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/Specific/Framework/make_curve.py')
-rwxr-xr-xsrc/Specific/Framework/make_curve.py52
1 files changed, 35 insertions, 17 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py
index b530a72ee..6c155efb9 100755
--- a/src/Specific/Framework/make_curve.py
+++ b/src/Specific/Framework/make_curve.py
@@ -6,9 +6,8 @@ def compute_bitwidth(base):
return 2**int(math.ceil(math.log(base, 2)))
def compute_sz(modulus, base):
return 1 + int(math.ceil(math.log(modulus, 2) / base))
-def default_carry_chain(cc):
- assert(cc == 'carry_chain1')
- return 'Some (seq 0 (pred sz))'
+def default_carry_chains():
+ return ('seq 0 (pred sz)', '[0; 1]')
def compute_s(modulus_str):
base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', '([0-9^ +-]*)$')), modulus_str).groups()
return '%s^%s' % (base, exp)
@@ -132,7 +131,23 @@ def format_c_code(header, code, numargs, sz, indent=' ', closing_indent='
ret += '\n%s)' % closing_indent
return ret
+def nested_list_to_string(v):
+ if isinstance(v, str) or isinstance(v, int) or isinstance(v, unicode):
+ return str(v)
+ elif isinstance(v, list):
+ return '[%s]' % '; '.join(map(nested_list_to_string, v))
+ else:
+ print('ERROR: Invalid type in nested_list_to_string: %s' % str(type(v)))
+ assert(False)
+
def make_curve_parameters(parameters):
+ def fix_option(term):
+ if not isinstance(term, str) and not isinstance(term, unicode):
+ return term
+ if term[:len('Some ')] != 'Some ' and term != 'None':
+ if ' ' in term: return 'Some (%s)' % term
+ return 'Some %s' % term
+ return term
replacements = dict(parameters)
assert(all(ch in '0123456789^+- ' for ch in parameters['modulus']))
modulus = eval(parameters['modulus'].replace('^', '**'))
@@ -141,17 +156,16 @@ def make_curve_parameters(parameters):
bitwidth = int(replacements['bitwidth'])
replacements['sz'] = parameters.get('sz', str(compute_sz(modulus, base)))
sz = int(replacements['sz'])
- for cc in ('carry_chain1', 'carry_chain2'):
- if cc in replacements.keys() and isinstance(replacements[cc], list):
- replacements[cc] = 'Some [%s]%%nat' % '; '.join(map(str, replacements[cc]))
- elif replacements[cc] == 'default':
- replacements[cc] = default_carry_chain(cc)
- elif isinstance(replacements[cc], str):
- if replacements[cc][:len('Some ')] != 'Some ' and replacements[cc][:len('None')] != 'None':
- if ' ' in replacements[cc]: replacements[cc] = '(%s)' % replacements[cc]
- replacements[cc] = 'Some %s' % replacements[cc]
- elif cc not in replacements.keys():
- replacements[cc] = 'None'
+ replacements['a24'] = fix_option(parameters.get('a24', 'None'))
+ replacements['carry_chains'] = fix_option(parameters.get('carry_chains', 'None'))
+ if isinstance(replacements['carry_chains'], list):
+ defaults = default_carry_chains()
+ replacements['carry_chains'] \
+ = ('Some %s%%nat'
+ % nested_list_to_string([(v if v != 'default' else defaults[i])
+ for i, v in enumerate(replacements['carry_chains'])]))
+ elif replacements['carry_chains'] == 'default':
+ replacements['carry_chains'] = 'Some %s%%nat' % nested_list_to_string(default_carry_chains())
replacements['s'] = parameters.get('s', compute_s(parameters['modulus']))
replacements['c'] = parameters.get('c', compute_c(parameters['modulus']))
if isinstance(replacements['c'], list):
@@ -164,6 +178,9 @@ def make_curve_parameters(parameters):
for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'):
if k not in replacements.keys():
replacements[k] = 'None'
+ for k in ('goldilocks', ):
+ if k not in replacements.keys():
+ replacements[k] = 'false'
for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'):
if k not in replacements.keys():
replacements[k] = 'idtac'
@@ -180,12 +197,13 @@ Module Curve <: CurveParameters.
Definition bitwidth : Z := %(bitwidth)s.
Definition s : Z := %(s)s.
Definition c : list limb := %(c)s.
- Definition carry_chain1 : option (list nat) := Eval vm_compute in %(carry_chain1)s.
- Definition carry_chain2 : option (list nat) := Eval vm_compute in %(carry_chain2)s.
+ Definition carry_chains : option (list (list nat)) := Eval vm_compute in %(carry_chains)s.
- Definition a24 : Z := %(a24)s.
+ Definition a24 : option Z := %(a24)s.
Definition coef_div_modulus : nat := %(coef_div_modulus)s%%nat. (* add %(coef_div_modulus)s*modulus before subtracting *)
+ Definition goldilocks : bool := %(goldilocks)s.
+
Definition mul_code : option (Z^sz -> Z^sz -> Z^sz)
:= %(mul)s.