From a17f3ea44a09638f0d78428290a96ecce613ad65 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sun, 15 Oct 2017 22:12:34 -0400 Subject: Explicitly specify base This allows it to be something other than log2(m)/sz. After | File Name | Before || Change ------------------------------------------------------------------------------------------- 8m20.82s | Total | 8m37.82s || -0m17.00s ------------------------------------------------------------------------------------------- 1m59.42s | Specific/NISTP256/AMD64/femul | 2m19.09s || -0m19.67s 3m28.66s | Specific/X25519/C64/ladderstep | 3m28.02s || +0m00.63s 0m24.97s | Specific/X25519/C64/femul | 0m24.60s || +0m00.36s 0m24.08s | Specific/NISTP256/AMD64/fesub | 0m23.48s || +0m00.59s 0m22.00s | Specific/NISTP256/AMD64/feadd | 0m21.34s || +0m00.66s 0m20.34s | Specific/X25519/C64/freeze | 0m19.76s || +0m00.57s 0m19.85s | Specific/X25519/C64/fesquare | 0m19.93s || -0m00.07s 0m18.04s | Specific/NISTP256/AMD64/feopp | 0m17.69s || +0m00.34s 0m15.10s | Specific/NISTP256/AMD64/fenz | 0m15.37s || -0m00.26s 0m08.31s | Specific/NISTP256/AMD64/Synthesis | 0m08.24s || +0m00.07s 0m05.96s | Specific/X25519/C64/Synthesis | 0m06.25s || -0m00.29s 0m02.10s | Specific/Framework/ArithmeticSynthesis/Defaults | 0m02.14s || -0m00.04s 0m01.00s | Specific/Framework/SynthesisFramework | 0m01.03s || -0m00.03s 0m00.97s | Specific/Framework/ArithmeticSynthesis/Base | 0m01.02s || -0m00.05s 0m00.89s | Specific/Framework/ArithmeticSynthesis/Freeze | 0m00.84s || +0m00.05s 0m00.80s | Specific/Framework/ArithmeticSynthesis/Karatsuba | 0m00.81s || -0m00.01s 0m00.79s | Specific/Framework/ArithmeticSynthesis/MontgomeryPackage | 0m00.80s || -0m00.01s 0m00.76s | Specific/Framework/MontgomeryReificationTypesPackage | 0m00.75s || +0m00.01s 0m00.74s | Specific/Framework/ReificationTypesPackage | 0m00.77s || -0m00.03s 0m00.74s | Specific/Framework/ArithmeticSynthesis/BasePackage | 0m00.74s || +0m00.00s 0m00.73s | Specific/Framework/ArithmeticSynthesis/SquareFromMul | 0m00.70s || +0m00.03s 0m00.72s | Specific/Framework/ArithmeticSynthesis/DefaultsPackage | 0m00.68s || +0m00.03s 0m00.70s | Specific/Framework/ArithmeticSynthesis/LadderstepPackage | 0m00.72s || -0m00.02s 0m00.70s | Specific/Framework/ArithmeticSynthesis/FreezePackage | 0m00.77s || -0m00.07s 0m00.69s | Specific/Framework/ArithmeticSynthesis/KaratsubaPackage | 0m00.70s || -0m00.01s 0m00.42s | Specific/X25519/C64/CurveParameters | 0m00.38s || +0m00.03s 0m00.36s | Specific/Framework/CurveParameters | 0m00.32s || +0m00.03s 0m00.33s | Specific/Framework/RawCurveParameters | 0m00.29s || +0m00.04s 0m00.33s | Specific/Framework/CurveParametersPackage | 0m00.30s || +0m00.03s 0m00.32s | Specific/NISTP256/AMD64/CurveParameters | 0m00.30s || +0m00.02s --- src/Specific/Framework/ArithmeticSynthesis/Base.v | 82 ++++++++++++---------- .../Framework/ArithmeticSynthesis/BasePackage.v | 35 +++++---- .../Framework/ArithmeticSynthesis/Defaults.v | 61 ++++++++-------- .../ArithmeticSynthesis/DefaultsPackage.v | 63 ++++++++++------- .../Framework/ArithmeticSynthesis/Freeze.v | 25 +++---- .../Framework/ArithmeticSynthesis/FreezePackage.v | 5 +- src/Specific/Framework/CurveParameters.v | 18 ++++- src/Specific/Framework/CurveParametersPackage.v | 2 + src/Specific/Framework/RawCurveParameters.v | 5 ++ src/Specific/Framework/make_curve.py | 7 ++ 10 files changed, 181 insertions(+), 122 deletions(-) (limited to 'src/Specific/Framework') diff --git a/src/Specific/Framework/ArithmeticSynthesis/Base.v b/src/Specific/Framework/ArithmeticSynthesis/Base.v index 2bf9719f0..541069d94 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/Base.v +++ b/src/Specific/Framework/ArithmeticSynthesis/Base.v @@ -1,5 +1,7 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef. Require Import Coq.Lists.List. Import ListNotations. +Require Import Coq.micromega.Lia. +Require Import Coq.QArith.QArith_base. Require Import Crypto.Arithmetic.Core. Import B. Require Import Crypto.Specific.Framework.CurveParameters. Require Import Crypto.Specific.Framework.ArithmeticSynthesis.HelperTactics. @@ -30,15 +32,17 @@ Section wt. Local Coercion QArith_base.inject_Z : Z >-> Q. Local Coercion Z.of_nat : nat >-> Z. Local Coercion Z.pos : positive >-> Z. - Definition wt_gen (m : positive) (sz : nat) (i:nat) : Z := 2^Qceiling((Z.log2_up m/sz)*i). + Definition wt_gen (base : Q) (i:nat) : Z := 2^Qceiling(base*i). End wt. Section gen. - Context (m : positive) + Context (base : Q) + (m : positive) (sz : nat) - (coef_div_modulus : nat). + (coef_div_modulus : nat) + (base_pos : (1 <= base)%Q). - Local Notation wt := (wt_gen m sz). + Local Notation wt := (wt_gen base). Definition sz2' := ((sz * 2) - 1)%nat. @@ -50,69 +54,62 @@ Section gen. Lemma sz2'_nonzero (sz_nonzero : sz <> 0%nat) : sz2' <> 0%nat. - Proof. clear -sz_nonzero; cbv [sz2']; omega. Qed. + Proof using Type. clear -sz_nonzero; cbv [sz2']; omega. Qed. Local Ltac Q_cbv := - cbv [wt_gen Qround.Qceiling QArith_base.Qmult QArith_base.Qdiv QArith_base.inject_Z QArith_base.Qden QArith_base.Qnum QArith_base.Qopp Qround.Qfloor QArith_base.Qinv QArith_base.Qle Z.of_nat]. + cbv [wt_gen Qround.Qceiling QArith_base.Qmult QArith_base.Qdiv QArith_base.inject_Z QArith_base.Qden QArith_base.Qnum QArith_base.Qopp Qround.Qfloor QArith_base.Qinv QArith_base.Qle QArith_base.Qeq Z.of_nat] in *. Lemma wt_gen0_1 : wt 0 = 1. - Proof. + Proof using Type. Q_cbv; simpl. autorewrite with zsimplify_const; reflexivity. Qed. Lemma wt_gen_nonzero : forall i, wt i <> 0. - Proof. + Proof using base_pos. eapply pow_ceil_mul_nat_nonzero; [ omega | ]. - destruct sz; Q_cbv; - autorewrite with zsimplify_const; [ omega | ]. - apply Z.log2_up_nonneg. + destruct base; Q_cbv; lia. Qed. Lemma wt_gen_nonneg : forall i, 0 <= wt i. - Proof. apply pow_ceil_mul_nat_nonneg; omega. Qed. + Proof using Type. apply pow_ceil_mul_nat_nonneg; omega. Qed. Lemma wt_gen_pos : forall i, wt i > 0. - Proof. + Proof using base_pos. intro i; pose proof (wt_gen_nonzero i); pose proof (wt_gen_nonneg i). omega. Qed. Lemma wt_gen_multiples : forall i, wt (S i) mod (wt i) = 0. - Proof. - apply pow_ceil_mul_nat_multiples. - destruct sz; Q_cbv; autorewrite with zsimplify_const; - auto using Z.log2_up_nonneg, Z.le_refl. + Proof using base_pos. + apply pow_ceil_mul_nat_multiples; destruct base; Q_cbv; lia. Qed. Section divides. - Context (sz_nonzero : sz <> 0%nat) - (sz_small : Z.of_nat sz <= Z.log2_up (Z.pos m)). - Lemma wt_gen_divides : forall i, wt (S i) / wt i > 0. - Proof. + Proof using base_pos. apply pow_ceil_mul_nat_divide; [ omega | ]. - destruct sz; Q_cbv; autorewrite with zsimplify_const; [ congruence | ]. - rewrite Pos.mul_1_l; assumption. + destruct base; Q_cbv; lia. Qed. + Lemma wt_gen_divides' : forall i, wt (S i) / wt i <> 0. - Proof. + Proof using base_pos. symmetry; apply Z.lt_neq, Z.gt_lt_iff, wt_gen_divides; assumption. Qed. Lemma wt_gen_div_bound : forall i, wt (S i) / wt i <= wt 1. - Proof. + Proof using base_pos. intro; etransitivity. eapply pow_ceil_mul_nat_divide_upperbound; [ omega | ]. - all:destruct sz; Q_cbv; autorewrite with zsimplify_const; + all:destruct base; Q_cbv; autorewrite with zsimplify_const; rewrite ?Pos.mul_1_l, ?Pos.mul_1_r; try assumption; omega. Qed. Lemma wt_gen_divides_chain carry_chain : forall i (H:In i carry_chain), wt (S i) / wt i <> 0. - Proof. intros i ?; apply wt_gen_divides'; assumption. Qed. + Proof using base_pos. intros i ?; apply wt_gen_divides'; assumption. Qed. Lemma wt_gen_divides_chains carry_chains @@ -123,7 +120,7 @@ Section gen. (fun carry_chain => forall i (H:In i carry_chain), wt (S i) / wt i <> 0) carry_chains). - Proof. + Proof using base_pos. induction carry_chains as [|carry_chain carry_chains IHcarry_chains]; constructor; eauto using wt_gen_divides_chain. Qed. @@ -137,9 +134,8 @@ Section gen. end) (Positional.zeros sz) coef_div_modulus. Lemma coef_mod' - (sz_le_log2_m : Z.of_nat sz <= Z.log2_up (Z.pos m)) : mod_eq m (Positional.eval (n:=sz) wt coef') 0. - Proof. + Proof using base_pos. cbv [coef' m_enc']. remember (Positional.zeros sz) as v eqn:Hv. assert (Hv' : mod_eq m (Positional.eval wt v) 0) @@ -163,8 +159,8 @@ Section gen. Qed. End gen. -Ltac pose_wt m sz wt := - let v := (eval cbv [wt_gen] in (wt_gen m sz)) in +Ltac pose_wt base wt := + let v := (eval cbv [wt_gen] in (wt_gen base)) in cache_term v wt. Ltac pose_sz2 sz sz2 := @@ -193,24 +189,31 @@ Ltac pose_sz_le_log2_m sz m sz_le_log2_m := ltac:(vm_decide_no_check) sz_le_log2_m. +Ltac pose_base_pos base base_pos := + cache_proof_with_type_by + ((1 <= base)%Q) + ltac:(vm_decide_no_check) + base_pos. + Ltac pose_m_correct m s c m_correct := cache_proof_with_type_by (Z.pos m = s - Associational.eval c) ltac:(vm_decide_no_check) m_correct. -Ltac pose_m_enc sz m m_enc := - let v := (eval vm_compute in (m_enc' m sz)) in +Ltac pose_m_enc base m sz m_enc := + let v := (eval vm_compute in (m_enc' base m sz)) in let v := (eval compute in v) in (* compute away the type arguments *) cache_term v m_enc. -Ltac pose_coef sz m coef_div_modulus coef := (* subtraction coefficient *) - let v := (eval vm_compute in (coef' m sz coef_div_modulus)) in + +Ltac pose_coef base m sz coef_div_modulus coef := (* subtraction coefficient *) + let v := (eval vm_compute in (coef' base m sz coef_div_modulus)) in cache_term v coef. -Ltac pose_coef_mod sz wt m coef coef_div_modulus sz_le_log2_m coef_mod := +Ltac pose_coef_mod wt coef base m sz coef_div_modulus base_pos coef_mod := cache_proof_with_type_by (mod_eq m (Positional.eval (n:=sz) wt coef) 0) - ltac:(vm_cast_no_check (coef_mod' m sz coef_div_modulus sz_le_log2_m)) + ltac:(vm_cast_no_check (coef_mod' base m sz coef_div_modulus base_pos)) coef_mod. Ltac pose_sz_nonzero sz sz_nonzero := cache_proof_with_type_by @@ -238,6 +241,7 @@ Ltac pose_wt_divides' wt wt_divides wt_divides' := (forall i, wt (S i) / wt i <> 0) ltac:(apply wt_gen_divides'; vm_decide_no_check) wt_divides'. + Ltac pose_wt_divides_chains wt carry_chains wt_divides_chains := let T := (eval cbv [carry_chains List.fold_right List.map] in (List.fold_right @@ -249,7 +253,7 @@ Ltac pose_wt_divides_chains wt carry_chains wt_divides_chains := carry_chains))) in cache_proof_with_type_by T - ltac:(refine (@wt_gen_divides_chains _ _ _ _ carry_chains); vm_decide_no_check) + ltac:(refine (@wt_gen_divides_chains _ _ carry_chains); vm_decide_no_check) wt_divides_chains. Ltac pose_wt_pos wt wt_pos := diff --git a/src/Specific/Framework/ArithmeticSynthesis/BasePackage.v b/src/Specific/Framework/ArithmeticSynthesis/BasePackage.v index d5b4e567c..80e032fb8 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/BasePackage.v +++ b/src/Specific/Framework/ArithmeticSynthesis/BasePackage.v @@ -5,7 +5,7 @@ Require Import Crypto.Specific.Framework.Packages. Require Import Crypto.Util.TagList. Module TAG. - Inductive tags := r | m | wt | sz2 | half_sz | half_sz_nonzero | s_nonzero | sz_le_log2_m | m_correct | m_enc | coef | coef_mod | sz_nonzero | wt_nonzero | wt_nonneg | wt_divides | wt_divides' | wt_divides_chains | wt_pos | wt_multiples | c_small | m_enc_bounded | m_correct_wt. + Inductive tags := r | m | wt | sz2 | half_sz | half_sz_nonzero | s_nonzero | sz_le_log2_m | base_pos | m_correct | m_enc | coef | coef_mod | sz_nonzero | wt_nonzero | wt_nonneg | wt_divides | wt_divides' | wt_divides_chains | wt_pos | wt_multiples | c_small | m_enc_bounded | m_correct_wt. End TAG. Ltac add_r pkg := @@ -22,10 +22,9 @@ Ltac add_m pkg := Tag.update pkg TAG.m m. Ltac add_wt pkg := - let m := Tag.get pkg TAG.m in - let sz := Tag.get pkg TAG.sz in + let base := Tag.get pkg TAG.base in let wt := fresh "wt" in - let wt := pose_wt m sz wt in + let wt := pose_wt base wt in Tag.update pkg TAG.wt wt. Ltac add_sz2 pkg := @@ -59,6 +58,12 @@ Ltac add_sz_le_log2_m pkg := let sz_le_log2_m := pose_sz_le_log2_m sz m sz_le_log2_m in Tag.update pkg TAG.sz_le_log2_m sz_le_log2_m. +Ltac add_base_pos pkg := + let base := Tag.get pkg TAG.base in + let base_pos := fresh "base_pos" in + let base_pos := pose_base_pos base base_pos in + Tag.update pkg TAG.base_pos base_pos. + Ltac add_m_correct pkg := let m := Tag.get pkg TAG.m in let s := Tag.get pkg TAG.s in @@ -68,29 +73,32 @@ Ltac add_m_correct pkg := Tag.update pkg TAG.m_correct m_correct. Ltac add_m_enc pkg := - let sz := Tag.get pkg TAG.sz in + let base := Tag.get pkg TAG.base in let m := Tag.get pkg TAG.m in + let sz := Tag.get pkg TAG.sz in let m_enc := fresh "m_enc" in - let m_enc := pose_m_enc sz m m_enc in + let m_enc := pose_m_enc base m sz m_enc in Tag.update pkg TAG.m_enc m_enc. Ltac add_coef pkg := - let sz := Tag.get pkg TAG.sz in + let base := Tag.get pkg TAG.base in let m := Tag.get pkg TAG.m in + let sz := Tag.get pkg TAG.sz in let coef_div_modulus := Tag.get pkg TAG.coef_div_modulus in let coef := fresh "coef" in - let coef := pose_coef sz m coef_div_modulus coef in + let coef := pose_coef base m sz coef_div_modulus coef in Tag.update pkg TAG.coef coef. Ltac add_coef_mod pkg := - let sz := Tag.get pkg TAG.sz in let wt := Tag.get pkg TAG.wt in - let m := Tag.get pkg TAG.m in let coef := Tag.get pkg TAG.coef in + let base := Tag.get pkg TAG.base in + let m := Tag.get pkg TAG.m in + let sz := Tag.get pkg TAG.sz in let coef_div_modulus := Tag.get pkg TAG.coef_div_modulus in - let sz_le_log2_m := Tag.get pkg TAG.sz_le_log2_m in + let base_pos := Tag.get pkg TAG.base_pos in let coef_mod := fresh "coef_mod" in - let coef_mod := pose_coef_mod sz wt m coef coef_div_modulus sz_le_log2_m coef_mod in + let coef_mod := pose_coef_mod wt coef base m sz coef_div_modulus base_pos coef_mod in Tag.update pkg TAG.coef_mod coef_mod. Ltac add_sz_nonzero pkg := @@ -177,6 +185,7 @@ Ltac add_Base_package pkg := let pkg := add_half_sz_nonzero pkg in let pkg := add_s_nonzero pkg in let pkg := add_sz_le_log2_m pkg in + let pkg := add_base_pos pkg in let pkg := add_m_correct pkg in let pkg := add_m_enc pkg in let pkg := add_coef pkg in @@ -214,6 +223,8 @@ Module MakeBasePackage (PKG : PrePackage). Notation s_nonzero := (ltac:(let v := get_s_nonzero () in exact v)) (only parsing). Ltac get_sz_le_log2_m _ := get TAG.sz_le_log2_m. Notation sz_le_log2_m := (ltac:(let v := get_sz_le_log2_m () in exact v)) (only parsing). + Ltac get_base_pos _ := get TAG.base_pos. + Notation base_pos := (ltac:(let v := get_base_pos () in exact v)) (only parsing). Ltac get_m_correct _ := get TAG.m_correct. Notation m_correct := (ltac:(let v := get_m_correct () in exact v)) (only parsing). Ltac get_m_enc _ := get TAG.m_enc. diff --git a/src/Specific/Framework/ArithmeticSynthesis/Defaults.v b/src/Specific/Framework/ArithmeticSynthesis/Defaults.v index 42c49f46d..36e02bb0e 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/Defaults.v +++ b/src/Specific/Framework/ArithmeticSynthesis/Defaults.v @@ -1,4 +1,5 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef. +Require Import Coq.QArith.QArith_base. Require Import Coq.Lists.List. Import ListNotations. Require Import Crypto.Arithmetic.CoreUnfolder. Require Import Crypto.Arithmetic.Core. Import B. @@ -38,8 +39,8 @@ Local Ltac solve_constant_local_sig := | [ |- { c : Z^?sz | Positional.Fdecode (m:=?M) ?wt c = ?v } ] => (exists (Positional.encode (n:=sz) (modulo:=modulo) (div:=div) wt (F.to_Z (m:=M) v))); lazymatch goal with - | [ sz_nonzero : sz <> 0%nat, sz_le_log2_m : Z.of_nat sz <= Z.log2_up (Z.pos M) |- _ ] - => clear -sz_nonzero sz_le_log2_m + | [ sz_nonzero : sz <> 0%nat, base_pos : (1 <= _)%Q |- _ ] + => clear -base_pos sz_nonzero end end; abstract ( @@ -50,6 +51,7 @@ Local Ltac solve_constant_local_sig := Section gen. Context (m : positive) + (base : Q) (sz : nat) (s : Z) (c : list limb) @@ -59,12 +61,13 @@ Section gen. (square_code : option (Z^sz -> Z^sz)) (sz_nonzero : sz <> 0%nat) (s_nonzero : s <> 0) + (base_pos : (1 <= base)%Q) (sz_le_log2_m : Z.of_nat sz <= Z.log2_up (Z.pos m)). - Local Notation wt := (wt_gen m sz). + Local Notation wt := (wt_gen base). Local Notation sz2 := (sz2' sz). - Local Notation wt_divides' := (wt_gen_divides' m sz sz_nonzero sz_le_log2_m). - Local Notation wt_nonzero := (wt_gen_nonzero m sz). + Local Notation wt_divides' := (wt_gen_divides' base base_pos). + Local Notation wt_nonzero := (wt_gen_nonzero base base_pos). (* side condition needs cbv [Positional.mul_cps Positional.reduce_cps]. *) Context (mul_code_correct @@ -99,9 +102,9 @@ Section gen. Proof. let a := fresh "a" in eexists; cbv beta zeta; intros a. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). pose proof wt_nonzero; pose proof div_mod. - pose proof (wt_gen_divides_chains m sz sz_nonzero sz_le_log2_m carry_chains). + pose proof (wt_gen_divides_chains base base_pos carry_chains). pose proof wt_divides'. let x := constr:(chained_carries' sz wt s c a carry_chains) in presolve_op_F constr:(wt) x; @@ -148,7 +151,7 @@ Section gen. Proof. eexists; cbv beta zeta; intros a b. pose proof wt_nonzero. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). let x := constr:( Positional.add_cps (n := sz) wt a b id) in presolve_op_F constr:(wt) x; @@ -166,7 +169,7 @@ Section gen. let b := fresh "b" in eexists; cbv beta zeta; intros a b. pose proof wt_nonzero. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). let x := constr:( Positional.sub_cps (n:=sz) (coef := coef) wt a b id) in presolve_op_F constr:(wt) x; @@ -182,7 +185,7 @@ Section gen. Proof. eexists; cbv beta zeta; intros a. pose proof wt_nonzero. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). let x := constr:( Positional.opp_cps (n:=sz) (coef := coef) wt a id) in presolve_op_F constr:(wt) x; @@ -198,7 +201,7 @@ Section gen. Proof. eexists; cbv beta zeta; intros a b. pose proof wt_nonzero. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). pose proof (sz2'_nonzero sz sz_nonzero). let x := constr:( Positional.mul_cps (n:=sz) (m:=sz2) wt a b @@ -224,7 +227,7 @@ Section gen. Proof. eexists; cbv beta zeta; intros a. pose proof wt_nonzero. - pose proof (wt_gen0_1 m sz). + pose proof (wt_gen0_1 base). pose proof (sz2'_nonzero sz sz_nonzero). let x := constr:( Positional.mul_cps (n:=sz) (m:=sz2) wt a a @@ -263,7 +266,7 @@ Section gen. (Positional.Fdecode_Fencode_id (sz_nonzero := sz_nonzero) (div_mod := div_mod) - wt (wt_gen0_1 m sz) wt_nonzero wt_divides') + wt (wt_gen0_1 base) wt_nonzero wt_divides') (Positional.eq_Feq_iff wt) _ _ _); lazymatch goal with @@ -324,25 +327,25 @@ Ltac cache_sig_with_type_by_existing_sig ty existing_sig id := ltac:(fun _ => cbv [carry_sig' constant_sig' zero_sig' one_sig' add_sig' sub_sig' mul_sig' square_sig' opp_sig']) ty existing_sig id. -Ltac pose_carry_sig sz m wt s c carry_chains carry_sig := +Ltac pose_carry_sig wt m base sz s c carry_chains carry_sig := cache_sig_with_type_by_existing_sig {carry : (Z^sz -> Z^sz)%type | forall a : Z^sz, let eval := Positional.Fdecode (m := m) wt in eval (carry a) = eval a} - (carry_sig' m sz s c carry_chains) + (carry_sig' m base sz s c carry_chains) carry_sig. -Ltac pose_zero_sig sz m wt sz_nonzero sz_le_log2_m zero_sig := +Ltac pose_zero_sig wt m base sz sz_nonzero base_pos zero_sig := cache_vm_sig_with_type { zero : Z^sz | Positional.Fdecode (m:=m) wt zero = 0%F} - (zero_sig' m sz sz_nonzero sz_le_log2_m) + (zero_sig' m base sz sz_nonzero base_pos) zero_sig. -Ltac pose_one_sig sz m wt sz_nonzero sz_le_log2_m one_sig := +Ltac pose_one_sig wt m base sz sz_nonzero base_pos one_sig := cache_vm_sig_with_type { one : Z^sz | Positional.Fdecode (m:=m) wt one = 1%F} - (one_sig' m sz sz_nonzero sz_le_log2_m) + (one_sig' m base sz sz_nonzero base_pos) one_sig. Ltac pose_a24_sig sz m wt a24 a24_sig := @@ -351,49 +354,49 @@ Ltac pose_a24_sig sz m wt a24 a24_sig := solve_constant_sig a24_sig. -Ltac pose_add_sig sz m wt sz_nonzero add_sig := +Ltac pose_add_sig wt m base sz add_sig := cache_sig_with_type_by_existing_sig { add : (Z^sz -> Z^sz -> Z^sz)%type | forall a b : Z^sz, let eval := Positional.Fdecode (m:=m) wt in eval (add a b) = (eval a + eval b)%F } - (add_sig' m sz sz_nonzero) + (add_sig' m base sz) add_sig. -Ltac pose_sub_sig sz m wt coef sub_sig := +Ltac pose_sub_sig wt m base sz coef sub_sig := cache_sig_with_type_by_existing_sig {sub : (Z^sz -> Z^sz -> Z^sz)%type | forall a b : Z^sz, let eval := Positional.Fdecode (m:=m) wt in eval (sub a b) = (eval a - eval b)%F} - (sub_sig' m sz coef) + (sub_sig' m base sz coef) sub_sig. -Ltac pose_opp_sig sz m wt coef opp_sig := +Ltac pose_opp_sig wt m base sz coef opp_sig := cache_sig_with_type_by_existing_sig {opp : (Z^sz -> Z^sz)%type | forall a : Z^sz, let eval := Positional.Fdecode (m := m) wt in eval (opp a) = F.opp (eval a)} - (opp_sig' m sz coef) + (opp_sig' m base sz coef) opp_sig. -Ltac pose_mul_sig sz m wt s c mul_code sz_nonzero s_nonzero mul_code_correct mul_sig := +Ltac pose_mul_sig wt m base sz s c mul_code sz_nonzero s_nonzero base_pos mul_code_correct mul_sig := cache_sig_with_type_by_existing_sig {mul : (Z^sz -> Z^sz -> Z^sz)%type | forall a b : Z^sz, let eval := Positional.Fdecode (m := m) wt in eval (mul a b) = (eval a * eval b)%F} - (mul_sig' m sz s c mul_code sz_nonzero s_nonzero mul_code_correct) + (mul_sig' m base sz s c mul_code sz_nonzero s_nonzero base_pos mul_code_correct) mul_sig. -Ltac pose_square_sig sz m wt s c square_code sz_nonzero s_nonzero square_code_correct square_sig := +Ltac pose_square_sig wt m base sz s c square_code sz_nonzero s_nonzero base_pos square_code_correct square_sig := cache_sig_with_type_by_existing_sig {square : (Z^sz -> Z^sz)%type | forall a : Z^sz, let eval := Positional.Fdecode (m := m) wt in eval (square a) = (eval a * eval a)%F} - (square_sig' m sz s c square_code sz_nonzero s_nonzero square_code_correct) + (square_sig' m base sz s c square_code sz_nonzero s_nonzero base_pos square_code_correct) square_sig. Ltac pose_ring sz m wt wt_divides' sz_nonzero wt_nonzero zero_sig one_sig opp_sig add_sig sub_sig mul_sig ring := diff --git a/src/Specific/Framework/ArithmeticSynthesis/DefaultsPackage.v b/src/Specific/Framework/ArithmeticSynthesis/DefaultsPackage.v index 4a037f34a..10d6e42ed 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/DefaultsPackage.v +++ b/src/Specific/Framework/ArithmeticSynthesis/DefaultsPackage.v @@ -39,38 +39,41 @@ Ltac add_carry_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.carry_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let s := Tag.get pkg TAG.s in let c := Tag.get pkg TAG.c in let carry_chains := Tag.get pkg TAG.carry_chains in let carry_sig := fresh "carry_sig" in - let carry_sig := pose_carry_sig sz m wt s c carry_chains carry_sig in + let carry_sig := pose_carry_sig wt m base sz s c carry_chains carry_sig in constr:(carry_sig)). Ltac add_zero_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.zero_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let sz_nonzero := Tag.get pkg TAG.sz_nonzero in - let sz_le_log2_m := Tag.get pkg TAG.sz_le_log2_m in + let base_pos := Tag.get pkg TAG.base_pos in let zero_sig := fresh "zero_sig" in - let zero_sig := pose_zero_sig sz m wt sz_nonzero sz_le_log2_m zero_sig in + let zero_sig := pose_zero_sig wt m base sz sz_nonzero base_pos zero_sig in constr:(zero_sig)). Ltac add_one_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.one_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let sz_nonzero := Tag.get pkg TAG.sz_nonzero in - let sz_le_log2_m := Tag.get pkg TAG.sz_le_log2_m in + let base_pos := Tag.get pkg TAG.base_pos in let one_sig := fresh "one_sig" in - let one_sig := pose_one_sig sz m wt sz_nonzero sz_le_log2_m one_sig in + let one_sig := pose_one_sig wt m base sz sz_nonzero base_pos one_sig in constr:(one_sig)). Ltac add_a24_sig pkg := Tag.update_by_tac_if_not_exists @@ -87,66 +90,72 @@ Ltac add_add_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.add_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in - let sz_nonzero := Tag.get pkg TAG.sz_nonzero in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let add_sig := fresh "add_sig" in - let add_sig := pose_add_sig sz m wt sz_nonzero add_sig in + let add_sig := pose_add_sig wt m base sz add_sig in constr:(add_sig)). Ltac add_sub_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.sub_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let coef := Tag.get pkg TAG.coef in let sub_sig := fresh "sub_sig" in - let sub_sig := pose_sub_sig sz m wt coef sub_sig in + let sub_sig := pose_sub_sig wt m base sz coef sub_sig in constr:(sub_sig)). Ltac add_opp_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.opp_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let coef := Tag.get pkg TAG.coef in let opp_sig := fresh "opp_sig" in - let opp_sig := pose_opp_sig sz m wt coef opp_sig in + let opp_sig := pose_opp_sig wt m base sz coef opp_sig in constr:(opp_sig)). Ltac add_mul_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.mul_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let s := Tag.get pkg TAG.s in let c := Tag.get pkg TAG.c in let mul_code := Tag.get pkg TAG.mul_code in let sz_nonzero := Tag.get pkg TAG.sz_nonzero in let s_nonzero := Tag.get pkg TAG.s_nonzero in + let base_pos := Tag.get pkg TAG.base_pos in let mul_code_correct := Tag.get pkg TAG.mul_code_correct in let mul_sig := fresh "mul_sig" in - let mul_sig := pose_mul_sig sz m wt s c mul_code sz_nonzero s_nonzero mul_code_correct mul_sig in + let mul_sig := pose_mul_sig wt m base sz s c mul_code sz_nonzero s_nonzero base_pos mul_code_correct mul_sig in constr:(mul_sig)). Ltac add_square_sig pkg := Tag.update_by_tac_if_not_exists pkg TAG.square_sig - ltac:(fun _ => let sz := Tag.get pkg TAG.sz in + ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in - let wt := Tag.get pkg TAG.wt in + let base := Tag.get pkg TAG.base in + let sz := Tag.get pkg TAG.sz in let s := Tag.get pkg TAG.s in let c := Tag.get pkg TAG.c in let square_code := Tag.get pkg TAG.square_code in let sz_nonzero := Tag.get pkg TAG.sz_nonzero in let s_nonzero := Tag.get pkg TAG.s_nonzero in + let base_pos := Tag.get pkg TAG.base_pos in let square_code_correct := Tag.get pkg TAG.square_code_correct in let square_sig := fresh "square_sig" in - let square_sig := pose_square_sig sz m wt s c square_code sz_nonzero s_nonzero square_code_correct square_sig in + let square_sig := pose_square_sig wt m base sz s c square_code sz_nonzero s_nonzero base_pos square_code_correct square_sig in constr:(square_sig)). Ltac add_ring pkg := Tag.update_by_tac_if_not_exists diff --git a/src/Specific/Framework/ArithmeticSynthesis/Freeze.v b/src/Specific/Framework/ArithmeticSynthesis/Freeze.v index a32f9e220..44d284d8e 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/Freeze.v +++ b/src/Specific/Framework/ArithmeticSynthesis/Freeze.v @@ -1,4 +1,5 @@ Require Import Coq.ZArith.ZArith Coq.ZArith.BinIntDef. +Require Import Coq.QArith.QArith_base. Require Import Coq.Lists.List. Import ListNotations. Require Import Crypto.Arithmetic.CoreUnfolder. Require Import Crypto.Arithmetic.Saturated.CoreUnfolder. @@ -26,17 +27,18 @@ Ltac freeze_preunfold := Section gen. Context (m : positive) + (base : Q) (sz : nat) (c : list limb) (bitwidth : Z) (m_enc : Z^sz) - (sz_nonzero : sz <> 0%nat) - (sz_le_log2_m : Z.of_nat sz <= Z.log2_up (Z.pos m)). + (base_pos : (1 <= base)%Q) + (sz_nonzero : sz <> 0%nat). - Local Notation wt := (wt_gen m sz). + Local Notation wt := (wt_gen base). Local Notation sz2 := (sz2' sz). - Local Notation wt_divides' := (wt_gen_divides' m sz sz_nonzero sz_le_log2_m). - Local Notation wt_nonzero := (wt_gen_nonzero m sz). + Local Notation wt_divides' := (wt_gen_divides' base base_pos). + Local Notation wt_nonzero := (wt_gen_nonzero base base_pos). Context (c_small : 0 < Associational.eval c < wt sz) (m_enc_bounded : Tuple.map (BinInt.Z.land (Z.ones bitwidth)) m_enc = m_enc) @@ -51,12 +53,11 @@ Section gen. eval (freeze a) = eval a }. Proof. eexists; cbv beta zeta; (intros a ?). - pose proof wt_nonzero; pose proof (wt_gen_pos m sz). - pose proof (wt_gen0_1 m sz). - pose proof div_mod; pose proof (wt_gen_divides m sz sz_nonzero sz_le_log2_m). - pose proof (wt_gen_multiples m sz). + pose proof wt_nonzero; pose proof (wt_gen_pos base base_pos). + pose proof (wt_gen0_1 base). + pose proof div_mod; pose proof (wt_gen_divides base base_pos). + pose proof (wt_gen_multiples base base_pos). pose proof div_correct; pose proof modulo_correct. - pose proof (wt_gen_divides_chain m sz sz_nonzero sz_le_log2_m). let x := constr:(freeze (n:=sz) wt (Z.ones bitwidth) m_enc a) in presolve_op_F constr:(wt) x; [ autorewrite with pattern_runtime; reflexivity | ]. @@ -65,7 +66,7 @@ Section gen. Defined. End gen. -Ltac pose_freeze_sig wt m sz c bitwidth m_enc sz_nonzero sz_le_log2_m freeze_sig := +Ltac pose_freeze_sig wt m base sz c bitwidth m_enc base_pos sz_nonzero freeze_sig := cache_sig_with_type_by_existing_sig_helper ltac:(fun _ => cbv [freeze_sig']) {freeze : (Z^sz -> Z^sz)%type | @@ -73,5 +74,5 @@ Ltac pose_freeze_sig wt m sz c bitwidth m_enc sz_nonzero sz_le_log2_m freeze_sig (0 <= Positional.eval wt a < 2 * Z.pos m)-> let eval := Positional.Fdecode (m := m) wt in eval (freeze a) = eval a} - (freeze_sig' m sz c bitwidth m_enc sz_nonzero sz_le_log2_m) + (freeze_sig' m base sz c bitwidth m_enc base_pos sz_nonzero) freeze_sig. diff --git a/src/Specific/Framework/ArithmeticSynthesis/FreezePackage.v b/src/Specific/Framework/ArithmeticSynthesis/FreezePackage.v index 885bfde09..1a4b405b6 100644 --- a/src/Specific/Framework/ArithmeticSynthesis/FreezePackage.v +++ b/src/Specific/Framework/ArithmeticSynthesis/FreezePackage.v @@ -15,14 +15,15 @@ Ltac add_freeze_sig pkg := TAG.freeze_sig ltac:(fun _ => let wt := Tag.get pkg TAG.wt in let m := Tag.get pkg TAG.m in + let base := Tag.get pkg TAG.base in let sz := Tag.get pkg TAG.sz in let c := Tag.get pkg TAG.c in let bitwidth := Tag.get pkg TAG.bitwidth in let m_enc := Tag.get pkg TAG.m_enc in + let base_pos := Tag.get pkg TAG.base_pos in let sz_nonzero := Tag.get pkg TAG.sz_nonzero in - let sz_le_log2_m := Tag.get pkg TAG.sz_le_log2_m in let freeze_sig := fresh "freeze_sig" in - let freeze_sig := pose_freeze_sig wt m sz c bitwidth m_enc sz_nonzero sz_le_log2_m freeze_sig in + let freeze_sig := pose_freeze_sig wt m base sz c bitwidth m_enc base_pos sz_nonzero freeze_sig in constr:(freeze_sig)). Ltac add_Freeze_package pkg := let pkg := add_freeze_sig pkg in diff --git a/src/Specific/Framework/CurveParameters.v b/src/Specific/Framework/CurveParameters.v index 8911dccfc..f8d245a67 100644 --- a/src/Specific/Framework/CurveParameters.v +++ b/src/Specific/Framework/CurveParameters.v @@ -11,7 +11,7 @@ Local Set Primitive Projections. Module Export Notations := RawCurveParameters.Notations. Module TAG. (* namespacing *) - Inductive tags := CP | sz | bitwidth | s | c | carry_chains | a24 | coef_div_modulus | goldilocks | montgomery | upper_bound_of_exponent | allowable_bit_widths | freeze_allowable_bit_widths | modinv_fuel | mul_code | square_code. + Inductive tags := CP | sz | base | bitwidth | s | c | carry_chains | a24 | coef_div_modulus | goldilocks | montgomery | upper_bound_of_exponent | allowable_bit_widths | freeze_allowable_bit_widths | modinv_fuel | mul_code | square_code. End TAG. Module Export CurveParameters. @@ -32,6 +32,7 @@ Module Export CurveParameters. Record CurveParameters := { sz : nat; + base : Q; bitwidth : Z; s : Z; c : list limb; @@ -52,6 +53,7 @@ Module Export CurveParameters. Declare Reduction cbv_CurveParameters := cbv [sz + base bitwidth s c @@ -90,6 +92,7 @@ Module Export CurveParameters. : CurveParameters := Eval cbv_RawCurveParameters in let sz := RawCurveParameters.sz CP in + let base := RawCurveParameters.base CP in let bitwidth := RawCurveParameters.bitwidth CP in let montgomery := RawCurveParameters.montgomery CP in let s := RawCurveParameters.s CP in @@ -109,6 +112,7 @@ Module Export CurveParameters. {| sz := sz; + base := base; bitwidth := bitwidth; s := s; c := c; @@ -146,6 +150,7 @@ Module Export CurveParameters. lazymatch v with | ({| sz := ?sz'; + base := ?base'; bitwidth := ?bitwidth'; s := ?s'; c := ?c'; @@ -162,6 +167,7 @@ Module Export CurveParameters. modinv_fuel := ?modinv_fuel' |}) => let sz' := do_compute sz' in + let base' := do_compute base' in let bitwidth' := do_compute bitwidth' in let carry_chains' := do_compute carry_chains' in let goldilocks' := do_compute goldilocks' in @@ -171,6 +177,7 @@ Module Export CurveParameters. let modinv_fuel' := do_compute modinv_fuel' in constr:({| sz := sz'; + base := base'; bitwidth := bitwidth'; s := s'; c := c'; @@ -194,6 +201,8 @@ Module Export CurveParameters. Ltac internal_pose_of_CP CP proj id := let P_proj := (eval cbv_CurveParameters in (proj CP)) in cache_term P_proj id. + Ltac pose_base CP base := + internal_pose_of_CP CP CurveParameters.base base. Ltac pose_sz CP sz := internal_pose_of_CP CP CurveParameters.sz sz. Ltac pose_bitwidth CP bitwidth := @@ -226,6 +235,12 @@ Module Export CurveParameters. internal_pose_of_CP CP CurveParameters.square_code square_code. (* Everything below this line autogenerated by remake_packages.py *) + Ltac add_base pkg := + let CP := Tag.get pkg TAG.CP in + let base := fresh "base" in + let base := pose_base CP base in + Tag.update pkg TAG.base base. + Ltac add_sz pkg := let CP := Tag.get pkg TAG.CP in let sz := fresh "sz" in @@ -317,6 +332,7 @@ Module Export CurveParameters. Tag.update pkg TAG.square_code square_code. Ltac add_CurveParameters_package pkg := + let pkg := add_base pkg in let pkg := add_sz pkg in let pkg := add_bitwidth pkg in let pkg := add_s pkg in diff --git a/src/Specific/Framework/CurveParametersPackage.v b/src/Specific/Framework/CurveParametersPackage.v index 75ef1f7e7..458c1d4ea 100644 --- a/src/Specific/Framework/CurveParametersPackage.v +++ b/src/Specific/Framework/CurveParametersPackage.v @@ -23,6 +23,8 @@ Ltac if_montgomery pkg tac_true tac_false arg := Module MakeCurveParametersPackage (PKG : PrePackage). Module Import MakeCurveParametersPackageInternal := MakePackageBase PKG. + Ltac get_base _ := get TAG.base. + Notation base := (ltac:(let v := get_base () in exact v)) (only parsing). Ltac get_sz _ := get TAG.sz. Notation sz := (ltac:(let v := get_sz () in exact v)) (only parsing). Ltac get_bitwidth _ := get TAG.bitwidth. diff --git a/src/Specific/Framework/RawCurveParameters.v b/src/Specific/Framework/RawCurveParameters.v index 8adff1f69..b84089eaf 100644 --- a/src/Specific/Framework/RawCurveParameters.v +++ b/src/Specific/Framework/RawCurveParameters.v @@ -1,9 +1,12 @@ +Require Export Coq.QArith.QArith_base. Require Export Coq.ZArith.BinInt. Require Export Coq.Lists.List. Require Export Crypto.Util.ZUtil.Notations. Require Crypto.Util.Tuple. Local Set Primitive Projections. +Coercion QArith_base.inject_Z : Z >-> Q. +Coercion Z.of_nat : nat >-> Z. Module Export Notations. (* import/export tracking *) Export ListNotations. @@ -18,6 +21,7 @@ End Notations. Record CurveParameters := { sz : nat; + base : Q; bitwidth : Z; s : Z; c : list limb; @@ -42,6 +46,7 @@ Record CurveParameters := Declare Reduction cbv_RawCurveParameters := cbv [sz + base bitwidth s c diff --git a/src/Specific/Framework/make_curve.py b/src/Specific/Framework/make_curve.py index 70b88069d..516b67868 100755 --- a/src/Specific/Framework/make_curve.py +++ b/src/Specific/Framework/make_curve.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import with_statement import json, sys, os, math, re, shutil, io +from fractions import Fraction def compute_bitwidth(base): return 2**int(math.ceil(math.log(base, 2))) @@ -11,6 +12,10 @@ def default_carry_chains(sz): def compute_s(modulus_str): base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', r'([0-9\^ +\*-]*)$')), modulus_str).groups() return '%s^%s' % (base, exp) +def reformat_base(base): + if '.' not in base: return base + int_part, frac_part = base.split('.') + return int_part + ' + ' + str(Fraction('.' + frac_part)) def compute_c(modulus_str): base, exp, rest = re.match(r'\s*'.join(('^', '(2)', r'\^', '([0-9]+)', r'([0-9\^ +\*-]*)$')), modulus_str).groups() if rest.strip() == '': return [] @@ -193,6 +198,7 @@ def make_curve_parameters(parameters): assert(all(ch in '0123456789^+- ' for ch in parameters['modulus'])) modulus = eval(parameters['modulus'].replace('^', '**')) base = float(parameters['base']) + replacements['reformatted_base'] = reformat_base(parameters['base']) replacements['bitwidth'] = parameters.get('bitwidth', str(compute_bitwidth(base))) bitwidth = int(replacements['bitwidth']) replacements['sz'] = parameters.get('sz', str(compute_sz(modulus, base))) @@ -242,6 +248,7 @@ Base: %(base)s Definition curve : CurveParameters := {| sz := %(sz)s%%nat; + base := %(reformatted_base)s; bitwidth := %(bitwidth)s; s := %(s)s; c := %(c)s; -- cgit v1.2.3