From 4d6d788ff04ec7b08d522099b6841fd16d6fae5c Mon Sep 17 00:00:00 2001 From: jadep Date: Sat, 5 Nov 2016 18:32:57 -0400 Subject: Automatically generate code for field operations with different primes --- src/SpecificGen/2213_32.json | 7 + src/SpecificGen/2519_32.json | 7 + src/SpecificGen/25519_32.json | 7 + src/SpecificGen/25519_64.json | 7 + src/SpecificGen/41417_32.json | 7 + src/SpecificGen/5211_32.json | 7 + src/SpecificGen/GF2213_32.v | 694 +++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GF2519_32.v | 676 ++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GF25519_32.v | 694 +++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GF25519_64.v | 694 +++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GF41417_32.v | 676 ++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GF5211_32.v | 676 ++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GFtemplate3mod4 | 676 ++++++++++++++++++++++++++++++++++++++ src/SpecificGen/GFtemplate5mod8 | 694 +++++++++++++++++++++++++++++++++++++++ src/SpecificGen/README.md | 5 + src/SpecificGen/fill_template.py | 39 +++ 16 files changed, 5566 insertions(+) create mode 100644 src/SpecificGen/2213_32.json create mode 100644 src/SpecificGen/2519_32.json create mode 100644 src/SpecificGen/25519_32.json create mode 100644 src/SpecificGen/25519_64.json create mode 100644 src/SpecificGen/41417_32.json create mode 100644 src/SpecificGen/5211_32.json create mode 100644 src/SpecificGen/GF2213_32.v create mode 100644 src/SpecificGen/GF2519_32.v create mode 100644 src/SpecificGen/GF25519_32.v create mode 100644 src/SpecificGen/GF25519_64.v create mode 100644 src/SpecificGen/GF41417_32.v create mode 100644 src/SpecificGen/GF5211_32.v create mode 100644 src/SpecificGen/GFtemplate3mod4 create mode 100644 src/SpecificGen/GFtemplate5mod8 create mode 100644 src/SpecificGen/README.md create mode 100644 src/SpecificGen/fill_template.py (limited to 'src/SpecificGen') diff --git a/src/SpecificGen/2213_32.json b/src/SpecificGen/2213_32.json new file mode 100644 index 000000000..fe000da25 --- /dev/null +++ b/src/SpecificGen/2213_32.json @@ -0,0 +1,7 @@ +{ + "k" : 221, + "c" : 3, + "n" : 8, + "w" : 32, + "ch" : "[0;1;2;3;4;5;6;7;0;1]" +} diff --git a/src/SpecificGen/2519_32.json b/src/SpecificGen/2519_32.json new file mode 100644 index 000000000..f2aabdb70 --- /dev/null +++ b/src/SpecificGen/2519_32.json @@ -0,0 +1,7 @@ +{ + "k" : 251, + "c" : 9, + "n" : 10, + "w" : 32, + "ch" : "[0;1;2;3;4;5;6;7;8;9;0;1]" +} diff --git a/src/SpecificGen/25519_32.json b/src/SpecificGen/25519_32.json new file mode 100644 index 000000000..383c03531 --- /dev/null +++ b/src/SpecificGen/25519_32.json @@ -0,0 +1,7 @@ +{ + "k" : 255, + "c" : 19, + "n" : 10, + "w" : 32, + "ch" : "[0;1;2;3;4;5;6;7;8;9;0;1]" +} diff --git a/src/SpecificGen/25519_64.json b/src/SpecificGen/25519_64.json new file mode 100644 index 000000000..b4acfda31 --- /dev/null +++ b/src/SpecificGen/25519_64.json @@ -0,0 +1,7 @@ +{ + "k" : 255, + "c" : 19, + "n" : 5, + "w" : 64, + "ch" : "[0;1;2;3;4;0;1]" +} diff --git a/src/SpecificGen/41417_32.json b/src/SpecificGen/41417_32.json new file mode 100644 index 000000000..0a55e4c0b --- /dev/null +++ b/src/SpecificGen/41417_32.json @@ -0,0 +1,7 @@ +{ + "k" : 414, + "c" : 17, + "n" : 18, + "w" : 32, + "ch" : "[0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;0;1]" +} diff --git a/src/SpecificGen/5211_32.json b/src/SpecificGen/5211_32.json new file mode 100644 index 000000000..dc43b67b7 --- /dev/null +++ b/src/SpecificGen/5211_32.json @@ -0,0 +1,7 @@ +{ + "k" : 521, + "c" : 1, + "n" : 20, + "w" : 32, + "ch" : "[0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18;19;0;1]" +} diff --git a/src/SpecificGen/GF2213_32.v b/src/SpecificGen/GF2213_32.v new file mode 100644 index 000000000..fe45e3423 --- /dev/null +++ b/src/SpecificGen/GF2213_32.v @@ -0,0 +1,694 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^221 - 3. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 32%Z. +Definition int_width := 32%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 8%nat 221. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 32 6 ++ 29 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 8 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 8 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7) + (g0, g1, g2, g3, g4, g5, g6, g7)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +(* Now that we have [pow], we can compute sqrt of -1 for use + in sqrt function (this is not needed unless the prime is + 5 mod 8) *) +Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). + +Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). +Proof. + cbv [rep]. + apply F.eq_to_Z_iff. + vm_compute. + reflexivity. +Qed. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field2213 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 8 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 8 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 8 (freeze_opt (int_width := int_width) c_ (to_list 8 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 8 (freeze_opt (int_width := int_width) c_ (to_list 8 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 8 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7) + (g0, g1, g2, g3, g4, g5, g6, g7)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 8 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7) + (g0, g1, g2, g3, g4, g5, g6, g7)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (powx powx_squared f : fe) : + { f' : fe | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powx powx_squared f}. +Proof. + eexists. + cbv [sqrt_5mod8_opt int_width]. + apply Proper_Let_In_nd_changebody; [reflexivity|intro]. + set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. + reflexivity. +Defined. + +Definition sqrt (powx powx_squared f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powx powx_squared f). + +Definition sqrt_correct (powx powx_squared f : fe) + : sqrt powx powx_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powx powx_squared f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powx powx_squared f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GF2519_32.v b/src/SpecificGen/GF2519_32.v new file mode 100644 index 000000000..ebbf7a24d --- /dev/null +++ b/src/SpecificGen/GF2519_32.v @@ -0,0 +1,676 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^251 - 9. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 32%Z. +Definition int_width := 32%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 10%nat 251. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 32 7 ++ 27 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 4 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 4 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field2519 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 10 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 10 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 10 (freeze_opt (int_width := int_width) c_ (to_list 10 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 10 (freeze_opt (int_width := int_width) c_ (to_list 10 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 10 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 10 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (f : fe) : + { f' : fe | f' = sqrt_3mod4_opt k_ c_ one_ f}. +Proof. + eexists. + cbv [sqrt_3mod4_opt int_width]. + rewrite <- pow_correct. + reflexivity. +Defined. + +Definition sqrt (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). + +Definition sqrt_correct (f : fe) + : sqrt f = sqrt_3mod4_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GF25519_32.v b/src/SpecificGen/GF25519_32.v new file mode 100644 index 000000000..80e51c2d7 --- /dev/null +++ b/src/SpecificGen/GF25519_32.v @@ -0,0 +1,694 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^255 - 19. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 32%Z. +Definition int_width := 32%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 10%nat 255. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 32 7 ++ 31 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 8 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 8 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +(* Now that we have [pow], we can compute sqrt of -1 for use + in sqrt function (this is not needed unless the prime is + 5 mod 8) *) +Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). + +Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). +Proof. + cbv [rep]. + apply F.eq_to_Z_iff. + vm_compute. + reflexivity. +Qed. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field25519 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 10 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 10 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 10 (freeze_opt (int_width := int_width) c_ (to_list 10 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 10 (freeze_opt (int_width := int_width) c_ (to_list 10 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 10 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 10 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (powx powx_squared f : fe) : + { f' : fe | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powx powx_squared f}. +Proof. + eexists. + cbv [sqrt_5mod8_opt int_width]. + apply Proper_Let_In_nd_changebody; [reflexivity|intro]. + set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. + reflexivity. +Defined. + +Definition sqrt (powx powx_squared f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powx powx_squared f). + +Definition sqrt_correct (powx powx_squared f : fe) + : sqrt powx powx_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powx powx_squared f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powx powx_squared f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GF25519_64.v b/src/SpecificGen/GF25519_64.v new file mode 100644 index 000000000..134090d0e --- /dev/null +++ b/src/SpecificGen/GF25519_64.v @@ -0,0 +1,694 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^255 - 19. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 64%Z. +Definition int_width := 64%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 5%nat 255. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 64 3 ++ 63 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 8 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 8 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4) := f in + let '(g0, g1, g2, g3, g4) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4) + (g0, g1, g2, g3, g4)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +(* Now that we have [pow], we can compute sqrt of -1 for use + in sqrt function (this is not needed unless the prime is + 5 mod 8) *) +Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). + +Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). +Proof. + cbv [rep]. + apply F.eq_to_Z_iff. + vm_compute. + reflexivity. +Qed. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field25519 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 5 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 5 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 5 (freeze_opt (int_width := int_width) c_ (to_list 5 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 5 (freeze_opt (int_width := int_width) c_ (to_list 5 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 5 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4) := f in + let '(g0, g1, g2, g3, g4) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4) + (g0, g1, g2, g3, g4)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 5 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4) := f in + let '(g0, g1, g2, g3, g4) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4) + (g0, g1, g2, g3, g4)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (powx powx_squared f : fe) : + { f' : fe | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powx powx_squared f}. +Proof. + eexists. + cbv [sqrt_5mod8_opt int_width]. + apply Proper_Let_In_nd_changebody; [reflexivity|intro]. + set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. + reflexivity. +Defined. + +Definition sqrt (powx powx_squared f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powx powx_squared f). + +Definition sqrt_correct (powx powx_squared f : fe) + : sqrt powx powx_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powx powx_squared f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powx powx_squared f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GF41417_32.v b/src/SpecificGen/GF41417_32.v new file mode 100644 index 000000000..7e7f5ace8 --- /dev/null +++ b/src/SpecificGen/GF41417_32.v @@ -0,0 +1,676 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^414 - 17. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 32%Z. +Definition int_width := 32%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 18%nat 414. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 32 12 ++ 30 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 4 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 4 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field41417 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 18 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 18 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 18 (freeze_opt (int_width := int_width) c_ (to_list 18 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 18 (freeze_opt (int_width := int_width) c_ (to_list 18 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 18 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 18 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (f : fe) : + { f' : fe | f' = sqrt_3mod4_opt k_ c_ one_ f}. +Proof. + eexists. + cbv [sqrt_3mod4_opt int_width]. + rewrite <- pow_correct. + reflexivity. +Defined. + +Definition sqrt (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). + +Definition sqrt_correct (f : fe) + : sqrt f = sqrt_3mod4_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GF5211_32.v b/src/SpecificGen/GF5211_32.v new file mode 100644 index 000000000..f04fe80f2 --- /dev/null +++ b/src/SpecificGen/GF5211_32.v @@ -0,0 +1,676 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^521 - 1. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := 32%Z. +Definition int_width := 32%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus 20%nat 521. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev [0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18;19;0;1])%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat 32 16 ++ 9 :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 4 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 4 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19) := g in + proj1_sig (mul_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19)). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field5211 : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list 20 f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + proj1_sig (ge_modulus_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19)). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list 20 f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 20 (freeze_opt (int_width := int_width) c_ (to_list 20 f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + proj1_sig (freeze_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19)). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 20 (freeze_opt (int_width := int_width) c_ (to_list 20 f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z 20 Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19) := g in + proj1_sig (fieldwiseb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19)). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z 20 Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + let '(g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19) := g in + proj1_sig (eqb_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) + (g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15, g16, g17, g18, g19)). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (f : fe) : + { f' : fe | f' = sqrt_3mod4_opt k_ c_ one_ f}. +Proof. + eexists. + cbv [sqrt_3mod4_opt int_width]. + rewrite <- pow_correct. + reflexivity. +Defined. + +Definition sqrt (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). + +Definition sqrt_correct (f : fe) + : sqrt f = sqrt_3mod4_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19) := f in + proj1_sig (pack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16, f17, f18, f19)). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16) := f in + proj1_sig (unpack_simpl_sig (f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15, f16)). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GFtemplate3mod4 b/src/SpecificGen/GFtemplate3mod4 new file mode 100644 index 000000000..6c25c91ca --- /dev/null +++ b/src/SpecificGen/GFtemplate3mod4 @@ -0,0 +1,676 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^{{{k}}} - {{{c}}}. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := {{{w}}}%Z. +Definition int_width := {{{w}}}%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus {{{n}}}%nat {{{k}}}. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev {{{ch}}})%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat {{{w}}} {{{kdivw}}} ++ {{{kmodw}}} :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 4 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 4 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (mul_simpl_sig {{{enum f}}} + {{{enum g}}}). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field{{{k}}}{{{c}}} : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list {{{n}}} f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '{{{enum f}}} := f in + proj1_sig (ge_modulus_sig {{{enum f}}}). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list {{{n}}} f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 {{{n}}} (freeze_opt (int_width := int_width) c_ (to_list {{{n}}} f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '{{{enum f}}} := f in + proj1_sig (freeze_sig {{{enum f}}}). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 {{{n}}} (freeze_opt (int_width := int_width) c_ (to_list {{{n}}} f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z {{{n}}} Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (fieldwiseb_sig {{{enum f}}} + {{{enum g}}}). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z {{{n}}} Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (eqb_sig {{{enum f}}} + {{{enum g}}}). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (f : fe) : + { f' : fe | f' = sqrt_3mod4_opt k_ c_ one_ f}. +Proof. + eexists. + cbv [sqrt_3mod4_opt int_width]. + rewrite <- pow_correct. + reflexivity. +Defined. + +Definition sqrt (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig f). + +Definition sqrt_correct (f : fe) + : sqrt f = sqrt_3mod4_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '{{{enum f}}} := f in + proj1_sig (pack_simpl_sig {{{enum f}}}). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '{{{enumw f}}} := f in + proj1_sig (unpack_simpl_sig {{{enumw f}}}). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/GFtemplate5mod8 b/src/SpecificGen/GFtemplate5mod8 new file mode 100644 index 000000000..4b8d4a9e4 --- /dev/null +++ b/src/SpecificGen/GFtemplate5mod8 @@ -0,0 +1,694 @@ +Require Import Crypto.BaseSystem. +Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystem. +Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs. +Require Import Crypto.ModularArithmetic.ModularBaseSystemOpt. +Require Import Coq.Lists.List Crypto.Util.ListUtil. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Decidable. +Require Import Crypto.Algebra. +Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Local Open Scope Z. + +(* BEGIN precomputation. *) + +Definition modulus : Z := Eval compute in 2^{{{k}}} - {{{c}}}. +Lemma prime_modulus : prime modulus. Admitted. +Definition freeze_input_bound := {{{w}}}%Z. +Definition int_width := {{{w}}}%Z. + +Instance params : PseudoMersenneBaseParams modulus. + construct_params prime_modulus {{{n}}}%nat {{{k}}}. +Defined. + +Definition length_fe := Eval compute in length limb_widths. +Definition fe := Eval compute in (tuple Z length_fe). + +Definition mul2modulus : fe := + Eval compute in (from_list_default 0%Z (length limb_widths) (construct_mul2modulus params)). + +Instance subCoeff : SubtractionCoefficient. + apply Build_SubtractionCoefficient with (coeff := mul2modulus). + vm_decide. +Defined. + +Instance carryChain : CarryChain limb_widths. + apply Build_CarryChain with (carry_chain := (rev {{{ch}}})%nat). + intros. + repeat (destruct H as [|H]; [subst; vm_compute; repeat constructor | ]). + contradiction H. +Defined. + +Definition freezePreconditions : FreezePreconditions freeze_input_bound int_width. +Proof. + constructor; compute_preconditions. +Defined. + +(* Wire format for [pack] and [unpack] *) +Definition wire_widths := Eval compute in (repeat {{{w}}} {{{kdivw}}} ++ {{{kmodw}}} :: nil). + +Definition wire_digits := Eval compute in (tuple Z (length wire_widths)). + +Lemma wire_widths_nonneg : forall w, In w wire_widths -> 0 <= w. +Proof. + intros. + repeat (destruct H as [|H]; [subst; vm_compute; congruence | ]). + contradiction H. +Qed. + +Lemma bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn wire_widths (length wire_widths). +Proof. + reflexivity. +Qed. + +Lemma modulus_gt_2 : 2 < modulus. Proof. cbv; congruence. Qed. + +(* Temporarily, we'll use addition chains equivalent to double-and-add. This is pending + finding the real, more optimal chains from previous work. *) +Fixpoint pow2Chain'' p (pow2_index acc_index : nat) chain_acc : list (nat * nat) := + match p with + | xI p' => pow2Chain'' p' 1 0 + (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + | xO p' => pow2Chain'' p' 0 (S acc_index) + (chain_acc ++ (pow2_index, pow2_index)::nil) + | xH => (chain_acc ++ (pow2_index, pow2_index) :: (0%nat, S acc_index) :: nil) + end. + +Fixpoint pow2Chain' p index := + match p with + | xI p' => pow2Chain'' p' 0 0 (repeat (0,0)%nat index) + | xO p' => pow2Chain' p' (S index) + | xH => repeat (0,0)%nat index + end. + +Definition pow2_chain p := + match p with + | xH => nil + | _ => pow2Chain' p 0 + end. + +Definition invChain := Eval compute in pow2_chain (Z.to_pos (modulus - 2)). + +Instance inv_ec : ExponentiationChain (modulus - 2). + apply Build_ExponentiationChain with (chain := invChain). + reflexivity. +Defined. + +(* Note : use caution copying square root code to other primes. The (modulus / 8 + 1) chains are + for primes that are 5 mod 8; if the prime is 3 mod 4 then use (modulus / 4 + 1). *) +Definition sqrtChain := Eval compute in pow2_chain (Z.to_pos (modulus / 8 + 1)). + +Instance sqrt_ec : ExponentiationChain (modulus / 8 + 1). + apply Build_ExponentiationChain with (chain := sqrtChain). + reflexivity. +Defined. + +Arguments chain {_ _ _} _. + +(* END precomputation *) + +(* Precompute constants *) +Definition k_ := Eval compute in k. +Definition k_subst : k = k_ := eq_refl k_. + +Definition c_ := Eval compute in c. +Definition c_subst : c = c_ := eq_refl c_. + +Definition one_ := Eval compute in one. +Definition one_subst : one = one_ := eq_refl one_. + +Definition zero_ := Eval compute in zero. +Definition zero_subst : zero = zero_ := eq_refl zero_. + +Definition modulus_digits_ := Eval compute in ModularBaseSystemList.modulus_digits. +Definition modulus_digits_subst : ModularBaseSystemList.modulus_digits = modulus_digits_ := eq_refl modulus_digits_. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb Z.leb ModularBaseSystemListZOperations.neg ModularBaseSystemListZOperations.cmovl ModularBaseSystemListZOperations.cmovne. + +Definition app_n2 {T} (f : wire_digits) (P : wire_digits -> T) : T. +Proof. + cbv [wire_digits] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n2_correct {T} f (P : wire_digits -> T) : app_n2 f P = P f. +Proof. + intros. + cbv [wire_digits] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition app_n {T} (f : fe) (P : fe -> T) : T. +Proof. + cbv [fe] in *. + set (f0 := f). + repeat (let g := fresh "g" in destruct f as [f g]). + apply P. + apply f0. +Defined. + +Definition app_n_correct {T} f (P : fe -> T) : app_n f P = P f. +Proof. + intros. + cbv [fe] in *. + repeat match goal with [p : (_*Z)%type |- _ ] => destruct p end. + reflexivity. +Qed. + +Definition appify2 {T} (op : fe -> fe -> T) (f g : fe) := + app_n f (fun f0 => (app_n g (fun g0 => op f0 g0))). + +Lemma appify2_correct : forall {T} op f g, @appify2 T op f g = op f g. +Proof. + intros. cbv [appify2]. + etransitivity; apply app_n_correct. +Qed. + +Definition uncurry_unop_fe {T} (op : fe -> T) + := Eval compute in Tuple.uncurry (n:=length_fe) op. +Definition curry_unop_fe {T} op : fe -> T + := Eval compute in fun f => app_n f (Tuple.curry (n:=length_fe) op). +Definition uncurry_binop_fe {T} (op : fe -> fe -> T) + := Eval compute in uncurry_unop_fe (fun f => uncurry_unop_fe (op f)). +Definition curry_binop_fe {T} op : fe -> fe -> T + := Eval compute in appify2 (fun f => curry_unop_fe (curry_unop_fe op f)). + +Definition uncurry_unop_wire_digits {T} (op : wire_digits -> T) + := Eval compute in Tuple.uncurry (n:=length wire_widths) op. +Definition curry_unop_wire_digits {T} op : wire_digits -> T + := Eval compute in fun f => app_n2 f (Tuple.curry (n:=length wire_widths) op). + +Definition add_sig (f g : fe) : + { fg : fe | fg = add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj1_sig (add_sig f g). + +Definition add_correct (f g : fe) + : add f g = add_opt f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (add_sig f g). + +Definition carry_add_sig (f g : fe) : + { fg : fe | fg = carry_add_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_add (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj1_sig (carry_add_sig f g). + +Definition carry_add_correct (f g : fe) + : carry_add f g = carry_add_opt f g := + Eval cbv beta iota delta [proj1_sig carry_add_sig] in + proj2_sig (carry_add_sig f g). + +Definition sub_sig (f g : fe) : + { fg : fe | fg = sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + reflexivity. +Defined. + +Definition sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj1_sig (sub_sig f g). + +Definition sub_correct (f g : fe) + : sub f g = sub_opt f g := + Eval cbv beta iota delta [proj1_sig sub_sig] in + proj2_sig (sub_sig f g). + +Definition carry_sub_sig (f g : fe) : + { fg : fe | fg = carry_sub_opt f g}. +Proof. + eexists. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition carry_sub (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj1_sig (carry_sub_sig f g). + +Definition carry_sub_correct (f g : fe) + : carry_sub f g = carry_sub_opt f g := + Eval cbv beta iota delta [proj1_sig carry_sub_sig] in + proj2_sig (carry_sub_sig f g). + +(* For multiplication, we add another layer of definition so that we can + rewrite under the [let] binders. *) +Definition mul_simpl_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. (* N.B. The slow part of this is computing with [Z_div_opt]. + It would be much faster if we could take advantage of + the form of [base_from_limb_widths] when doing + division, so we could do subtraction instead. *) + autorewrite with zsimplify_fast. + reflexivity. +Defined. + +Definition mul_simpl (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_simpl_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (mul_simpl_sig {{{enum f}}} + {{{enum g}}}). + +Definition mul_simpl_correct (f g : fe) + : mul_simpl f g = carry_mul_opt k_ c_ f g. +Proof. + pose proof (proj2_sig (mul_simpl_sig f g)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition mul_sig (f g : fe) : + { fg : fe | fg = carry_mul_opt k_ c_ f g}. +Proof. + eexists. + rewrite <-mul_simpl_correct. + rewrite <-(@appify2_correct fe). + cbv. + autorewrite with zsimplify_fast zsimplify_Z_to_pos; cbv. + autorewrite with zsimplify_Z_to_pos; cbv. + reflexivity. +Defined. + +Definition mul (f g : fe) : fe := + Eval cbv beta iota delta [proj1_sig mul_sig] in + proj1_sig (mul_sig f g). + +Definition mul_correct (f g : fe) + : mul f g = carry_mul_opt k_ c_ f g := + Eval cbv beta iota delta [proj1_sig add_sig] in + proj2_sig (mul_sig f g). + +Definition opp_sig (f : fe) : + { g : fe | g = opp_opt f }. +Proof. + eexists. + cbv [opp_opt]. + rewrite <-sub_correct. + rewrite zero_subst. + cbv [sub]. + reflexivity. +Defined. + +Definition opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig opp_sig] in proj1_sig (opp_sig f). + +Definition opp_correct (f : fe) + : opp f = opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (opp_sig f). + +Definition carry_opp_sig (f : fe) : + { g : fe | g = carry_opp_opt f }. +Proof. + eexists. + cbv [carry_opp_opt]. + rewrite <-carry_sub_correct. + rewrite zero_subst. + cbv [carry_sub]. + reflexivity. +Defined. + +Definition carry_opp (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig carry_opp_sig] in proj1_sig (carry_opp_sig f). + +Definition carry_opp_correct (f : fe) + : carry_opp f = carry_opp_opt f + := Eval cbv beta iota delta [proj2_sig add_sig] in proj2_sig (carry_opp_sig f). + +Definition pow (f : fe) chain := fold_chain_opt one_ mul chain [f]. + +Lemma pow_correct (f : fe) : forall chain, pow f chain = pow_opt k_ c_ one_ f chain. +Proof. + cbv [pow pow_opt]; intros. + rewrite !fold_chain_opt_correct. + apply Proper_fold_chain; try reflexivity. + intros; subst; apply mul_correct. +Qed. + +(* Now that we have [pow], we can compute sqrt of -1 for use + in sqrt function (this is not needed unless the prime is + 5 mod 8) *) +Local Transparent Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition sqrt_m1 := Eval vm_compute in (pow (encode (F.of_Z _ 2)) (pow2_chain (Z.to_pos ((modulus - 1) / 4)))). + +Lemma sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F). +Proof. + cbv [rep]. + apply F.eq_to_Z_iff. + vm_compute. + reflexivity. +Qed. + +Local Opaque Z.shiftr Z.shiftl Z.land Z.mul Z.add Z.sub Z.lor Let_In Z.eqb Z.ltb andb. + +Definition inv_sig (f : fe) : + { g : fe | g = inv_opt k_ c_ one_ f }. +Proof. + eexists; cbv [inv_opt]. + rewrite <-pow_correct. + cbv - [mul]. + reflexivity. +Defined. + +Definition inv (f : fe) : fe + := Eval cbv beta iota delta [proj1_sig inv_sig] in proj1_sig (inv_sig f). + +Definition inv_correct (f : fe) + : inv f = inv_opt k_ c_ one_ f + := Eval cbv beta iota delta [proj2_sig inv_sig] in proj2_sig (inv_sig f). + +Definition mbs_field := modular_base_system_field modulus_gt_2. + +Import Morphisms. + +Local Existing Instance prime_modulus. + +Lemma field_and_homomorphisms + : @field fe eq zero_ one_ opp add sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ add mul encode + /\ @Ring.is_homomorphism + fe eq one_ add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite opp_correct, opp_opt_correct; apply opp_rep; reflexivity. } + { intros; rewrite add_correct, add_opt_correct; apply add_rep; reflexivity. } + { intros; rewrite sub_correct, sub_opt_correct; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition field{{{k}}}{{{c}}} : @field fe eq zero_ one_ opp add sub mul inv div := proj1 field_and_homomorphisms. + +Lemma carry_field_and_homomorphisms + : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div + /\ @Ring.is_homomorphism + (F modulus) Logic.eq F.one F.add F.mul + fe eq one_ carry_add mul encode + /\ @Ring.is_homomorphism + fe eq one_ carry_add mul + (F modulus) Logic.eq F.one F.add F.mul + decode. +Proof. + eapply @Field.field_and_homomorphism_from_redundant_representation. + { exact (F.field_modulo _). } + { apply encode_rep. } + { reflexivity. } + { reflexivity. } + { reflexivity. } + { intros; rewrite carry_opp_correct, carry_opp_opt_correct, carry_opp_rep; apply opp_rep; reflexivity. } + { intros; rewrite carry_add_correct, carry_add_opt_correct, carry_add_rep; apply add_rep; reflexivity. } + { intros; rewrite carry_sub_correct, carry_sub_opt_correct, carry_sub_rep; apply sub_rep; reflexivity. } + { intros; rewrite mul_correct, carry_mul_opt_correct by reflexivity; apply carry_mul_rep; reflexivity. } + { intros; rewrite inv_correct, inv_opt_correct by reflexivity; apply inv_rep; reflexivity. } + { intros; apply encode_rep. } +Qed. + +Definition carry_field : @field fe eq zero_ one_ carry_opp carry_add carry_sub mul inv div := proj1 carry_field_and_homomorphisms. + +Lemma homomorphism_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one add mul encode. +Proof. apply field_and_homomorphisms. Qed. + +Lemma homomorphism_F_decode + : @Ring.is_homomorphism fe eq one add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply field_and_homomorphisms. Qed. + + +Lemma homomorphism_carry_F_encode + : @Ring.is_homomorphism (F modulus) Logic.eq F.one F.add F.mul fe eq one carry_add mul encode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Lemma homomorphism_carry_F_decode + : @Ring.is_homomorphism fe eq one carry_add mul (F modulus) Logic.eq F.one F.add F.mul decode. +Proof. apply carry_field_and_homomorphisms. Qed. + +Definition ge_modulus_sig (f : fe) : + { b : Z | b = ge_modulus_opt (to_list {{{n}}} f) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [ge_modulus_opt]. + rewrite !modulus_digits_subst. + cbv. + reflexivity. +Defined. + +Definition ge_modulus (f : fe) : Z := + Eval cbv beta iota delta [proj1_sig ge_modulus_sig] in + let '{{{enum f}}} := f in + proj1_sig (ge_modulus_sig {{{enum f}}}). + +Definition ge_modulus_correct (f : fe) : + ge_modulus f = ge_modulus_opt (to_list {{{n}}} f). +Proof. + pose proof (proj2_sig (ge_modulus_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition freeze_sig (f : fe) : + { f' : fe | f' = from_list_default 0 {{{n}}} (freeze_opt (int_width := int_width) c_ (to_list {{{n}}} f)) }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists; cbv [freeze_opt int_width]. + cbv [to_list to_list']. + cbv [conditional_subtract_modulus_opt]. + rewrite !modulus_digits_subst. + cbv - [from_list_default]. + (* TODO(jgross,jadep): use Reflective linearization here? *) + repeat ( + set_evars; rewrite app_Let_In_nd; subst_evars; + eapply Proper_Let_In_nd_changebody; [reflexivity|intro]). + cbv [from_list_default from_list_default']. + reflexivity. +Defined. + +Definition freeze (f : fe) : fe := + Eval cbv beta iota delta [proj1_sig freeze_sig] in + let '{{{enum f}}} := f in + proj1_sig (freeze_sig {{{enum f}}}). + +Definition freeze_correct (f : fe) + : freeze f = from_list_default 0 {{{n}}} (freeze_opt (int_width := int_width) c_ (to_list {{{n}}} f)). +Proof. + pose proof (proj2_sig (freeze_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Defined. + +Definition fieldwiseb_sig (f g : fe) : + { b | b = @fieldwiseb Z Z {{{n}}} Z.eqb f g }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv. + reflexivity. +Defined. + +Definition fieldwiseb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig fieldwiseb_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (fieldwiseb_sig {{{enum f}}} + {{{enum g}}}). + +Lemma fieldwiseb_correct (f g : fe) + : fieldwiseb f g = @Tuple.fieldwiseb Z Z {{{n}}} Z.eqb f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (fieldwiseb_sig f' g')). +Qed. + +Definition eqb_sig (f g : fe) : + { b | b = eqb int_width f g }. +Proof. + cbv [eqb]. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [ModularBaseSystem.freeze int_width]. + rewrite <-!from_list_default_eq with (d := 0). + rewrite <-!(freeze_opt_correct c_) by auto using length_to_list. + rewrite <-!freeze_correct. + rewrite <-fieldwiseb_correct. + reflexivity. +Defined. + +Definition eqb (f g : fe) : bool + := Eval cbv beta iota delta [proj1_sig eqb_sig] in + let '{{{enum f}}} := f in + let '{{{enum g}}} := g in + proj1_sig (eqb_sig {{{enum f}}} + {{{enum g}}}). + +Lemma eqb_correct (f g : fe) + : eqb f g = ModularBaseSystem.eqb int_width f g. +Proof. + set (f' := f); set (g' := g). + hnf in f, g; destruct_head' prod. + exact (proj2_sig (eqb_sig f' g')). +Qed. + +Definition sqrt_sig (powx powx_squared f : fe) : + { f' : fe | f' = sqrt_5mod8_opt (int_width := int_width) k_ c_ sqrt_m1 powx powx_squared f}. +Proof. + eexists. + cbv [sqrt_5mod8_opt int_width]. + apply Proper_Let_In_nd_changebody; [reflexivity|intro]. + set_evars. rewrite <-!mul_correct, <-eqb_correct. subst_evars. + reflexivity. +Defined. + +Definition sqrt (powx powx_squared f : fe) : fe + := Eval cbv beta iota delta [proj1_sig sqrt_sig] in proj1_sig (sqrt_sig powx powx_squared f). + +Definition sqrt_correct (powx powx_squared f : fe) + : sqrt powx powx_squared f = sqrt_5mod8_opt k_ c_ sqrt_m1 powx powx_squared f + := Eval cbv beta iota delta [proj2_sig sqrt_sig] in proj2_sig (sqrt_sig powx powx_squared f). + +Definition pack_simpl_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [pack_opt]. + repeat (rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition pack_simpl (f : fe) := + Eval cbv beta iota delta [proj1_sig pack_simpl_sig] in + let '{{{enum f}}} := f in + proj1_sig (pack_simpl_sig {{{enum f}}}). + +Definition pack_simpl_correct (f : fe) + : pack_simpl f = pack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (pack_simpl_sig f)). + cbv [fe] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition pack_sig (f : fe) : + { f' | f' = pack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-pack_simpl_correct. + rewrite <-(@app_n_correct wire_digits). + cbv. + reflexivity. +Defined. + +Definition pack (f : fe) : wire_digits := + Eval cbv beta iota delta [proj1_sig pack_sig] in proj1_sig (pack_sig f). + +Definition pack_correct (f : fe) + : pack f = pack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (pack_sig f). + +Definition unpack_simpl_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + eexists. + cbv [unpack_opt]. + repeat ( + rewrite <-convert'_opt_correct; + cbv - [from_list_default_opt Conversion.convert']). + repeat progress rewrite ?Z.shiftl_0_r, ?Z.shiftr_0_r, ?Z.land_0_l, ?Z.lor_0_l, ?Z.land_same_r. + cbv [from_list_default_opt]. + reflexivity. +Defined. + +Definition unpack_simpl (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_simpl_sig] in + let '{{{enumw f}}} := f in + proj1_sig (unpack_simpl_sig {{{enumw f}}}). + +Definition unpack_simpl_correct (f : wire_digits) + : unpack_simpl f = unpack_opt params wire_widths_nonneg bits_eq f. +Proof. + pose proof (proj2_sig (unpack_simpl_sig f)). + cbv [wire_digits] in *. + repeat match goal with p : (_ * Z)%type |- _ => destruct p end. + assumption. +Qed. + +Definition unpack_sig (f : wire_digits) : + { f' | f' = unpack_opt params wire_widths_nonneg bits_eq f }. +Proof. + eexists. + rewrite <-unpack_simpl_correct. + rewrite <-(@app_n2_correct fe). + cbv. + reflexivity. +Defined. + +Definition unpack (f : wire_digits) : fe := + Eval cbv beta iota delta [proj1_sig unpack_sig] in proj1_sig (unpack_sig f). + +Definition unpack_correct (f : wire_digits) + : unpack f = unpack_opt params wire_widths_nonneg bits_eq f + := Eval cbv beta iota delta [proj2_sig pack_sig] in proj2_sig (unpack_sig f). diff --git a/src/SpecificGen/README.md b/src/SpecificGen/README.md new file mode 100644 index 000000000..165e755d5 --- /dev/null +++ b/src/SpecificGen/README.md @@ -0,0 +1,5 @@ +Usage: + +python fill_template.py 41417_32.json + +(overwrites GF41417_32.v) diff --git a/src/SpecificGen/fill_template.py b/src/SpecificGen/fill_template.py new file mode 100644 index 000000000..172ec1079 --- /dev/null +++ b/src/SpecificGen/fill_template.py @@ -0,0 +1,39 @@ +import os, sys, json + +enum = lambda n, s : "(" + ", ".join([s + str(x) for x in range(n)]) + ")" + +params = open(sys.argv[1]) +replacements = json.load(params) +params.close() +replacements["kmodw"] = replacements["k"] % replacements["w"] +replacements["kdivw"] = int(replacements["k"] / replacements["w"]) +replacements["enum f"] = enum(replacements["n"], "f") +replacements["enum g"] = enum(replacements["n"], "g") +replacements["enumw f"] = enum(replacements["kdivw"] + 1, "f") +replacements = {k : str(v) for k,v in replacements.items()} + +OUT = "GF" + replacements["k"] + replacements["c"] + "_" + replacements["w"] + ".v" + +if len(sys.argv) > 2: + OUT = sys.argv[2] + +if int(replacements["c"]) % 8 == 1: + TEMPLATE = "GFtemplate3mod4" +else: + TEMPLATE = "GFtemplate5mod8" + +BEGIN_FIELD = "{{{" +END_FIELD = "}}}" +field = lambda s : BEGIN_FIELD + s + END_FIELD + +inp = open(TEMPLATE) +out = open(OUT, "w+") + +for line in inp: + new_line = line + for w in replacements: + new_line = new_line.replace(field(w), replacements[w]) + out.write(new_line) + +inp.close() +out.close() -- cgit v1.2.3