aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/MontgomeryReduction
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-06-16 17:02:30 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2017-06-16 19:41:10 -0400
commitbf86eb3bb543191bb75784767f39c6d2253c5bac (patch)
tree95e4603c1a1f2804910285c16fb24b016941e976 /src/Arithmetic/MontgomeryReduction
parentc2cfdbede87ffb0489384fe41365961fbd4d1df8 (diff)
Switch to using tuples for word-by-word montgomery
The new parameterized definitions and proofs are in WordByWord/Abstract/Dependent/*; the old ones are untouched (and unused) in WordByWord/Abstract/*. I replaced definitions I didn't know how to write in the Saturated API with the use of an axiom.
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v67
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v411
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v53
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v81
4 files changed, 543 insertions, 69 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v
new file mode 100644
index 000000000..71c8ea117
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v
@@ -0,0 +1,67 @@
+(*** Word-By-Word Montgomery Multiplication *)
+(** This file implements Montgomery Form, Montgomery Reduction, and
+ Montgomery Multiplication on an abstract [T : ℕ → Type]. See
+ https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion
+ of the algorithm; note that it may be that none of the algorithms
+ there exactly match what we're doing here. *)
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.LetIn.
+
+Local Open Scope Z_scope.
+
+Section WordByWordMontgomery.
+ Local Coercion Z.pos : positive >-> Z.
+ Context
+ {T : nat -> Type}
+ {eval : forall {n}, T n -> Z}
+ {zero : forall {n}, T n}
+ {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *)
+ {r : positive}
+ {R : positive}
+ {R_numlimbs : nat}
+ {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *)
+ {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *)
+ {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *)
+ (N : T (S R_numlimbs)).
+
+ (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *)
+ Section Iteration.
+ Context (pred_A_numlimbs : nat)
+ (B : T R_numlimbs) (k : Z)
+ (A : T (S pred_A_numlimbs))
+ (S : T (S R_numlimbs)).
+ (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *)
+ Local Definition A_a := dlet p := divmod _ A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a.
+ Local Definition S1 := add _ S (scmul _ a B).
+ Local Definition s := snd (divmod _ S1).
+ Local Definition q := s * k mod r.
+ Local Definition S2 := add _ S1 (scmul _ q N).
+ Local Definition S3 := fst (divmod _ S2).
+ Local Definition S4 := drop_high S3.
+ End Iteration.
+
+ Section loop.
+ Context (A_numlimbs : nat)
+ (A : T A_numlimbs)
+ (B : T R_numlimbs)
+ (k : Z)
+ (S' : T (S R_numlimbs)).
+
+ Definition redc_body {pred_A_numlimbs} : T (S pred_A_numlimbs) * T (S R_numlimbs)
+ -> T pred_A_numlimbs * T (S R_numlimbs)
+ := fun '(A, S') => (A' _ A, S4 _ B k A S').
+
+ Fixpoint redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs)
+ := match count return T count * _ -> _ with
+ | O => fun A_S => A_S
+ | S count' => fun A_S => redc_loop count' (redc_body A_S)
+ end.
+
+ Definition redc : T (S R_numlimbs)
+ := snd (redc_loop A_numlimbs (A, zero (1 + R_numlimbs))).
+ End loop.
+End WordByWordMontgomery.
+
+Create HintDb word_by_word_montgomery.
+Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery.
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v
new file mode 100644
index 000000000..e09add277
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v
@@ -0,0 +1,411 @@
+(*** Word-By-Word Montgomery Multiplication Proofs *)
+Require Import Coq.Arith.Arith.
+Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia.
+Require Import Crypto.Util.LetIn.
+Require Import Crypto.Util.Prod.
+Require Import Crypto.Util.NatUtil.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic.
+Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition.
+Require Import Crypto.Algebra.Ring.
+Require Import Crypto.Util.Sigma.
+Require Import Crypto.Util.Tactics.SetEvars.
+Require Import Crypto.Util.Tactics.SubstEvars.
+Require Import Crypto.Util.Tactics.DestructHead.
+Local Open Scope Z_scope.
+
+Section WordByWordMontgomery.
+ Context
+ {T : nat -> Type}
+ {eval : forall {n}, T n -> Z}
+ {zero : forall {n}, T n}
+ {divmod : forall {n}, T (S n) -> T n * Z} (* returns lowest limb and all-but-lowest-limb *)
+ {r : positive}
+ {r_big : r > 1}
+ {R : positive}
+ {R_numlimbs : nat}
+ {R_correct : R = r^Z.of_nat R_numlimbs :> Z}
+ {small : forall {n}, T n -> Prop}
+ {eval_zero : forall n, eval (@zero n) = 0}
+ {eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r}
+ {eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r}
+ {small_div : forall n v, small v -> small (fst (@divmod n v))}
+ {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *)
+ {eval_scmul: forall n a v, eval (@scmul n a v) = a * eval v}
+ {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *)
+ {eval_add : forall n a b, eval (@add n a b) = eval a + eval b}
+ {small_add : forall n a b, small (@add n a b)}
+ {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *)
+ {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)}
+ (N : T (S R_numlimbs)) (Npos : positive) (Npos_correct: eval N = Z.pos Npos)
+ (N_lt_R : eval N < R)
+ (B : T R_numlimbs)
+ (B_bounds : 0 <= eval B < R)
+ ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N))
+ (k : Z) (k_correct : k * eval N mod r = -1).
+
+ Create HintDb push_eval discriminated.
+ Local Ltac t_small :=
+ repeat first [ assumption
+ | apply small_add
+ | apply small_div
+ | apply Z_mod_lt
+ | solve [ auto ]
+ | lia
+ | progress autorewrite with push_eval ].
+ Hint Rewrite
+ eval_zero
+ eval_div
+ eval_mod
+ eval_add
+ eval_scmul
+ eval_drop_high
+ using (repeat autounfold with word_by_word_montgomery; t_small)
+ : push_eval.
+
+ Local Arguments eval {_} _.
+ Local Arguments small {_} _.
+ Local Arguments divmod {_} _.
+
+ (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *)
+ Section Iteration.
+ Context (pred_A_numlimbs : nat)
+ (A : T (S pred_A_numlimbs))
+ (S : T (S R_numlimbs))
+ (small_A : small A)
+ (S_nonneg : 0 <= eval S).
+ (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *)
+
+ Local Coercion eval : T >-> Z.
+
+ Local Notation a := (@WordByWord.Abstract.Dependent.Definition.a T (@divmod) pred_A_numlimbs A).
+ Local Notation A' := (@WordByWord.Abstract.Dependent.Definition.A' T (@divmod) pred_A_numlimbs A).
+ Local Notation S1 := (@WordByWord.Abstract.Dependent.Definition.S1 T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S).
+ Local Notation s := (@WordByWord.Abstract.Dependent.Definition.s T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S).
+ Local Notation q := (@WordByWord.Abstract.Dependent.Definition.q T (@divmod) r R_numlimbs scmul add pred_A_numlimbs B k A S).
+ Local Notation S2 := (@WordByWord.Abstract.Dependent.Definition.S2 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S).
+ Local Notation S3 := (@WordByWord.Abstract.Dependent.Definition.S3 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S).
+ Local Notation S4 := (@WordByWord.Abstract.Dependent.Definition.S4 T (@divmod) r R_numlimbs scmul add drop_high N pred_A_numlimbs B k A S).
+
+ Lemma S3_bound
+ : eval S < eval N + eval B
+ -> eval S3 < eval N + eval B.
+ Proof.
+ assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1)
+ by (intros x y; pose proof (Z_mod_lt x y); omega).
+ intro HS.
+ unfold S3, S2, S1.
+ autorewrite with push_eval; [].
+ eapply Z.le_lt_trans.
+ { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r);
+ [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ].
+ Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia;
+ repeat autounfold with word_by_word_montgomery;
+ autorewrite with push_eval;
+ try Z.zero_bounds;
+ auto with lia. }
+ rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia.
+ autorewrite with zsimplify.
+ simpl; omega.
+ Qed.
+
+ Lemma small_A'
+ : small A'.
+ Proof.
+ repeat autounfold with word_by_word_montgomery; auto.
+ Qed.
+
+ Lemma small_S3
+ : small S3.
+ Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed.
+
+ Lemma S3_nonneg : 0 <= eval S3.
+ Proof.
+ repeat autounfold with word_by_word_montgomery;
+ autorewrite with push_eval; [].
+ rewrite ?Npos_correct; Z.zero_bounds; lia.
+ Qed.
+
+ Lemma S4_nonneg : 0 <= eval S4.
+ Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed.
+
+ Lemma S4_bound
+ : eval S < eval N + eval B
+ -> eval S4 < eval N + eval B.
+ Proof.
+ intro H; pose proof (S3_bound H); pose proof S3_nonneg.
+ unfold S4.
+ rewrite eval_drop_high by apply small_S3.
+ rewrite Z.mod_small by nia.
+ assumption.
+ Qed.
+
+ Lemma S1_eq : eval S1 = S + a*B.
+ Proof.
+ cbv [S1 a A'].
+ repeat autorewrite with push_eval.
+ reflexivity.
+ Qed.
+
+ Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N.
+ Proof.
+ cbv [S2]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity.
+ Qed.
+
+ Lemma S2_mod_r : S2 mod r = 0.
+ Proof.
+ cbv [S2 q s]; autorewrite with push_eval.
+ assert (r > 0) by lia.
+ assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1).
+ { destruct (Z.eq_dec r 1) as [H'|H'].
+ { rewrite H'; split; reflexivity. }
+ { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } }
+ autorewrite with pull_Zmod.
+ replace 0 with (0 mod r) by apply Zmod_0_l.
+ eapply F.eq_of_Z_iff.
+ repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod.
+ rewrite <-Algebra.Hierarchy.associative.
+ replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one).
+ { cbv [F.of_Z F.add]; simpl.
+ apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ].
+ simpl.
+ rewrite (proj1 Hr), Z.mul_sub_distr_l.
+ push_Zmod; pull_Zmod.
+ autorewrite with zsimplify; reflexivity. }
+ { rewrite <- F.of_Z_mul.
+ rewrite F.of_Z_mod.
+ rewrite k_correct.
+ cbv [F.of_Z F.add F.opp F.one]; simpl.
+ change (-(1)) with (-1) in *.
+ apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl.
+ rewrite (proj1 Hr), (proj2 Hr); reflexivity. }
+ Qed.
+
+ Lemma S3_mod_N
+ : S3 mod N = (S + a*B)*ri mod N.
+ Proof.
+ cbv [S3]; autorewrite with push_eval cancel_pair.
+ pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH;
+ cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH.
+ etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)|
+ rewrite (fun a => Z.mul_mod_l a ri N); reflexivity].
+ rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval.
+ autorewrite with push_Zmod;
+ rewrite S2_mod_r;
+ autorewrite with zsimplify.
+ reflexivity.
+ Qed.
+
+ Lemma S4_mod_N
+ (Hbound : eval S < eval N + eval B)
+ : S4 mod N = (S + a*B)*ri mod N.
+ Proof.
+ pose proof (S3_bound Hbound); pose proof S3_nonneg.
+ unfold S4; autorewrite with push_eval.
+ rewrite (Z.mod_small _ (r * _)) by nia.
+ apply S3_mod_N.
+ Qed.
+ End Iteration.
+
+ Local Notation redc_body := (@redc_body T (@divmod) r R_numlimbs scmul add drop_high N B k).
+ Local Notation redc_loop := (@redc_loop T (@divmod) r R_numlimbs scmul add drop_high N B k).
+ Local Notation redc A := (@redc T zero (@divmod) r R_numlimbs scmul add drop_high N _ A B k).
+
+ (*Lemma redc_loop_comm_body count
+ : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S).
+ Proof.
+ induction count as [|count IHcount]; try reflexivity.
+ simpl; intro; rewrite IHcount; reflexivity.
+ Qed.*)
+
+ Section body.
+ Context (pred_A_numlimbs : nat)
+ (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)).
+ Let A:=fst A_S.
+ Let S:=snd A_S.
+ Let A_a:=divmod A.
+ Let a:=snd A_a.
+ Context (small_A : small A)
+ (S_bound : 0 <= eval S < eval N + eval B).
+
+ Lemma small_fst_redc_body : small (fst (redc_body A_S)).
+ Proof. destruct A_S; apply small_A'; assumption. Qed.
+ Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)).
+ Proof. destruct A_S; apply S4_nonneg; assumption. Qed.
+
+ Lemma snd_redc_body_mod_N
+ : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N).
+ Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed.
+
+ Lemma fst_redc_body
+ : (eval (fst (redc_body A_S))) = eval (fst A_S) / r.
+ Proof.
+ destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl.
+ autorewrite with push_eval.
+ reflexivity.
+ Qed.
+
+ Lemma fst_redc_body_mod_N
+ : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N).
+ Proof.
+ rewrite fst_redc_body.
+ etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ].
+ unfold a, A_a, A.
+ autorewrite with push_eval.
+ reflexivity.
+ Qed.
+
+ Lemma redc_body_bound
+ : eval S < eval N + eval B
+ -> eval (snd (redc_body A_S)) < eval N + eval B.
+ Proof.
+ destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega.
+ Qed.
+ End body.
+
+ Local Arguments Z.pow !_ !_.
+ Local Arguments Z.of_nat !_.
+ Local Ltac induction_loop count IHcount
+ := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | (*rewrite redc_loop_comm_body in * *) ].
+ Lemma redc_loop_good count A_S
+ (Hsmall : small (fst A_S))
+ (Hbound : 0 <= eval (snd A_S) < eval N + eval B)
+ : small (fst (redc_loop count A_S))
+ /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B.
+ Proof.
+ induction_loop count IHcount; auto; [].
+ change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *).
+ destruct_head'_and.
+ repeat first [ apply conj
+ | apply small_fst_redc_body
+ | apply redc_body_bound
+ | apply snd_redc_body_nonneg
+ | apply IHcount
+ | solve [ auto ] ].
+ Qed.
+
+ Lemma redc_loop_bound count A_S
+ (Hsmall : small (fst A_S))
+ (Hbound : 0 <= eval (snd A_S) < eval N + eval B)
+ : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B.
+ Proof. apply redc_loop_good; assumption. Qed.
+
+ Local Ltac handle_IH_small :=
+ repeat first [ apply redc_loop_good
+ | apply small_fst_redc_body
+ | apply redc_body_bound
+ | apply snd_redc_body_nonneg
+ | apply conj
+ | progress destruct_head' and
+ | solve [ auto ] ].
+
+ Lemma fst_redc_loop count A_S
+ (Hsmall : small (fst A_S))
+ (Hbound : 0 <= eval (snd A_S) < eval N + eval B)
+ : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count).
+ Proof.
+ induction_loop count IHcount.
+ { simpl; autorewrite with zsimplify; reflexivity. }
+ { rewrite IHcount, fst_redc_body by handle_IH_small.
+ change (1 + R_numlimbs)%nat with (S R_numlimbs) in *.
+ rewrite Zdiv_Zdiv by Z.zero_bounds.
+ rewrite <- (Z.pow_1_r r) at 1.
+ rewrite <- Z.pow_add_r by lia.
+ replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia.
+ reflexivity. }
+ Qed.
+
+ Lemma fst_redc_loop_mod_N count A_S
+ (Hsmall : small (fst A_S))
+ (Hbound : 0 <= eval (snd A_S) < eval N + eval B)
+ : eval (fst (redc_loop count A_S)) mod (eval N)
+ = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count)
+ * ri^(Z.of_nat count) mod (eval N).
+ Proof.
+ rewrite fst_redc_loop by assumption.
+ destruct count.
+ { simpl; autorewrite with zsimplify; reflexivity. }
+ { etransitivity;
+ [ eapply Z.div_to_inv_modulo;
+ try solve [ eassumption
+ | apply Z.lt_gt, Z.pow_pos_nonneg; lia ]
+ | ].
+ { erewrite <- Z.pow_mul_l, <- Z.pow_1_l.
+ { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. }
+ { lia. } }
+ reflexivity. }
+ Qed.
+
+ Local Arguments Z.pow : simpl never.
+ Lemma snd_redc_loop_mod_N count A_S
+ (Hsmall : small (fst A_S))
+ (Hbound : 0 <= eval (snd A_S) < eval N + eval B)
+ : (eval (snd (redc_loop count A_S))) mod (eval N)
+ = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N).
+ Proof.
+ induction_loop count IHcount.
+ { simpl; autorewrite with zsimplify; reflexivity. }
+ { rewrite IHcount by handle_IH_small.
+ push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod.
+ autorewrite with push_eval; [].
+ match goal with
+ | [ |- ?x mod ?N = ?y mod ?N ]
+ => change (Z.equiv_modulo N x y)
+ end.
+ destruct A_S as [A S].
+ cbn [fst snd].
+ change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)).
+ rewrite !Z.mul_add_distr_r.
+ rewrite <- !Z.mul_assoc.
+ replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count)))
+ by (change (Datatypes.S count) with (1 + count)%nat;
+ autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia).
+ rewrite <- !Z.add_assoc.
+ apply Z.add_mod_Proper; [ reflexivity | ].
+ unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)).
+ rewrite Z.mod_pull_div by auto with zarith lia.
+ push_Zmod.
+ erewrite Z.div_to_inv_modulo;
+ [
+ | apply Z.lt_gt; lia
+ | eassumption ].
+ pull_Zmod.
+ match goal with
+ | [ |- ?x mod ?N = ?y mod ?N ]
+ => change (Z.equiv_modulo N x y)
+ end.
+ repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia
+ | rewrite (Z.mul_comm _ ri)
+ | rewrite (Z.mul_assoc _ ri _)
+ | rewrite (Z.mul_comm _ (ri^_))
+ | rewrite (Z.mul_assoc _ (ri^_) _) ].
+ repeat first [ rewrite <- Z.mul_assoc
+ | rewrite <- Z.mul_add_distr_l
+ | rewrite (Z.mul_comm _ (eval B))
+ | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia;
+ rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds)
+ | rewrite Zplus_minus
+ | rewrite (Z.mul_comm r (r^_))
+ | reflexivity ]. }
+ Qed.
+
+ Lemma redc_bound A_numlimbs (A : T A_numlimbs)
+ (small_A : small A)
+ : 0 <= eval (redc A) < eval N + eval B.
+ Proof.
+ unfold redc.
+ apply redc_loop_good; simpl; autorewrite with push_eval;
+ rewrite ?Npos_correct; auto; lia.
+ Qed.
+
+ Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs)
+ : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N).
+ Proof.
+ unfold redc.
+ rewrite snd_redc_loop_mod_N; cbn [fst snd];
+ autorewrite with push_eval zsimplify;
+ [ | rewrite ?Npos_correct; auto; lia.. ].
+ Z.rewrite_mod_small.
+ reflexivity.
+ Qed.
+End WordByWordMontgomery.
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
index c2048889d..55724b940 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
@@ -1,12 +1,12 @@
(*** Word-By-Word Montgomery Multiplication *)
(** This file implements Montgomery Form, Montgomery Reduction, and
- Montgomery Multiplication on an abstract [list ℤ]. See
+ Montgomery Multiplication on an abstract [ℤⁿ]. See
https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion
of the algorithm; note that it may be that none of the algorithms
there exactly match what we're doing here. *)
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Arithmetic.Saturated.
-Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition.
+Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition.
Require Import Crypto.Util.Notations.
Require Import Crypto.Util.LetIn.
@@ -18,39 +18,54 @@ Section WordByWordMontgomery.
Context
{r : positive}
{R_numlimbs : nat}
- (N : T).
+ (N' : T R_numlimbs).
+ (** TODO(andreser): Add a comment here about why we take in [N : T
+ R_numlimbs]; we need [N : T (S R_numlimbs)] so that the limb
+ arithmetic works out exactly (so that we can add [q * N] of
+ length [S (S R_numlimbs)] to [S1] of length [S (S R_numlimbs)]
+ and then do [divmod] to get something of length [S
+ R_numlimbs]. *)
+ Local Notation N := (join0 N').
- Definition redc_body_no_cps (B : T) (k : Z) (A_S : T * T) : T * T
- := @redc_body T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k A_S.
- Definition redc_loop_no_cps (B : T) (k : Z) (count : nat) (A_S : T * T) : T * T
- := @redc_loop T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S.
- Definition redc_no_cps (A B : T) (k : Z) : T
- := @redc T numlimbs zero divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N A B k.
+ Definition redc_body_no_cps (B : T R_numlimbs) (k : Z) {pred_A_numlimbs} (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs))
+ : T pred_A_numlimbs * T (S R_numlimbs)
+ := @redc_body T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k _ A_S.
+ Definition redc_loop_no_cps (B : T R_numlimbs) (k : Z) (count : nat) (A_S : T count * T (S R_numlimbs))
+ : T 0 * T (S R_numlimbs)
+ := @redc_loop T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S.
+ Definition redc_no_cps {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs)
+ := @redc T (@zero) (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N _ A B k.
- Definition redc_body_cps (A B : T) (k : Z) (S' : T) {cpsT} (rest : T * T -> cpsT) : cpsT
+ Definition redc_body_cps {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs))
+ {cpsT} (rest : T pred_A_numlimbs * T (S R_numlimbs) -> cpsT)
+ : cpsT
:= divmod_cps A (fun '(A, a) =>
- @scmul_cps r a B _ (fun aB => @add_cps r S' aB _ (fun S1 =>
+ @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 =>
divmod_cps S1 (fun '(_, s) =>
dlet q := s * k mod r in
- @scmul_cps r q N _ (fun qN => @add_cps r S1 qN _ (fun S2 =>
+ @scmul_cps r _ q N _ (fun qN => @add_cps r _ S1 qN _ (fun S2 =>
divmod_cps S2 (fun '(S3, _) =>
@drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))).
Section loop.
- Context (A B : T) (k : Z) {cpsT : Type}.
- Fixpoint redc_loop_cps (count : nat) (rest : T * T -> cpsT) : T * T -> cpsT
+ Context {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) {cpsT : Type}.
+ Fixpoint redc_loop_cps (count : nat) (rest : T 0 * T (S R_numlimbs) -> cpsT) : T count * T (S R_numlimbs) -> cpsT
:= match count with
| O => rest
| S count' => fun '(A, S') => redc_body_cps A B k S' (redc_loop_cps count' rest)
end.
- Definition redc_cps (rest : T -> cpsT) : cpsT
- := redc_loop_cps (numlimbs A) (fun '(A, S') => rest S') (A, zero (1 + numlimbs B)).
+ Definition redc_cps (rest : T (S R_numlimbs) -> cpsT) : cpsT
+ := redc_loop_cps A_numlimbs (fun '(A, S') => rest S') (A, zero).
End loop.
- Definition redc_body (A B : T) (k : Z) (S' : T) : T * T := redc_body_cps A B k S' id.
- Definition redc_loop (B : T) (k : Z) (count : nat) : T * T -> T * T := redc_loop_cps B k count id.
- Definition redc (A B : T) (k : Z) : T := redc_cps A B k id.
+ Definition redc_body {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs))
+ : T pred_A_numlimbs * T (S R_numlimbs)
+ := redc_body_cps A B k S' id.
+ Definition redc_loop (B : T R_numlimbs) (k : Z) (count : nat) : T count * T (S R_numlimbs) -> T 0 * T (S R_numlimbs)
+ := redc_loop_cps B k count id.
+ Definition redc {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs)
+ := redc_cps A B k id.
End WordByWordMontgomery.
Hint Opaque redc redc_body redc_loop : uncps.
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
index 96646c61c..6dfd6a10a 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
@@ -2,8 +2,8 @@
Require Import Coq.ZArith.BinInt.
Require Import Coq.micromega.Lia.
Require Import Crypto.Arithmetic.Saturated.
-Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition.
-Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Proofs.
+Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Definition.
+Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Dependent.Proofs.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Tactics.BreakMatch.
@@ -19,25 +19,24 @@ Section WordByWordMontgomery.
Local Notation add := (@add (Z.pos r)).
Local Notation scmul := (@scmul (Z.pos r)).
Local Notation eval_zero := (@eval_zero (Z.pos r)).
+ Local Notation eval_join0 := (@eval_zero (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation eval_div := (@eval_div (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation eval_mod := (@eval_mod (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation small_div := (@small_div (Z.pos r) (Zorder.Zgt_pos_0 _)).
- Local Notation numlimbs_div := (@numlimbs_div (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation eval_scmul := (@eval_scmul (Z.pos r) (Zorder.Zgt_pos_0 _)).
- Local Notation numlimbs_scmul := (@numlimbs_scmul (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation eval_add := (@eval_add (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation small_add := (@small_add (Z.pos r) (Zorder.Zgt_pos_0 _)).
- Local Notation numlimbs_add := (@numlimbs_add (Z.pos r) (Zorder.Zgt_pos_0 _)).
Local Notation drop_high := (@drop_high (S R_numlimbs)).
- Local Notation numlimbs_drop_high := (@numlimbs_drop_high (Z.pos r) (Zorder.Zgt_pos_0 _) (S R_numlimbs)).
- Context (N A B : T)
- (k : Z)
- ri
+ Context (A_numlimbs : nat)
+ (N' : T R_numlimbs)
+ (A : T A_numlimbs)
+ (B : T R_numlimbs)
+ (k : Z).
+ Local Notation N := (join0 N').
+ Context ri
(r_big : r > 1)
(small_A : small A)
- (Hnumlimbs_le : (R_numlimbs <= numlimbs B)%nat)
- (Hnumlimbs_eq : R_numlimbs = numlimbs B)
- (A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat (numlimbs A))
+ (A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat A_numlimbs)
(ri_correct : r*ri mod (eval N) = 1 mod (eval N))
(N_bound : 0 < eval N < r^Z.of_nat R_numlimbs)
(B_bound' : 0 <= eval B < r^Z.of_nat R_numlimbs)
@@ -71,33 +70,24 @@ Section WordByWordMontgomery.
rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity.
Qed.
- Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N).
- Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N).
- Local Notation redc_body := (@redc_body r R_numlimbs N).
- Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k).
- Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k).
- Local Notation redc_loop := (@redc_loop r R_numlimbs N B k).
- Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A B k).
- Local Notation redc_cps := (@redc_cps r R_numlimbs N A B k).
- Local Notation redc := (@redc r R_numlimbs N A B k).
+ Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N').
+ Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N').
+ Local Notation redc_body := (@redc_body r R_numlimbs N').
+ Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N' B k).
+ Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N' B k).
+ Local Notation redc_loop := (@redc_loop r R_numlimbs N' B k).
+ Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N' A_numlimbs A B k).
+ Local Notation redc_cps := (@redc_cps r R_numlimbs N' A_numlimbs A B k).
+ Local Notation redc := (@redc r R_numlimbs N' A_numlimbs A B k).
Definition redc_no_cps_bound : 0 <= eval redc_no_cps < eval N + eval B
- := @redc_bound T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A.
- Definition numlimbs_redc_no_cps_gen
- : numlimbs redc_no_cps
- = match numlimbs A with
- | O => S (numlimbs B)
- | _ => S R_numlimbs
- end
- := @numlimbs_redc_gen T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_le.
- Definition numlimbs_redc_no_cps : numlimbs redc_no_cps = S (numlimbs B)
- := @numlimbs_redc T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_eq.
+ := @redc_bound T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A_numlimbs A small_A.
Definition redc_no_cps_mod_N
- : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N)
- := @redc_mod_N T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A small_A A_bound.
+ : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N)
+ := @redc_mod_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A_numlimbs A small_A A_bound.
- Lemma redc_body_cps_id (A' S' : T) {cpsT} f
- : @redc_body_cps A' B k S' cpsT f = f (redc_body A' B k S').
+ Lemma redc_body_cps_id pred_A_numlimbs (A' : T (S pred_A_numlimbs)) (S' : T (S R_numlimbs)) {cpsT} f
+ : @redc_body_cps pred_A_numlimbs A' B k S' cpsT f = f (redc_body A' B k S').
Proof.
unfold redc_body, redc_body_cps, LetIn.Let_In.
repeat first [ reflexivity
@@ -105,7 +95,7 @@ Section WordByWordMontgomery.
| progress autorewrite with uncps ].
Qed.
- Lemma redc_loop_cps_id (count : nat) (A_S : T * T) {cpsT} f
+ Lemma redc_loop_cps_id (count : nat) (A_S : T count * T (S R_numlimbs)) {cpsT} f
: @redc_loop_cps cpsT count f A_S = f (redc_loop count A_S).
Proof.
unfold redc_loop.
@@ -121,10 +111,10 @@ Section WordByWordMontgomery.
etransitivity; rewrite redc_loop_cps_id; [ | reflexivity ]; break_innermost_match; reflexivity.
Qed.
- Lemma redc_body_id_no_cps A' S'
- : redc_body A' B k S' = redc_body_no_cps B k (A', S').
+ Lemma redc_body_id_no_cps pred_A_numlimbs A' S'
+ : @redc_body pred_A_numlimbs A' B k S' = redc_body_no_cps B k (A', S').
Proof.
- unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Definition.redc_body, LetIn.Let_In, id.
+ unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Dependent.Definition.redc_body, LetIn.Let_In, id.
repeat autounfold with word_by_word_montgomery.
repeat first [ reflexivity
| progress cbn [fst snd id]
@@ -144,24 +134,15 @@ Section WordByWordMontgomery.
Qed.
Lemma redc_cps_id_no_cps : redc = redc_no_cps.
Proof.
- unfold redc, redc_no_cps, redc_cps, Abstract.Definition.redc.
+ unfold redc, redc_no_cps, redc_cps, Abstract.Dependent.Definition.redc.
rewrite redc_loop_cps_id, (surjective_pairing (redc_loop _ _)).
rewrite redc_loop_cps_id_no_cps; reflexivity.
Qed.
Lemma redc_bound : 0 <= eval redc < eval N + eval B.
Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound. Qed.
- Lemma numlimbs_redc_gen
- : numlimbs redc
- = match numlimbs A with
- | O => S (numlimbs B)
- | _ => S R_numlimbs
- end.
- Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps_gen. Qed.
- Lemma numlimbs_redc : numlimbs redc = S (numlimbs B).
- Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps. Qed.
Lemma redc_mod_N
- : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N).
+ : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N).
Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_mod_N. Qed.
End WordByWordMontgomery.