aboutsummaryrefslogtreecommitdiff
path: root/src/Specific/Framework/make_curve.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/Specific/Framework/make_curve.py')
-rwxr-xr-xsrc/Specific/Framework/make_curve.py71
1 files changed, 52 insertions, 19 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py
index 56c91e577..aeaef91db 100755
--- a/src/Specific/Framework/make_curve.py
+++ b/src/Specific/Framework/make_curve.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
from __future__ import with_statement
-import json, sys, os, math, re, shutil
+import json, sys, os, math, re, shutil, io
def compute_bitwidth(base):
return 2**int(math.ceil(math.log(base, 2)))
@@ -175,12 +175,13 @@ def nested_list_to_string(v):
assert(False)
def make_curve_parameters(parameters):
- def fix_option(term):
+ def fix_option(term, scope_string=''):
if not isinstance(term, str) and not isinstance(term, unicode):
return term
if term[:len('Some ')] != 'Some ' and term != 'None':
- if ' ' in term: return 'Some (%s)' % term
- return 'Some %s' % term
+ if ' ' in term and (term[0] + term[-1]) not in ('()', '[]'):
+ return 'Some (%s)%s' % (term, scope_string)
+ return 'Some %s%s' % (term, scope_string)
return term
replacements = dict(parameters)
assert(all(ch in '0123456789^+- ' for ch in parameters['modulus']))
@@ -208,9 +209,13 @@ def make_curve_parameters(parameters):
parameters.get(op + '_code', None),
nargs,
sz)
- for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'):
- if k not in replacements.keys():
- replacements[k] = 'None'
+ replacements['coef_div_modulus_raw'] = replacements.get('coef_div_modulus', '0')
+ for k, scope_string in (('upper_bound_of_exponent', ''),
+ ('allowable_bit_widths', '%nat'),
+ ('freeze_extra_allowable_bit_widths', '%nat'),
+ ('coef_div_modulus', '%nat'),
+ ('modinv_fuel', '%nat')):
+ replacements[k] = fix_option(nested_list_to_string(replacements.get(k, 'None')), scope_string=scope_string)
for k in ('s', 'c', 'goldilocks'):
replacements[k] = nested_list_to_string(replacements[k])
for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'):
@@ -232,7 +237,7 @@ Module Curve <: CurveParameters.
Definition carry_chains : option (list (list nat)) := Eval vm_compute in %(carry_chains)s.
Definition a24 : option Z := %(a24)s.
- Definition coef_div_modulus : nat := %(coef_div_modulus)s%%nat. (* add %(coef_div_modulus)s*modulus before subtracting *)
+ Definition coef_div_modulus : option nat := %(coef_div_modulus)s. (* add %(coef_div_modulus_raw)s*modulus before subtracting *)
Definition goldilocks : bool := %(goldilocks)s.
@@ -245,6 +250,7 @@ Module Curve <: CurveParameters.
Definition upper_bound_of_exponent : option (Z -> Z) := %(upper_bound_of_exponent)s.
Definition allowable_bit_widths : option (list nat) := %(allowable_bit_widths)s.
Definition freeze_extra_allowable_bit_widths : option (list nat) := %(freeze_extra_allowable_bit_widths)s.
+ Definition modinv_fuel : option nat := %(modinv_fuel)s.
Ltac extra_prove_mul_eq := %(extra_prove_mul_eq)s.
Ltac extra_prove_square_eq := %(extra_prove_square_eq)s.
End Curve.
@@ -257,12 +263,9 @@ Require Import %s.CurveParameters.
Module Import T := MakeSynthesisTactics Curve.
-Module P <: SynthesisPrePackage.
- Definition Synthesis_package' : Synthesis_package'_Type.
- Proof. make_synthesis_package (). Defined.
-
- Definition Synthesis_package
- := Eval cbv [Synthesis_package' projT2] in projT2 Synthesis_package'.
+Module P <: PrePackage.
+ Definition package : Tag.Context.
+ Proof. make_Synthesis_package (). Defined.
End P.
Module Export S := PackageSynthesis Curve P.
@@ -282,6 +285,8 @@ Proof.
Time synthesize_%(arg)s ().
Show Ltac Profile.
Time Defined.
+
+Print Assumptions %(arg)s.
""" % {'prefix':prefix, 'arg':fearg[2:]}
elif fearg in ('fesquare',):
return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems.
@@ -296,6 +301,8 @@ Proof.
Time synthesize_square ().
Show Ltac Profile.
Time Defined.
+
+Print Assumptions square.
""" % {'prefix':prefix}
elif fearg in ('freeze',):
return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems.
@@ -310,11 +317,29 @@ Proof.
Time synthesize_freeze ().
Show Ltac Profile.
Time Defined.
+
+Print Assumptions freeze.
""" % {'prefix':prefix}
+ elif fearg in ('feopp',):
+ return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems.
+Require Import %(prefix)s.Synthesis.
+
+(* TODO : change this to field once field isomorphism happens *)
+Definition %(arg)s :
+ { %(arg)s : feBW -> feBW
+ | forall a, phiBW (%(arg)s a) = F.%(arg)s (phiBW a) }.
+Proof.
+ Set Ltac Profiling.
+ Time synthesize_%(arg)s ().
+ Show Ltac Profile.
+Time Defined.
+
+Print Assumptions %(arg)s.
+""" % {'prefix':prefix, 'arg':fearg[2:]}
elif fearg in ('ladderstep', 'xzladderstep'):
return r"""Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Arithmetic.PrimeFieldTheorems.
-Require Import Crypto.Specific.Framework.LadderstepSynthesisFramework.
+Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Ladderstep.
Require Import %(prefix)s.Synthesis.
(* TODO : change this to field once field isomorphism happens *)
@@ -334,6 +359,8 @@ Proof.
synthesize_xzladderstep ().
Show Ltac Profile.
Time Defined.
+
+Print Assumptions xzladderstep.
""" % {'prefix':prefix}
else:
print('ERROR: Unsupported operation: %s' % fearg)
@@ -343,10 +370,12 @@ Time Defined.
def make_display_arg(fearg, prefix):
file_name = fearg
function_name = fearg
- if fearg in ('femul', 'fesub', 'feadd', 'fesquare'):
+ if fearg in ('femul', 'fesub', 'feadd', 'fesquare', 'feopp'):
function_name = fearg[2:]
elif fearg in ('freeze', 'xzladderstep'):
pass
+ elif fearg in ('fenz',):
+ function_name = 'nonzero'
elif fearg in ('ladderstep', ):
function_name = 'xzladderstep'
else:
@@ -404,12 +433,16 @@ def main(*args):
if open(fname, 'r').read() == outputs[k]:
continue
new_files.append(fname)
- with open(fname, 'w') as f:
- f.write(outputs[k])
+ with io.open(fname, 'w', newline='\n') as f:
+ f.write(unicode(outputs[k]))
if fname[-len('compiler.sh'):] == 'compiler.sh':
mode = os.fstat(f.fileno()).st_mode
mode |= 0o111
- os.fchmod(f.fileno(), mode & 0o7777)
+ mode &= 0o7777
+ if 'fchmod' in os.__dict__.keys():
+ os.fchmod(f.fileno(), mode)
+ else:
+ os.chmod(f.name, mode)
if len(new_files) > 0:
print('git add ' + ' '.join('"%s"' % i for i in new_files))