diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-16 17:02:30 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-06-16 19:41:10 -0400 |
commit | bf86eb3bb543191bb75784767f39c6d2253c5bac (patch) | |
tree | 95e4603c1a1f2804910285c16fb24b016941e976 /src/Arithmetic/MontgomeryReduction | |
parent | c2cfdbede87ffb0489384fe41365961fbd4d1df8 (diff) |
Switch to using tuples for word-by-word montgomery
The new parameterized definitions and proofs are in
WordByWord/Abstract/Dependent/*; the old ones are untouched (and unused)
in WordByWord/Abstract/*. I replaced definitions I didn't know how to
write in the Saturated API with the use of an axiom.
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
4 files changed, 543 insertions, 69 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v new file mode 100644 index 000000000..71c8ea117 --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v @@ -0,0 +1,67 @@ +(*** Word-By-Word Montgomery Multiplication *) +(** This file implements Montgomery Form, Montgomery Reduction, and + Montgomery Multiplication on an abstract [T : ℕ → Type]. See + https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion + of the algorithm; note that it may be that none of the algorithms + there exactly match what we're doing here. *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. + +Local Open Scope Z_scope. + +Section WordByWordMontgomery. + Local Coercion Z.pos : positive >-> Z. + Context + {T : nat -> Type} + {eval : forall {n}, T n -> Z} + {zero : forall {n}, T n} + {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {R : positive} + {R_numlimbs : nat} + {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *) + {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *) + {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) + (N : T (S R_numlimbs)). + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (pred_A_numlimbs : nat) + (B : T R_numlimbs) (k : Z) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + Local Definition A_a := dlet p := divmod _ A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. + Local Definition S1 := add _ S (scmul _ a B). + Local Definition s := snd (divmod _ S1). + Local Definition q := s * k mod r. + Local Definition S2 := add _ S1 (scmul _ q N). + Local Definition S3 := fst (divmod _ S2). + Local Definition S4 := drop_high S3. + End Iteration. + + Section loop. + Context (A_numlimbs : nat) + (A : T A_numlimbs) + (B : T R_numlimbs) + (k : Z) + (S' : T (S R_numlimbs)). + + Definition redc_body {pred_A_numlimbs} : T (S pred_A_numlimbs) * T (S R_numlimbs) + -> T pred_A_numlimbs * T (S R_numlimbs) + := fun '(A, S') => (A' _ A, S4 _ B k A S'). + + Fixpoint redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs) + := match count return T count * _ -> _ with + | O => fun A_S => A_S + | S count' => fun A_S => redc_loop count' (redc_body A_S) + end. + + Definition redc : T (S R_numlimbs) + := snd (redc_loop A_numlimbs (A, zero (1 + R_numlimbs))). + End loop. +End WordByWordMontgomery. + +Create HintDb word_by_word_montgomery. +Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v new file mode 100644 index 000000000..e09add277 --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v @@ -0,0 +1,411 @@ +(*** Word-By-Word Montgomery Multiplication Proofs *) +Require Import Coq.Arith.Arith. +Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.NatUtil. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.Tactics.SetEvars. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Local Open Scope Z_scope. + +Section WordByWordMontgomery. + Context + {T : nat -> Type} + {eval : forall {n}, T n -> Z} + {zero : forall {n}, T n} + {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {r_big : r > 1} + {R : positive} + {R_numlimbs : nat} + {R_correct : R = r^Z.of_nat R_numlimbs :> Z} + {small : forall {n}, T n -> Prop} + {eval_zero : forall n, eval (@zero n) = 0} + {eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r} + {eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r} + {small_div : forall n v, small v -> small (fst (@divmod n v))} + {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *) + {eval_scmul: forall n a v, eval (@scmul n a v) = a * eval v} + {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *) + {eval_add : forall n a b, eval (@add n a b) = eval a + eval b} + {small_add : forall n a b, small (@add n a b)} + {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) + {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} + (N : T (S R_numlimbs)) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) + (N_lt_R : eval N < R) + (B : T R_numlimbs) + (B_bounds : 0 <= eval B < R) + ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) + (k : Z) (k_correct : k * eval N mod r = -1). + + Create HintDb push_eval discriminated. + Local Ltac t_small := + repeat first [ assumption + | apply small_add + | apply small_div + | apply Z_mod_lt + | solve [ auto ] + | lia + | progress autorewrite with push_eval ]. + Hint Rewrite + eval_zero + eval_div + eval_mod + eval_add + eval_scmul + eval_drop_high + using (repeat autounfold with word_by_word_montgomery; t_small) + : push_eval. + + Local Arguments eval {_} _. + Local Arguments small {_} _. + Local Arguments divmod {_} _. + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (pred_A_numlimbs : nat) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)) + (small_A : small A) + (S_nonneg : 0 <= eval S). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + + Local Coercion eval : T >-> Z. + + Local Notation a := (@WordByWord.Abstract.Dependent.Definition.a T (@divmod) pred_A_numlimbs A). + Local Notation A' := (@WordByWord.Abstract.Dependent.Definition.A' T (@divmod) pred_A_numlimbs A). + Local Notation S1 := (@WordByWord.Abstract.Dependent.Definition.S1 T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S). + Local Notation s := (@WordByWord.Abstract.Dependent.Definition.s T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S). + Local Notation q := (@WordByWord.Abstract.Dependent.Definition.q T (@divmod) r R_numlimbs scmul add pred_A_numlimbs B k A S). + Local Notation S2 := (@WordByWord.Abstract.Dependent.Definition.S2 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S). + Local Notation S3 := (@WordByWord.Abstract.Dependent.Definition.S3 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S). + Local Notation S4 := (@WordByWord.Abstract.Dependent.Definition.S4 T (@divmod) r R_numlimbs scmul add drop_high N pred_A_numlimbs B k A S). + + Lemma S3_bound + : eval S < eval N + eval B + -> eval S3 < eval N + eval B. + Proof. + assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) + by (intros x y; pose proof (Z_mod_lt x y); omega). + intro HS. + unfold S3, S2, S1. + autorewrite with push_eval; []. + eapply Z.le_lt_trans. + { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); + [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. + Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; + repeat autounfold with word_by_word_montgomery; + autorewrite with push_eval; + try Z.zero_bounds; + auto with lia. } + rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. + autorewrite with zsimplify. + simpl; omega. + Qed. + + Lemma small_A' + : small A'. + Proof. + repeat autounfold with word_by_word_montgomery; auto. + Qed. + + Lemma small_S3 + : small S3. + Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. + + Lemma S3_nonneg : 0 <= eval S3. + Proof. + repeat autounfold with word_by_word_montgomery; + autorewrite with push_eval; []. + rewrite ?Npos_correct; Z.zero_bounds; lia. + Qed. + + Lemma S4_nonneg : 0 <= eval S4. + Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed. + + Lemma S4_bound + : eval S < eval N + eval B + -> eval S4 < eval N + eval B. + Proof. + intro H; pose proof (S3_bound H); pose proof S3_nonneg. + unfold S4. + rewrite eval_drop_high by apply small_S3. + rewrite Z.mod_small by nia. + assumption. + Qed. + + Lemma S1_eq : eval S1 = S + a*B. + Proof. + cbv [S1 a A']. + repeat autorewrite with push_eval. + reflexivity. + Qed. + + Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N. + Proof. + cbv [S2]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. + Qed. + + Lemma S2_mod_r : S2 mod r = 0. + Proof. + cbv [S2 q s]; autorewrite with push_eval. + assert (r > 0) by lia. + assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). + { destruct (Z.eq_dec r 1) as [H'|H']. + { rewrite H'; split; reflexivity. } + { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } + autorewrite with pull_Zmod. + replace 0 with (0 mod r) by apply Zmod_0_l. + eapply F.eq_of_Z_iff. + repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. + rewrite <-Algebra.Hierarchy.associative. + replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). + { cbv [F.of_Z F.add]; simpl. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. + simpl. + rewrite (proj1 Hr), Z.mul_sub_distr_l. + push_Zmod; pull_Zmod. + autorewrite with zsimplify; reflexivity. } + { rewrite <- F.of_Z_mul. + rewrite F.of_Z_mod. + rewrite k_correct. + cbv [F.of_Z F.add F.opp F.one]; simpl. + change (-(1)) with (-1) in *. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. + rewrite (proj1 Hr), (proj2 Hr); reflexivity. } + Qed. + + Lemma S3_mod_N + : S3 mod N = (S + a*B)*ri mod N. + Proof. + cbv [S3]; autorewrite with push_eval cancel_pair. + pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; + cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. + etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| + rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. + rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval. + autorewrite with push_Zmod; + rewrite S2_mod_r; + autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma S4_mod_N + (Hbound : eval S < eval N + eval B) + : S4 mod N = (S + a*B)*ri mod N. + Proof. + pose proof (S3_bound Hbound); pose proof S3_nonneg. + unfold S4; autorewrite with push_eval. + rewrite (Z.mod_small _ (r * _)) by nia. + apply S3_mod_N. + Qed. + End Iteration. + + Local Notation redc_body := (@redc_body T (@divmod) r R_numlimbs scmul add drop_high N B k). + Local Notation redc_loop := (@redc_loop T (@divmod) r R_numlimbs scmul add drop_high N B k). + Local Notation redc A := (@redc T zero (@divmod) r R_numlimbs scmul add drop_high N _ A B k). + + (*Lemma redc_loop_comm_body count + : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S). + Proof. + induction count as [|count IHcount]; try reflexivity. + simpl; intro; rewrite IHcount; reflexivity. + Qed.*) + + Section body. + Context (pred_A_numlimbs : nat) + (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)). + Let A:=fst A_S. + Let S:=snd A_S. + Let A_a:=divmod A. + Let a:=snd A_a. + Context (small_A : small A) + (S_bound : 0 <= eval S < eval N + eval B). + + Lemma small_fst_redc_body : small (fst (redc_body A_S)). + Proof. destruct A_S; apply small_A'; assumption. Qed. + Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). + Proof. destruct A_S; apply S4_nonneg; assumption. Qed. + + Lemma snd_redc_body_mod_N + : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). + Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed. + + Lemma fst_redc_body + : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. + Proof. + destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma fst_redc_body_mod_N + : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). + Proof. + rewrite fst_redc_body. + etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. + unfold a, A_a, A. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma redc_body_bound + : eval S < eval N + eval B + -> eval (snd (redc_body A_S)) < eval N + eval B. + Proof. + destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. + Qed. + End body. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.of_nat !_. + Local Ltac induction_loop count IHcount + := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | (*rewrite redc_loop_comm_body in * *) ]. + Lemma redc_loop_good count A_S + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : small (fst (redc_loop count A_S)) + /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. + induction_loop count IHcount; auto; []. + change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). + destruct_head'_and. + repeat first [ apply conj + | apply small_fst_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply IHcount + | solve [ auto ] ]. + Qed. + + Lemma redc_loop_bound count A_S + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. apply redc_loop_good; assumption. Qed. + + Local Ltac handle_IH_small := + repeat first [ apply redc_loop_good + | apply small_fst_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply conj + | progress destruct_head' and + | solve [ auto ] ]. + + Lemma fst_redc_loop count A_S + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). + Proof. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount, fst_redc_body by handle_IH_small. + change (1 + R_numlimbs)%nat with (S R_numlimbs) in *. + rewrite Zdiv_Zdiv by Z.zero_bounds. + rewrite <- (Z.pow_1_r r) at 1. + rewrite <- Z.pow_add_r by lia. + replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia. + reflexivity. } + Qed. + + Lemma fst_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) mod (eval N) + = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) + * ri^(Z.of_nat count) mod (eval N). + Proof. + rewrite fst_redc_loop by assumption. + destruct count. + { simpl; autorewrite with zsimplify; reflexivity. } + { etransitivity; + [ eapply Z.div_to_inv_modulo; + try solve [ eassumption + | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] + | ]. + { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. + { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } + { lia. } } + reflexivity. } + Qed. + + Local Arguments Z.pow : simpl never. + Lemma snd_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (eval (snd (redc_loop count A_S))) mod (eval N) + = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). + Proof. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount by handle_IH_small. + push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod. + autorewrite with push_eval; []. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + destruct A_S as [A S]. + cbn [fst snd]. + change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). + rewrite !Z.mul_add_distr_r. + rewrite <- !Z.mul_assoc. + replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count))) + by (change (Datatypes.S count) with (1 + count)%nat; + autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). + rewrite <- !Z.add_assoc. + apply Z.add_mod_Proper; [ reflexivity | ]. + unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). + rewrite Z.mod_pull_div by auto with zarith lia. + push_Zmod. + erewrite Z.div_to_inv_modulo; + [ + | apply Z.lt_gt; lia + | eassumption ]. + pull_Zmod. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia + | rewrite (Z.mul_comm _ ri) + | rewrite (Z.mul_assoc _ ri _) + | rewrite (Z.mul_comm _ (ri^_)) + | rewrite (Z.mul_assoc _ (ri^_) _) ]. + repeat first [ rewrite <- Z.mul_assoc + | rewrite <- Z.mul_add_distr_l + | rewrite (Z.mul_comm _ (eval B)) + | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; + rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) + | rewrite Zplus_minus + | rewrite (Z.mul_comm r (r^_)) + | reflexivity ]. } + Qed. + + Lemma redc_bound A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : 0 <= eval (redc A) < eval N + eval B. + Proof. + unfold redc. + apply redc_loop_good; simpl; autorewrite with push_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). + Proof. + unfold redc. + rewrite snd_redc_loop_mod_N; cbn [fst snd]; + autorewrite with push_eval zsimplify; + [ | rewrite ?Npos_correct; auto; lia.. ]. + Z.rewrite_mod_small. + reflexivity. + Qed. +End WordByWordMontgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index c2048889d..55724b940 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -1,12 +1,12 @@ (*** Word-By-Word Montgomery Multiplication *) (** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [list ℤ]. See + Montgomery Multiplication on an abstract [ℤⁿ]. See https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion of the algorithm; note that it may be that none of the algorithms there exactly match what we're doing here. *) Require Import Coq.ZArith.ZArith. Require Import Crypto.Arithmetic.Saturated. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. Require Import Crypto.Util.Notations. Require Import Crypto.Util.LetIn. @@ -18,39 +18,54 @@ Section WordByWordMontgomery. Context {r : positive} {R_numlimbs : nat} - (N : T). + (N' : T R_numlimbs). + (** TODO(andreser): Add a comment here about why we take in [N : T + R_numlimbs]; we need [N : T (S R_numlimbs)] so that the limb + arithmetic works out exactly (so that we can add [q * N] of + length [S (S R_numlimbs)] to [S1] of length [S (S R_numlimbs)] + and then do [divmod] to get something of length [S + R_numlimbs]. *) + Local Notation N := (join0 N'). - Definition redc_body_no_cps (B : T) (k : Z) (A_S : T * T) : T * T - := @redc_body T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k A_S. - Definition redc_loop_no_cps (B : T) (k : Z) (count : nat) (A_S : T * T) : T * T - := @redc_loop T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S. - Definition redc_no_cps (A B : T) (k : Z) : T - := @redc T numlimbs zero divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N A B k. + Definition redc_body_no_cps (B : T R_numlimbs) (k : Z) {pred_A_numlimbs} (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)) + : T pred_A_numlimbs * T (S R_numlimbs) + := @redc_body T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k _ A_S. + Definition redc_loop_no_cps (B : T R_numlimbs) (k : Z) (count : nat) (A_S : T count * T (S R_numlimbs)) + : T 0 * T (S R_numlimbs) + := @redc_loop T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S. + Definition redc_no_cps {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs) + := @redc T (@zero) (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N _ A B k. - Definition redc_body_cps (A B : T) (k : Z) (S' : T) {cpsT} (rest : T * T -> cpsT) : cpsT + Definition redc_body_cps {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs)) + {cpsT} (rest : T pred_A_numlimbs * T (S R_numlimbs) -> cpsT) + : cpsT := divmod_cps A (fun '(A, a) => - @scmul_cps r a B _ (fun aB => @add_cps r S' aB _ (fun S1 => + @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 => divmod_cps S1 (fun '(_, s) => dlet q := s * k mod r in - @scmul_cps r q N _ (fun qN => @add_cps r S1 qN _ (fun S2 => + @scmul_cps r _ q N _ (fun qN => @add_cps r _ S1 qN _ (fun S2 => divmod_cps S2 (fun '(S3, _) => @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))). Section loop. - Context (A B : T) (k : Z) {cpsT : Type}. - Fixpoint redc_loop_cps (count : nat) (rest : T * T -> cpsT) : T * T -> cpsT + Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}. + Fixpoint redc_loop_cps (count : nat) (rest : T 0 * T (S R_numlimbs) -> cpsT) : T count * T (S R_numlimbs) -> cpsT := match count with | O => rest | S count' => fun '(A, S') => redc_body_cps A B k S' (redc_loop_cps count' rest) end. - Definition redc_cps (rest : T -> cpsT) : cpsT - := redc_loop_cps (numlimbs A) (fun '(A, S') => rest S') (A, zero (1 + numlimbs B)). + Definition redc_cps (rest : T (S R_numlimbs) -> cpsT) : cpsT + := redc_loop_cps A_numlimbs (fun '(A, S') => rest S') (A, zero). End loop. - Definition redc_body (A B : T) (k : Z) (S' : T) : T * T := redc_body_cps A B k S' id. - Definition redc_loop (B : T) (k : Z) (count : nat) : T * T -> T * T := redc_loop_cps B k count id. - Definition redc (A B : T) (k : Z) : T := redc_cps A B k id. + Definition redc_body {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs)) + : T pred_A_numlimbs * T (S R_numlimbs) + := redc_body_cps A B k S' id. + Definition redc_loop (B : T R_numlimbs) (k : Z) (count : nat) : T count * T (S R_numlimbs) -> T 0 * T (S R_numlimbs) + := redc_loop_cps B k count id. + Definition redc {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs) + := redc_cps A B k id. End WordByWordMontgomery. Hint Opaque redc redc_body redc_loop : uncps. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index 96646c61c..6dfd6a10a 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -2,8 +2,8 @@ Require Import Coq.ZArith.BinInt. Require Import Coq.micromega.Lia. Require Import Crypto.Arithmetic.Saturated. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. -Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Proofs. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Proofs. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Tactics.BreakMatch. @@ -19,25 +19,24 @@ Section WordByWordMontgomery. Local Notation add := (@add (Z.pos r)). Local Notation scmul := (@scmul (Z.pos r)). Local Notation eval_zero := (@eval_zero (Z.pos r)). + Local Notation eval_join0 := (@eval_zero (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation eval_div := (@eval_div (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation eval_mod := (@eval_mod (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation small_div := (@small_div (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation numlimbs_div := (@numlimbs_div (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation eval_scmul := (@eval_scmul (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation numlimbs_scmul := (@numlimbs_scmul (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation eval_add := (@eval_add (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation small_add := (@small_add (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation numlimbs_add := (@numlimbs_add (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation drop_high := (@drop_high (S R_numlimbs)). - Local Notation numlimbs_drop_high := (@numlimbs_drop_high (Z.pos r) (Zorder.Zgt_pos_0 _) (S R_numlimbs)). - Context (N A B : T) - (k : Z) - ri + Context (A_numlimbs : nat) + (N' : T R_numlimbs) + (A : T A_numlimbs) + (B : T R_numlimbs) + (k : Z). + Local Notation N := (join0 N'). + Context ri (r_big : r > 1) (small_A : small A) - (Hnumlimbs_le : (R_numlimbs <= numlimbs B)%nat) - (Hnumlimbs_eq : R_numlimbs = numlimbs B) - (A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat (numlimbs A)) + (A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat A_numlimbs) (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) (N_bound : 0 < eval N < r^Z.of_nat R_numlimbs) (B_bound' : 0 <= eval B < r^Z.of_nat R_numlimbs) @@ -71,33 +70,24 @@ Section WordByWordMontgomery. rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity. Qed. - Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N). - Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N). - Local Notation redc_body := (@redc_body r R_numlimbs N). - Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k). - Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k). - Local Notation redc_loop := (@redc_loop r R_numlimbs N B k). - Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A B k). - Local Notation redc_cps := (@redc_cps r R_numlimbs N A B k). - Local Notation redc := (@redc r R_numlimbs N A B k). + Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N'). + Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N'). + Local Notation redc_body := (@redc_body r R_numlimbs N'). + Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N' B k). + Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N' B k). + Local Notation redc_loop := (@redc_loop r R_numlimbs N' B k). + Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N' A_numlimbs A B k). + Local Notation redc_cps := (@redc_cps r R_numlimbs N' A_numlimbs A B k). + Local Notation redc := (@redc r R_numlimbs N' A_numlimbs A B k). Definition redc_no_cps_bound : 0 <= eval redc_no_cps < eval N + eval B - := @redc_bound T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A. - Definition numlimbs_redc_no_cps_gen - : numlimbs redc_no_cps - = match numlimbs A with - | O => S (numlimbs B) - | _ => S R_numlimbs - end - := @numlimbs_redc_gen T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_le. - Definition numlimbs_redc_no_cps : numlimbs redc_no_cps = S (numlimbs B) - := @numlimbs_redc T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_eq. + := @redc_bound T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A_numlimbs A small_A. Definition redc_no_cps_mod_N - : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N) - := @redc_mod_N T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A small_A A_bound. + : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N) + := @redc_mod_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A_numlimbs A small_A A_bound. - Lemma redc_body_cps_id (A' S' : T) {cpsT} f - : @redc_body_cps A' B k S' cpsT f = f (redc_body A' B k S'). + Lemma redc_body_cps_id pred_A_numlimbs (A' : T (S pred_A_numlimbs)) (S' : T (S R_numlimbs)) {cpsT} f + : @redc_body_cps pred_A_numlimbs A' B k S' cpsT f = f (redc_body A' B k S'). Proof. unfold redc_body, redc_body_cps, LetIn.Let_In. repeat first [ reflexivity @@ -105,7 +95,7 @@ Section WordByWordMontgomery. | progress autorewrite with uncps ]. Qed. - Lemma redc_loop_cps_id (count : nat) (A_S : T * T) {cpsT} f + Lemma redc_loop_cps_id (count : nat) (A_S : T count * T (S R_numlimbs)) {cpsT} f : @redc_loop_cps cpsT count f A_S = f (redc_loop count A_S). Proof. unfold redc_loop. @@ -121,10 +111,10 @@ Section WordByWordMontgomery. etransitivity; rewrite redc_loop_cps_id; [ | reflexivity ]; break_innermost_match; reflexivity. Qed. - Lemma redc_body_id_no_cps A' S' - : redc_body A' B k S' = redc_body_no_cps B k (A', S'). + Lemma redc_body_id_no_cps pred_A_numlimbs A' S' + : @redc_body pred_A_numlimbs A' B k S' = redc_body_no_cps B k (A', S'). Proof. - unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Definition.redc_body, LetIn.Let_In, id. + unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Dependent.Definition.redc_body, LetIn.Let_In, id. repeat autounfold with word_by_word_montgomery. repeat first [ reflexivity | progress cbn [fst snd id] @@ -144,24 +134,15 @@ Section WordByWordMontgomery. Qed. Lemma redc_cps_id_no_cps : redc = redc_no_cps. Proof. - unfold redc, redc_no_cps, redc_cps, Abstract.Definition.redc. + unfold redc, redc_no_cps, redc_cps, Abstract.Dependent.Definition.redc. rewrite redc_loop_cps_id, (surjective_pairing (redc_loop _ _)). rewrite redc_loop_cps_id_no_cps; reflexivity. Qed. Lemma redc_bound : 0 <= eval redc < eval N + eval B. Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound. Qed. - Lemma numlimbs_redc_gen - : numlimbs redc - = match numlimbs A with - | O => S (numlimbs B) - | _ => S R_numlimbs - end. - Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps_gen. Qed. - Lemma numlimbs_redc : numlimbs redc = S (numlimbs B). - Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps. Qed. Lemma redc_mod_N - : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). + : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_mod_N. Qed. End WordByWordMontgomery. |