diff options
Diffstat (limited to 'src/Specific/Framework/make_curve.py')
-rwxr-xr-x | src/Specific/Framework/make_curve.py | 71 |
1 files changed, 52 insertions, 19 deletions
diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py index 56c91e577..aeaef91db 100755 --- a/src/Specific/Framework/make_curve.py +++ b/src/Specific/Framework/make_curve.py @@ -1,6 +1,6 @@ #!/usr/bin/env python from __future__ import with_statement -import json, sys, os, math, re, shutil +import json, sys, os, math, re, shutil, io def compute_bitwidth(base): return 2**int(math.ceil(math.log(base, 2))) @@ -175,12 +175,13 @@ def nested_list_to_string(v): assert(False) def make_curve_parameters(parameters): - def fix_option(term): + def fix_option(term, scope_string=''): if not isinstance(term, str) and not isinstance(term, unicode): return term if term[:len('Some ')] != 'Some ' and term != 'None': - if ' ' in term: return 'Some (%s)' % term - return 'Some %s' % term + if ' ' in term and (term[0] + term[-1]) not in ('()', '[]'): + return 'Some (%s)%s' % (term, scope_string) + return 'Some %s%s' % (term, scope_string) return term replacements = dict(parameters) assert(all(ch in '0123456789^+- ' for ch in parameters['modulus'])) @@ -208,9 +209,13 @@ def make_curve_parameters(parameters): parameters.get(op + '_code', None), nargs, sz) - for k in ('upper_bound_of_exponent', 'allowable_bit_widths', 'freeze_extra_allowable_bit_widths'): - if k not in replacements.keys(): - replacements[k] = 'None' + replacements['coef_div_modulus_raw'] = replacements.get('coef_div_modulus', '0') + for k, scope_string in (('upper_bound_of_exponent', ''), + ('allowable_bit_widths', '%nat'), + ('freeze_extra_allowable_bit_widths', '%nat'), + ('coef_div_modulus', '%nat'), + ('modinv_fuel', '%nat')): + replacements[k] = fix_option(nested_list_to_string(replacements.get(k, 'None')), scope_string=scope_string) for k in ('s', 'c', 'goldilocks'): replacements[k] = nested_list_to_string(replacements[k]) for k in ('extra_prove_mul_eq', 'extra_prove_square_eq'): @@ -232,7 +237,7 @@ Module Curve <: CurveParameters. Definition carry_chains : option (list (list nat)) := Eval vm_compute in %(carry_chains)s. Definition a24 : option Z := %(a24)s. - Definition coef_div_modulus : nat := %(coef_div_modulus)s%%nat. (* add %(coef_div_modulus)s*modulus before subtracting *) + Definition coef_div_modulus : option nat := %(coef_div_modulus)s. (* add %(coef_div_modulus_raw)s*modulus before subtracting *) Definition goldilocks : bool := %(goldilocks)s. @@ -245,6 +250,7 @@ Module Curve <: CurveParameters. Definition upper_bound_of_exponent : option (Z -> Z) := %(upper_bound_of_exponent)s. Definition allowable_bit_widths : option (list nat) := %(allowable_bit_widths)s. Definition freeze_extra_allowable_bit_widths : option (list nat) := %(freeze_extra_allowable_bit_widths)s. + Definition modinv_fuel : option nat := %(modinv_fuel)s. Ltac extra_prove_mul_eq := %(extra_prove_mul_eq)s. Ltac extra_prove_square_eq := %(extra_prove_square_eq)s. End Curve. @@ -257,12 +263,9 @@ Require Import %s.CurveParameters. Module Import T := MakeSynthesisTactics Curve. -Module P <: SynthesisPrePackage. - Definition Synthesis_package' : Synthesis_package'_Type. - Proof. make_synthesis_package (). Defined. - - Definition Synthesis_package - := Eval cbv [Synthesis_package' projT2] in projT2 Synthesis_package'. +Module P <: PrePackage. + Definition package : Tag.Context. + Proof. make_Synthesis_package (). Defined. End P. Module Export S := PackageSynthesis Curve P. @@ -282,6 +285,8 @@ Proof. Time synthesize_%(arg)s (). Show Ltac Profile. Time Defined. + +Print Assumptions %(arg)s. """ % {'prefix':prefix, 'arg':fearg[2:]} elif fearg in ('fesquare',): return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems. @@ -296,6 +301,8 @@ Proof. Time synthesize_square (). Show Ltac Profile. Time Defined. + +Print Assumptions square. """ % {'prefix':prefix} elif fearg in ('freeze',): return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems. @@ -310,11 +317,29 @@ Proof. Time synthesize_freeze (). Show Ltac Profile. Time Defined. + +Print Assumptions freeze. """ % {'prefix':prefix} + elif fearg in ('feopp',): + return r"""Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import %(prefix)s.Synthesis. + +(* TODO : change this to field once field isomorphism happens *) +Definition %(arg)s : + { %(arg)s : feBW -> feBW + | forall a, phiBW (%(arg)s a) = F.%(arg)s (phiBW a) }. +Proof. + Set Ltac Profiling. + Time synthesize_%(arg)s (). + Show Ltac Profile. +Time Defined. + +Print Assumptions %(arg)s. +""" % {'prefix':prefix, 'arg':fearg[2:]} elif fearg in ('ladderstep', 'xzladderstep'): return r"""Require Import Crypto.Arithmetic.Core. Require Import Crypto.Arithmetic.PrimeFieldTheorems. -Require Import Crypto.Specific.Framework.LadderstepSynthesisFramework. +Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Ladderstep. Require Import %(prefix)s.Synthesis. (* TODO : change this to field once field isomorphism happens *) @@ -334,6 +359,8 @@ Proof. synthesize_xzladderstep (). Show Ltac Profile. Time Defined. + +Print Assumptions xzladderstep. """ % {'prefix':prefix} else: print('ERROR: Unsupported operation: %s' % fearg) @@ -343,10 +370,12 @@ Time Defined. def make_display_arg(fearg, prefix): file_name = fearg function_name = fearg - if fearg in ('femul', 'fesub', 'feadd', 'fesquare'): + if fearg in ('femul', 'fesub', 'feadd', 'fesquare', 'feopp'): function_name = fearg[2:] elif fearg in ('freeze', 'xzladderstep'): pass + elif fearg in ('fenz',): + function_name = 'nonzero' elif fearg in ('ladderstep', ): function_name = 'xzladderstep' else: @@ -404,12 +433,16 @@ def main(*args): if open(fname, 'r').read() == outputs[k]: continue new_files.append(fname) - with open(fname, 'w') as f: - f.write(outputs[k]) + with io.open(fname, 'w', newline='\n') as f: + f.write(unicode(outputs[k])) if fname[-len('compiler.sh'):] == 'compiler.sh': mode = os.fstat(f.fileno()).st_mode mode |= 0o111 - os.fchmod(f.fileno(), mode & 0o7777) + mode &= 0o7777 + if 'fchmod' in os.__dict__.keys(): + os.fchmod(f.fileno(), mode) + else: + os.chmod(f.name, mode) if len(new_files) > 0: print('git add ' + ' '.join('"%s"' % i for i in new_files)) |