diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-13 01:43:06 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-13 01:43:06 -0400 |
commit | b4b711cba32a21806c6c0aae53be40c04af60cb3 (patch) | |
tree | 40360bd900ac173527921123cc6c2442766762d5 /src/Arithmetic/MontgomeryReduction | |
parent | 3b1d3856b138d84cbb0429a6bcdaa9080233fa9b (diff) |
WBW-montgomery: Fill in most context variables
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
4 files changed, 662 insertions, 549 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v new file mode 100644 index 000000000..c7bf317ec --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Definition.v @@ -0,0 +1,81 @@ +(*** Word-By-Word Montgomery Multiplication *) +(** This file implements Montgomery Form, Montgomery Reduction, and + Montgomery Multiplication on an abstract [T]. We follow "Fast Prime + Field Elliptic Curve Cryptography with 256 Bit Primes", + https://eprint.iacr.org/2013/816.pdf. *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. + +(** Quoting from page 7 of "Fast Prime + Field Elliptic Curve Cryptography with 256 Bit Primes", + https://eprint.iacr.org/2013/816.pdf: *) +(** * Algorithm 1: Word-by-Word Montgomery Multiplication (WW-MM) *) +(** Input: [p < 2ˡ] (odd modulus), + [0 ≤ a, b < p], [l = s×k] + Output: [a×b×2⁻ˡ mod p] + Pre-computed: [k0 = -p⁻¹ mod 2ˢ] + Flow +<< +1. T = a×b + For i = 1 to k do + 2. T1 = T mod 2ˢ + 3. Y = T1 × k0 mod 2ˢ + 4. T2 = Y × p + 5. T3 = (T + T2) + 6. T = T3 / 2ˢ + End For +7. If T ≥ p then X = T – p; + else X = T +Return X +>> *) +Local Open Scope Z_scope. + +Section WordByWordMontgomery. + Local Coercion Z.pos : positive >-> Z. + Context + {T : Type} + {eval : T -> Z} + {numlimbs : T -> nat} + {zero : nat -> T} + {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {scmul : Z -> T -> T} (* uses double-output multiply *) + {R : positive} + {add : T -> T -> T} (* joins carry *) + {drop_high : T -> T} (* drops the highest limb *) + (N : T). + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (B : T) (k : Z). + Context (A S : T). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + Local Definition A_a := dlet p := divmod A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. + Local Definition S1 := add S (scmul a B). + Local Definition s := snd (divmod S1). + Local Definition q := s * k mod r. + Local Definition S2 := add S1 (scmul q N). + Local Definition S3 := fst (divmod S2). + Local Definition S4 := drop_high S3. + End Iteration. + + Section loop. + Context (A B : T) (k : Z) (S' : T). + + Definition redc_body : T * T -> T * T + := fun '(A, S') => (A' A, S4 B k A S'). + + Fixpoint redc_loop (count : nat) : T * T -> T * T + := match count with + | O => fun A_S => A_S + | S count' => fun A_S => redc_loop count' (redc_body A_S) + end. + + Definition redc : T + := snd (redc_loop (numlimbs A) (A, zero (1 + numlimbs B))). + End loop. +End WordByWordMontgomery. + +Create HintDb word_by_word_montgomery. +Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Proofs.v new file mode 100644 index 000000000..056b816a2 --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Proofs.v @@ -0,0 +1,486 @@ +(*** Word-By-Word Montgomery Multiplication Proofs *) +Require Import Coq.Arith.Arith. +Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.NatUtil. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.Tactics.SetEvars. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Local Open Scope Z_scope. + +Section WordByWordMontgomery. + Context + {T : Type} + {eval : T -> Z} + {numlimbs : T -> nat} + {zero : nat -> T} + {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {r_big : r > 1} + {small : T -> Prop} + {eval_zero : forall n, eval (zero n) = 0} + {numlimbs_zero : forall n, numlimbs (zero n) = n} + {eval_div : forall v, small v -> eval (fst (divmod v)) = eval v / r} + {eval_mod : forall v, small v -> snd (divmod v) = eval v mod r} + {small_div : forall v, small v -> small (fst (divmod v))} + {numlimbs_div : forall v, numlimbs (fst (divmod v)) = pred (numlimbs v)} + {scmul : Z -> T -> T} (* uses double-output multiply *) + {eval_scmul: forall a v, eval (scmul a v) = a * eval v} + {numlimbs_scmul : forall a v, 0 <= a < r -> numlimbs (scmul a v) = S (numlimbs v)} + {R : positive} + {R_numlimbs : nat} + {R_correct : R = r^Z.of_nat R_numlimbs :> Z} + {add : T -> T -> T} (* joins carry *) + {eval_add : forall a b, eval (add a b) = eval a + eval b} + {small_add : forall a b, small (add a b)} + {numlimbs_add : forall a b, numlimbs (add a b) = Datatypes.S (max (numlimbs a) (numlimbs b))} + {drop_high : T -> T} (* drops things after [S R_numlimbs] *) + {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} + {numlimbs_drop_high : forall v, numlimbs (drop_high v) = min (numlimbs v) (S R_numlimbs)} + (N : T) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) + (N_lt_R : eval N < R) + (B : T) + (B_bounds : 0 <= eval B < R) + ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)). + Context (k : Z) (k_correct : k * eval N mod r = -1). + + Create HintDb push_numlimbs discriminated. + Create HintDb push_eval discriminated. + Local Ltac t_small := + repeat first [ assumption + | apply small_add + | apply small_div + | apply Z_mod_lt + | solve [ auto ] + | lia + | progress autorewrite with push_eval + | progress autorewrite with push_numlimbs ]. + Hint Rewrite + eval_zero + eval_div + eval_mod + eval_add + eval_scmul + eval_drop_high + using (repeat autounfold with word_by_word_montgomery; t_small) + : push_eval. + Hint Rewrite + numlimbs_zero + numlimbs_div + numlimbs_add + numlimbs_scmul + numlimbs_drop_high + using (repeat autounfold with word_by_word_montgomery; t_small) + : push_numlimbs. + Hint Rewrite <- Max.succ_max_distr pred_Sn Min.succ_min_distr : push_numlimbs. + + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (A S : T) + (small_A : small A) + (S_nonneg : 0 <= eval S). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + + Local Coercion eval : T >-> Z. + + Local Notation a := (@WordByWord.Abstract.Definition.a T divmod A). + Local Notation A' := (@WordByWord.Abstract.Definition.A' T divmod A). + Local Notation S1 := (@WordByWord.Abstract.Definition.S1 T divmod scmul add B A S). + Local Notation S2 := (@WordByWord.Abstract.Definition.S2 T divmod r scmul add N B k A S). + Local Notation S3 := (@WordByWord.Abstract.Definition.S3 T divmod r scmul add N B k A S). + Local Notation S4 := (@WordByWord.Abstract.Definition.S4 T divmod r scmul add drop_high N B k A S). + + Lemma S3_bound + : eval S < eval N + eval B + -> eval S3 < eval N + eval B. + Proof. + assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) + by (intros x y; pose proof (Z_mod_lt x y); omega). + intro HS. + unfold S3, WordByWord.Abstract.Definition.S2, WordByWord.Abstract.Definition.S1. + autorewrite with push_eval; []. + eapply Z.le_lt_trans. + { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); + [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. + Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; + repeat autounfold with word_by_word_montgomery; + autorewrite with push_eval; + try Z.zero_bounds; + auto with lia. } + rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. + autorewrite with zsimplify. + omega. + Qed. + + Lemma small_A' + : small A'. + Proof. + repeat autounfold with word_by_word_montgomery; auto. + Qed. + + Lemma small_S3 + : small S3. + Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. + + Lemma S3_nonneg : 0 <= eval S3. + Proof. + repeat autounfold with word_by_word_montgomery; + autorewrite with push_eval; []. + rewrite ?Npos_correct; Z.zero_bounds; lia. + Qed. + + Lemma S4_nonneg : 0 <= eval S4. + Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed. + + Lemma S4_bound + : eval S < eval N + eval B + -> eval S4 < eval N + eval B. + Proof. + intro H; pose proof (S3_bound H); pose proof S3_nonneg. + unfold S4. + rewrite eval_drop_high by apply small_S3. + rewrite Z.mod_small by nia. + assumption. + Qed. + + Lemma numlimbs_S4 : numlimbs S4 = min (max (1 + numlimbs S) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). + Proof. + cbn [plus]. + repeat autounfold with word_by_word_montgomery. + repeat autorewrite with push_numlimbs. + change Init.Nat.max with Nat.max. + rewrite <- ?(Max.max_assoc (numlimbs S)). + reflexivity. + Qed. + + Lemma S1_eq : eval S1 = S + a*B. + Proof. + cbv [S1 a WordByWord.Abstract.Definition.A']. + repeat autorewrite with push_eval. + reflexivity. + Qed. + + Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N. + Proof. + cbv [S2 WordByWord.Abstract.Definition.q WordByWord.Abstract.Definition.s]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. + Qed. + + Lemma S2_mod_r : S2 mod r = 0. + cbv [S2 WordByWord.Abstract.Definition.q WordByWord.Abstract.Definition.s]; autorewrite with push_eval. + assert (r > 0) by lia. + assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). + { destruct (Z.eq_dec r 1) as [H'|H']. + { rewrite H'; split; reflexivity. } + { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } + autorewrite with pull_Zmod. + replace 0 with (0 mod r) by apply Zmod_0_l. + eapply F.eq_of_Z_iff. + repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. + rewrite <-Algebra.Hierarchy.associative. + replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). + { cbv [F.of_Z F.add]; simpl. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. + simpl. + rewrite (proj1 Hr), Z.mul_sub_distr_l. + push_Zmod; pull_Zmod. + autorewrite with zsimplify; reflexivity. } + { rewrite <- F.of_Z_mul. + rewrite F.of_Z_mod. + rewrite k_correct. + cbv [F.of_Z F.add F.opp F.one]; simpl. + change (-(1)) with (-1) in *. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. + rewrite (proj1 Hr), (proj2 Hr); reflexivity. } + Qed. + + Lemma S3_mod_N + : S3 mod N = (S + a*B)*ri mod N. + Proof. + cbv [S3]; autorewrite with push_eval cancel_pair. + pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; + cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. + etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| + rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. + rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval. + autorewrite with push_Zmod; + rewrite S2_mod_r; + autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma S4_mod_N + (Hbound : eval S < eval N + eval B) + : S4 mod N = (S + a*B)*ri mod N. + Proof. + pose proof (S3_bound Hbound); pose proof S3_nonneg. + unfold S4; autorewrite with push_eval. + rewrite (Z.mod_small _ (r * _)) by nia. + apply S3_mod_N. + Qed. + End Iteration. + + Local Notation redc_body := (@redc_body T divmod r scmul add drop_high N B k). + Local Notation redc_loop := (@redc_loop T divmod r scmul add drop_high N B k). + Local Notation redc A := (@redc T numlimbs zero divmod r scmul add drop_high N A B k). + + Lemma redc_loop_comm_body count + : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S). + Proof. + induction count as [|count IHcount]; try reflexivity. + simpl; intro; rewrite IHcount; reflexivity. + Qed. + + Section body. + Context (A_S : T * T). + Let A:=fst A_S. + Let S:=snd A_S. + Let A_a:=divmod A. + Let a:=snd A_a. + Context (small_A : small A) + (S_bound : 0 <= eval S < eval N + eval B). + + Lemma small_fst_redc_body : small (fst (redc_body A_S)). + Proof. destruct A_S; apply small_A'; assumption. Qed. + Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). + Proof. destruct A_S; apply S4_nonneg; assumption. Qed. + + Lemma snd_redc_body_mod_N + : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). + Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed. + + Lemma fst_redc_body + : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. + Proof. + destruct A_S; simpl; unfold WordByWord.Abstract.Definition.A', WordByWord.Abstract.Definition.A_a, Let_In, a, A_a, A; simpl. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma fst_redc_body_mod_N + : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). + Proof. + rewrite fst_redc_body. + etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. + unfold a, A_a, A. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma redc_body_bound + : eval S < eval N + eval B + -> eval (snd (redc_body A_S)) < eval N + eval B. + Proof. + destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. + Qed. + + Lemma numlimbs_redc_body : numlimbs (snd (redc_body A_S)) + = min (max (1 + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). + Proof. destruct A_S; apply numlimbs_S4; assumption. Qed. + End body. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.of_nat !_. + Local Ltac induction_loop count IHcount + := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | rewrite redc_loop_comm_body in * ]. + Lemma redc_loop_good A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : small (fst (redc_loop count A_S)) + /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. + induction_loop count IHcount; auto; []. + change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). + destruct_head'_and. + repeat first [ apply conj + | apply small_fst_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | solve [ auto ] ]. + Qed. + + Lemma redc_loop_bound A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. apply redc_loop_good; assumption. Qed. + + Local Ltac t_min_max_step _ := + match goal with + | [ |- context[Init.Nat.max ?x ?y] ] + => first [ rewrite (Max.max_l x y) by omega + | rewrite (Max.max_r x y) by omega ] + | [ |- context[Init.Nat.min ?x ?y] ] + => first [ rewrite (Min.min_l x y) by omega + | rewrite (Min.min_r x y) by omega ] + | _ => progress change Init.Nat.max with Nat.max + | _ => progress change Init.Nat.min with Nat.min + end. + + Lemma numlimbs_redc_loop A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + (Hnumlimbs : (R_numlimbs <= numlimbs (snd A_S))%nat) + : numlimbs (snd (redc_loop count A_S)) + = match count with + | O => numlimbs (snd A_S) + | S _ => 1 + R_numlimbs + end%nat. + Proof. + assert (Hgen + : numlimbs (snd (redc_loop count A_S)) + = match count with + | O => numlimbs (snd A_S) + | S _ => min (max (count + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs) + end). + { induction_loop count IHcount; [ reflexivity | ]. + rewrite numlimbs_redc_body by (try apply redc_loop_good; auto). + rewrite IHcount; clear IHcount. + destruct count; [ reflexivity | ]. + destruct (Compare_dec.le_lt_dec (1 + max (1 + numlimbs B) (numlimbs N)) (S count + numlimbs (snd A_S))), + (Compare_dec.le_lt_dec (1 + R_numlimbs) (S count + numlimbs (snd A_S))), + (Compare_dec.le_lt_dec (1 + R_numlimbs) (1 + max (1 + numlimbs B) (numlimbs N))); + repeat first [ reflexivity + | t_min_max_step () + | progress autorewrite with push_numlimbs + | rewrite Nat.min_comm, Nat.min_max_distr ]. } + rewrite Hgen; clear Hgen. + destruct count; [ reflexivity | ]. + repeat apply Max.max_case_strong; apply Min.min_case_strong; omega. + Qed. + + + Lemma fst_redc_loop A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). + Proof. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite fst_redc_body, IHcount + by (apply redc_loop_good; auto). + rewrite Zdiv_Zdiv by Z.zero_bounds. + rewrite <- (Z.pow_1_r r) at 2. + rewrite <- Z.pow_add_r by lia. + replace (Z.of_nat count + 1) with (Z.of_nat (S count)) by (simpl; lia). + reflexivity. } + Qed. + + Lemma fst_redc_loop_mod_N A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) mod (eval N) + = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) + * ri^(Z.of_nat count) mod (eval N). + Proof. + rewrite fst_redc_loop by assumption. + destruct count. + { simpl; autorewrite with zsimplify; reflexivity. } + { etransitivity; + [ eapply Z.div_to_inv_modulo; + try solve [ eassumption + | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] + | ]. + { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. + { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } + { lia. } } + reflexivity. } + Qed. + + Local Arguments Z.pow : simpl never. + Lemma snd_redc_loop_mod_N A_S count + (Hsmall : small (fst A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (eval (snd (redc_loop count A_S))) mod (eval N) + = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). + Proof. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { simpl; rewrite snd_redc_body_mod_N + by (apply redc_loop_good; auto). + push_Zmod; rewrite IHcount; pull_Zmod. + autorewrite with push_eval; [ | apply redc_loop_good; auto.. ]; []. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + destruct A_S as [A S]. + cbn [fst snd]. + change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). + rewrite !Z.mul_add_distr_r. + rewrite <- !Z.mul_assoc. + replace (ri^(Z.of_nat count) * ri) with (ri^(Z.of_nat (Datatypes.S count))) + by (change (Datatypes.S count) with (1 + count)%nat; + autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). + rewrite <- !Z.add_assoc. + apply Z.add_mod_Proper; [ reflexivity | ]. + unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). + rewrite fst_redc_loop by (try apply redc_loop_good; auto; omega). + cbn [fst]. + rewrite Z.mod_pull_div by lia. + erewrite Z.div_to_inv_modulo; + [ + | solve [ eassumption | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] + | erewrite <- Z.pow_mul_l, <- Z.pow_1_l; + [ apply Z.pow_mod_Proper; [ eassumption | reflexivity ] + | lia ] ]. + pull_Zmod. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia + | rewrite (Z.mul_comm _ ri) + | rewrite (Z.mul_assoc _ ri _) + | rewrite (Z.mul_comm _ (ri^_)) + | rewrite (Z.mul_assoc _ (ri^_) _) ]. + repeat first [ rewrite <- Z.mul_assoc + | rewrite <- Z.mul_add_distr_l + | rewrite (Z.mul_comm _ (eval B)) + | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; + rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) + | rewrite Zplus_minus + | reflexivity ]. } + Qed. + + Lemma redc_bound A + (small_A : small A) + : 0 <= eval (redc A) < eval N + eval B. + Proof. + unfold redc. + apply redc_loop_good; simpl; autorewrite with push_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma numlimbs_redc_gen A (small_A : small A) (Hnumlimbs : (R_numlimbs <= numlimbs B)%nat) + : numlimbs (redc A) + = match numlimbs A with + | O => S (numlimbs B) + | _ => S R_numlimbs + end. + Proof. + unfold redc; rewrite numlimbs_redc_loop by (cbn [fst snd]; t_small); + cbn [snd]; rewrite ?numlimbs_zero. + reflexivity. + Qed. + Lemma numlimbs_redc A (small_A : small A) (Hnumlimbs : R_numlimbs = numlimbs B) + : numlimbs (redc A) = S (numlimbs B). + Proof. rewrite numlimbs_redc_gen; subst; auto; destruct (numlimbs A); reflexivity. Qed. + + Lemma redc_mod_N A (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat (numlimbs A)) + : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). + Proof. + unfold redc. + rewrite snd_redc_loop_mod_N; cbn [fst snd]; + autorewrite with push_eval zsimplify; + [ | rewrite ?Npos_correct; auto; lia.. ]. + Z.rewrite_mod_small. + reflexivity. + Qed. +End WordByWordMontgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index c7bf317ec..898adcff7 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -1,81 +1,17 @@ (*** Word-By-Word Montgomery Multiplication *) (** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [T]. We follow "Fast Prime - Field Elliptic Curve Cryptography with 256 Bit Primes", - https://eprint.iacr.org/2013/816.pdf. *) -Require Import Coq.ZArith.ZArith. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.LetIn. - -(** Quoting from page 7 of "Fast Prime - Field Elliptic Curve Cryptography with 256 Bit Primes", - https://eprint.iacr.org/2013/816.pdf: *) -(** * Algorithm 1: Word-by-Word Montgomery Multiplication (WW-MM) *) -(** Input: [p < 2ˡ] (odd modulus), - [0 ≤ a, b < p], [l = s×k] - Output: [a×b×2⁻ˡ mod p] - Pre-computed: [k0 = -p⁻¹ mod 2ˢ] - Flow -<< -1. T = a×b - For i = 1 to k do - 2. T1 = T mod 2ˢ - 3. Y = T1 × k0 mod 2ˢ - 4. T2 = Y × p - 5. T3 = (T + T2) - 6. T = T3 / 2ˢ - End For -7. If T ≥ p then X = T – p; - else X = T -Return X ->> *) -Local Open Scope Z_scope. - -Section WordByWordMontgomery. - Local Coercion Z.pos : positive >-> Z. - Context - {T : Type} - {eval : T -> Z} - {numlimbs : T -> nat} - {zero : nat -> T} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {scmul : Z -> T -> T} (* uses double-output multiply *) - {R : positive} - {add : T -> T -> T} (* joins carry *) - {drop_high : T -> T} (* drops the highest limb *) - (N : T). - - (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) - Section Iteration. - Context (B : T) (k : Z). - Context (A S : T). - (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) - Local Definition A_a := dlet p := divmod A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. - Local Definition S1 := add S (scmul a B). - Local Definition s := snd (divmod S1). - Local Definition q := s * k mod r. - Local Definition S2 := add S1 (scmul q N). - Local Definition S3 := fst (divmod S2). - Local Definition S4 := drop_high S3. - End Iteration. - - Section loop. - Context (A B : T) (k : Z) (S' : T). - - Definition redc_body : T * T -> T * T - := fun '(A, S') => (A' A, S4 B k A S'). - - Fixpoint redc_loop (count : nat) : T * T -> T * T - := match count with - | O => fun A_S => A_S - | S count' => fun A_S => redc_loop count' (redc_body A_S) - end. - - Definition redc : T - := snd (redc_loop (numlimbs A) (A, zero (1 + numlimbs B))). - End loop. -End WordByWordMontgomery. - -Create HintDb word_by_word_montgomery. -Hint Unfold S4 S3 S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. + Montgomery Multiplication on an abstract [list ℤ]. We follow + "Fast Prime Field Elliptic Curve Cryptography with 256 Bit + Primes", https://eprint.iacr.org/2013/816.pdf. *) +Require Import Coq.ZArith.BinInt. +Require Import Crypto.Arithmetic.Saturated. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. + +Section redc. + (** XXX TODO: Figure out how to fill in these context variables *) + Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *). + + (** XXX TODO: pick better names for the arguments to this definition. *) + Definition redc {r : positive} {R_numlimbs : nat} (N A B : T) (k : Z) : T + := @redc T numlimbs zero divmod r (@scmul (Z.pos r) mul_split) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N A B k. +End redc. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index c90b55fbc..d51c16673 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -1,486 +1,96 @@ (*** Word-By-Word Montgomery Multiplication Proofs *) -Require Import Coq.Arith.Arith. -Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia. -Require Import Crypto.Util.LetIn. -Require Import Crypto.Util.Prod. -Require Import Crypto.Util.NatUtil. -Require Import Crypto.Util.ZUtil. -Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic. +Require Import Coq.ZArith.BinInt. +Require Import Coq.micromega.Lia. +Require Import Crypto.Arithmetic.Saturated. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Proofs. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. -Require Import Crypto.Algebra.Ring. -Require Import Crypto.Util.Sigma. -Require Import Crypto.Util.Tactics.SetEvars. -Require Import Crypto.Util.Tactics.SubstEvars. -Require Import Crypto.Util.Tactics.DestructHead. -Local Open Scope Z_scope. +Require Import Crypto.Util.ZUtil. +Local Open Scope Z_scope. +Local Coercion Z.pos : positive >-> Z. Section WordByWordMontgomery. - Context - {T : Type} - {eval : T -> Z} - {numlimbs : T -> nat} - {zero : nat -> T} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {r : positive} - {r_big : r > 1} - {small : T -> Prop} - {eval_zero : forall n, eval (zero n) = 0} - {numlimbs_zero : forall n, numlimbs (zero n) = n} - {eval_div : forall v, small v -> eval (fst (divmod v)) = eval v / r} - {eval_mod : forall v, small v -> snd (divmod v) = eval v mod r} - {small_div : forall v, small (fst (divmod v))} - {numlimbs_div : forall v, numlimbs (fst (divmod v)) = pred (numlimbs v)} - {scmul : Z -> T -> T} (* uses double-output multiply *) - {eval_scmul: forall a v, eval (scmul a v) = a * eval v} - {numlimbs_scmul : forall a v, 0 <= a < r -> numlimbs (scmul a v) = S (numlimbs v)} - {R : positive} - {R_numlimbs : nat} - {R_correct : R = r^Z.of_nat R_numlimbs :> Z} - {add : T -> T -> T} (* joins carry *) - {eval_add : forall a b, eval (add a b) = eval a + eval b} - {small_add : forall a b, small (add a b)} - {numlimbs_add : forall a b, numlimbs (add a b) = Datatypes.S (max (numlimbs a) (numlimbs b))} - {drop_high : T -> T} (* drops things after [S R_numlimbs] *) - {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} - {numlimbs_drop_high : forall v, numlimbs (drop_high v) = min (numlimbs v) (S R_numlimbs)} - (N : T) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) - (N_lt_R : eval N < R) - (B : T) - (B_bounds : 0 <= eval B < R) - ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)). - Context (k : Z) (k_correct : k * eval N mod r = -1). - - Create HintDb push_numlimbs discriminated. - Create HintDb push_eval discriminated. - Local Ltac t_small := - repeat first [ assumption - | apply small_add - | apply small_div - | apply Z_mod_lt - | solve [ auto ] - | lia - | progress autorewrite with push_eval - | progress autorewrite with push_numlimbs ]. - Hint Rewrite - eval_zero - eval_div - eval_mod - eval_add - eval_scmul - eval_drop_high - using (repeat autounfold with word_by_word_montgomery; t_small) - : push_eval. - Hint Rewrite - numlimbs_zero - numlimbs_div - numlimbs_add - numlimbs_scmul - numlimbs_drop_high - using (repeat autounfold with word_by_word_montgomery; t_small) - : push_numlimbs. - Hint Rewrite <- Max.succ_max_distr pred_Sn Min.succ_min_distr : push_numlimbs. - - - (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) - Section Iteration. - Context (A S : T) - (small_A : small A) - (S_nonneg : 0 <= eval S). - (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) - - Local Coercion eval : T >-> Z. - - Local Notation a := (@WordByWord.Definition.a T divmod A). - Local Notation A' := (@WordByWord.Definition.A' T divmod A). - Local Notation S1 := (@WordByWord.Definition.S1 T divmod scmul add B A S). - Local Notation S2 := (@WordByWord.Definition.S2 T divmod r scmul add N B k A S). - Local Notation S3 := (@WordByWord.Definition.S3 T divmod r scmul add N B k A S). - Local Notation S4 := (@WordByWord.Definition.S4 T divmod r scmul add drop_high N B k A S). - - Lemma S3_bound - : eval S < eval N + eval B - -> eval S3 < eval N + eval B. - Proof. - assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) - by (intros x y; pose proof (Z_mod_lt x y); omega). - intro HS. - unfold S3, WordByWord.Definition.S2, WordByWord.Definition.S1. - autorewrite with push_eval; []. - eapply Z.le_lt_trans. - { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); - [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. - Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; - repeat autounfold with word_by_word_montgomery; - autorewrite with push_eval; - try Z.zero_bounds; - auto with lia. } - rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. - autorewrite with zsimplify. - omega. - Qed. - - Lemma small_A' - : small A'. - Proof. - repeat autounfold with word_by_word_montgomery; auto. - Qed. - - Lemma small_S3 - : small S3. - Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. - - Lemma S3_nonneg : 0 <= eval S3. - Proof. - repeat autounfold with word_by_word_montgomery; - autorewrite with push_eval; []. - rewrite ?Npos_correct; Z.zero_bounds; lia. - Qed. - - Lemma S4_nonneg : 0 <= eval S4. - Proof. unfold S4; rewrite eval_drop_high by apply small_S3; Z.zero_bounds. Qed. - - Lemma S4_bound - : eval S < eval N + eval B - -> eval S4 < eval N + eval B. - Proof. - intro H; pose proof (S3_bound H); pose proof S3_nonneg. - unfold S4. - rewrite eval_drop_high by apply small_S3. - rewrite Z.mod_small by nia. - assumption. - Qed. - - Lemma numlimbs_S4 : numlimbs S4 = min (max (1 + numlimbs S) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). - Proof. - cbn [plus]. - repeat autounfold with word_by_word_montgomery. - repeat autorewrite with push_numlimbs. - change Init.Nat.max with Nat.max. - rewrite <- ?(Max.max_assoc (numlimbs S)). - reflexivity. - Qed. - - Lemma S1_eq : eval S1 = S + a*B. - Proof. - cbv [S1 a WordByWord.Definition.A']. - repeat autorewrite with push_eval. - reflexivity. - Qed. - - Lemma S2_mod_N : (eval S2) mod N = (S + a*B) mod N. - Proof. - cbv [S2 WordByWord.Definition.q WordByWord.Definition.s]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. - Qed. - - Lemma S2_mod_r : S2 mod r = 0. - cbv [S2 WordByWord.Definition.q WordByWord.Definition.s]; autorewrite with push_eval. - assert (r > 0) by lia. - assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). - { destruct (Z.eq_dec r 1) as [H'|H']. - { rewrite H'; split; reflexivity. } - { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } - autorewrite with pull_Zmod. - replace 0 with (0 mod r) by apply Zmod_0_l. - eapply F.eq_of_Z_iff. - repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. - rewrite <-Algebra.Hierarchy.associative. - replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). - { cbv [F.of_Z F.add]; simpl. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. - simpl. - rewrite (proj1 Hr), Z.mul_sub_distr_l. - push_Zmod; pull_Zmod. - autorewrite with zsimplify; reflexivity. } - { rewrite <- F.of_Z_mul. - rewrite F.of_Z_mod. - rewrite k_correct. - cbv [F.of_Z F.add F.opp F.one]; simpl. - change (-(1)) with (-1) in *. - apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. - rewrite (proj1 Hr), (proj2 Hr); reflexivity. } - Qed. - - Lemma S3_mod_N - : S3 mod N = (S + a*B)*ri mod N. - Proof. - cbv [S3]; autorewrite with push_eval cancel_pair. - pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; - cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. - etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| - rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. - rewrite <-S2_mod_N; repeat (f_equal; []); autorewrite with push_eval. - autorewrite with push_Zmod; - rewrite S2_mod_r; - autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma S4_mod_N - (Hbound : eval S < eval N + eval B) - : S4 mod N = (S + a*B)*ri mod N. - Proof. - pose proof (S3_bound Hbound); pose proof S3_nonneg. - unfold S4; autorewrite with push_eval. - rewrite (Z.mod_small _ (r * _)) by nia. - apply S3_mod_N. - Qed. - End Iteration. - - Local Notation redc_body := (@redc_body T divmod r scmul add drop_high N B k). - Local Notation redc_loop := (@redc_loop T divmod r scmul add drop_high N B k). - Local Notation redc A := (@redc T numlimbs zero divmod r scmul add drop_high N A B k). - - Lemma redc_loop_comm_body count - : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S). + (** XXX TODO: Figure out how to fill in these context variables *) + Context {mul_split : Z -> Z -> Z -> Z * Z} (* first argument is where to split output; [mul_split s x y] gives ((x * y) mod s, (x * y) / s) *) + {mul_split_mod : forall s x y, + fst (mul_split s x y) = (x * y) mod s} + {mul_split_div : forall s x y, + snd (mul_split s x y) = (x * y) / s}. + + (** XXX TODO: pick better names for things like [R_numlimbs] *) + Context (r : positive) + (R_numlimbs : nat). + Local Notation small := (@small (Z.pos r)). + Local Notation eval := (@eval (Z.pos r)). + Local Notation add := (@add (Z.pos r)). + Local Notation scmul := (@scmul (Z.pos r) mul_split). + Local Notation eval_zero := (@eval_zero (Z.pos r)). + Local Notation eval_div := (@eval_div (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation eval_mod := (@eval_mod (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation small_div := (@small_div (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation numlimbs_div := (@numlimbs_div (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation eval_scmul := (@eval_scmul (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation numlimbs_scmul := (@numlimbs_scmul (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation eval_add := (@eval_add (Z.pos r) (Zorder.Zgt_pos_0 _)). + Local Notation small_add := (@small_add (Z.pos r) (Zorder.Zgt_pos_0 _)). + Local Notation numlimbs_add := (@numlimbs_add (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div). + Local Notation drop_high := (@drop_high (S R_numlimbs)). + Local Notation numlimbs_drop_high := (@numlimbs_drop_high (Z.pos r) (Zorder.Zgt_pos_0 _) mul_split mul_split_mod mul_split_div (S R_numlimbs)). + Context (N A B : T) + (k : Z) + ri + (r_big : r > 1) + (small_A : small A) + (Hnumlimbs_le : (R_numlimbs <= numlimbs B)%nat) + (Hnumlimbs_eq : R_numlimbs = numlimbs B) + (A_bound : 0 <= eval A < Z.pos r ^ Z.of_nat (numlimbs A)) + (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) + (N_bound : 0 < eval N < r^Z.of_nat R_numlimbs) + (B_bound' : 0 <= eval B < r^Z.of_nat R_numlimbs) + (k_correct : k * eval N mod r = -1). + Let R : positive := match (Z.pos r ^ Z.of_nat R_numlimbs)%Z with + | Z.pos R => R + | _ => 1%positive + end. + Let Npos : positive := match eval N with + | Z.pos N => N + | _ => 1%positive + end. + Local Lemma R_correct : Z.pos R = Z.pos r ^ Z.of_nat R_numlimbs. Proof. - induction count as [|count IHcount]; try reflexivity. - simpl; intro; rewrite IHcount; reflexivity. + assert (0 < r^Z.of_nat R_numlimbs) by (apply Z.pow_pos_nonneg; lia). + subst R; destruct (Z.pos r ^ Z.of_nat R_numlimbs) eqn:?; [ | reflexivity | ]; + lia. Qed. + Local Lemma Npos_correct: eval N = Z.pos Npos. + Proof. subst Npos; destruct (eval N); [ | reflexivity | ]; lia. Qed. + Local Lemma N_lt_R : eval N < R. + Proof. rewrite R_correct; apply N_bound. Qed. + Local Lemma B_bound : 0 <= eval B < R. + Proof. rewrite R_correct; apply B_bound'. Qed. - Section body. - Context (A_S : T * T). - Let A:=fst A_S. - Let S:=snd A_S. - Let A_a:=divmod A. - Let a:=snd A_a. - Context (small_A : small A) - (S_bound : 0 <= eval S < eval N + eval B). - - Lemma small_fst_redc_body : small (fst (redc_body A_S)). - Proof. destruct A_S; apply small_A'; assumption. Qed. - Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). - Proof. destruct A_S; apply S4_nonneg; assumption. Qed. - - Lemma snd_redc_body_mod_N - : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). - Proof. destruct A_S; apply S4_mod_N; auto; omega. Qed. - - Lemma fst_redc_body - : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. - Proof. - destruct A_S; simpl; unfold WordByWord.Definition.A', WordByWord.Definition.A_a, Let_In, a, A_a, A; simpl. - autorewrite with push_eval. - reflexivity. - Qed. - Lemma fst_redc_body_mod_N - : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). - Proof. - rewrite fst_redc_body. - etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. - unfold a, A_a, A. - autorewrite with push_eval. - reflexivity. - Qed. - - Lemma redc_body_bound - : eval S < eval N + eval B - -> eval (snd (redc_body A_S)) < eval N + eval B. - Proof. - destruct A_S; apply S4_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. - Qed. - - Lemma numlimbs_redc_body : numlimbs (snd (redc_body A_S)) - = min (max (1 + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs). - Proof. destruct A_S; apply numlimbs_S4; assumption. Qed. - End body. - - Local Arguments Z.pow !_ !_. - Local Arguments Z.of_nat !_. - Local Ltac induction_loop count IHcount - := induction count as [|count IHcount]; intros; cbn [redc_loop] in *; [ | rewrite redc_loop_comm_body in * ]. - Lemma redc_loop_good A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : small (fst (redc_loop count A_S)) - /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Local Lemma eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs). Proof. - induction_loop count IHcount; auto; []. - change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). - destruct_head'_and. - repeat first [ apply conj - | apply small_fst_redc_body - | apply redc_body_bound - | apply snd_redc_body_nonneg - | solve [ auto ] ]. + intros; erewrite eval_drop_high by (eassumption || lia). + f_equal; unfold uweight. + rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity. Qed. - Lemma redc_loop_bound A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. - Proof. apply redc_loop_good; assumption. Qed. - - Local Ltac t_min_max_step _ := - match goal with - | [ |- context[Init.Nat.max ?x ?y] ] - => first [ rewrite (Max.max_l x y) by omega - | rewrite (Max.max_r x y) by omega ] - | [ |- context[Init.Nat.min ?x ?y] ] - => first [ rewrite (Min.min_l x y) by omega - | rewrite (Min.min_r x y) by omega ] - | _ => progress change Init.Nat.max with Nat.max - | _ => progress change Init.Nat.min with Nat.min - end. - - Lemma numlimbs_redc_loop A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - (Hnumlimbs : (R_numlimbs <= numlimbs (snd A_S))%nat) - : numlimbs (snd (redc_loop count A_S)) - = match count with - | O => numlimbs (snd A_S) - | S _ => 1 + R_numlimbs - end%nat. - Proof. - assert (Hgen - : numlimbs (snd (redc_loop count A_S)) - = match count with - | O => numlimbs (snd A_S) - | S _ => min (max (count + numlimbs (snd A_S)) (1 + max (1 + numlimbs B) (numlimbs N))) (1 + R_numlimbs) - end). - { induction_loop count IHcount; [ reflexivity | ]. - rewrite numlimbs_redc_body by (try apply redc_loop_good; auto). - rewrite IHcount; clear IHcount. - destruct count; [ reflexivity | ]. - destruct (Compare_dec.le_lt_dec (1 + max (1 + numlimbs B) (numlimbs N)) (S count + numlimbs (snd A_S))), - (Compare_dec.le_lt_dec (1 + R_numlimbs) (S count + numlimbs (snd A_S))), - (Compare_dec.le_lt_dec (1 + R_numlimbs) (1 + max (1 + numlimbs B) (numlimbs N))); - repeat first [ reflexivity - | t_min_max_step () - | progress autorewrite with push_numlimbs - | rewrite Nat.min_comm, Nat.min_max_distr ]. } - rewrite Hgen; clear Hgen. - destruct count; [ reflexivity | ]. - repeat apply Max.max_case_strong; apply Min.min_case_strong; omega. - Qed. + Local Notation redc := (@redc mul_split r R_numlimbs N A B k). - - Lemma fst_redc_loop A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { rewrite fst_redc_body, IHcount - by (apply redc_loop_good; auto). - rewrite Zdiv_Zdiv by Z.zero_bounds. - rewrite <- (Z.pow_1_r r) at 2. - rewrite <- Z.pow_add_r by lia. - replace (Z.of_nat count + 1) with (Z.of_nat (S count)) by (simpl; lia). - reflexivity. } - Qed. - - Lemma fst_redc_loop_mod_N A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : eval (fst (redc_loop count A_S)) mod (eval N) - = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) - * ri^(Z.of_nat count) mod (eval N). - Proof. - rewrite fst_redc_loop by assumption. - destruct count. - { simpl; autorewrite with zsimplify; reflexivity. } - { etransitivity; - [ eapply Z.div_to_inv_modulo; - try solve [ eassumption - | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] - | ]. - { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. - { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } - { lia. } } - reflexivity. } - Qed. - - Local Arguments Z.pow : simpl never. - Lemma snd_redc_loop_mod_N A_S count - (Hsmall : small (fst A_S)) - (Hbound : 0 <= eval (snd A_S) < eval N + eval B) - : (eval (snd (redc_loop count A_S))) mod (eval N) - = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). - Proof. - induction_loop count IHcount. - { simpl; autorewrite with zsimplify; reflexivity. } - { simpl; rewrite snd_redc_body_mod_N - by (apply redc_loop_good; auto). - push_Zmod; rewrite IHcount; pull_Zmod. - autorewrite with push_eval; [ | apply redc_loop_good; auto.. ]; []. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - destruct A_S as [A S]. - cbn [fst snd]. - change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). - rewrite !Z.mul_add_distr_r. - rewrite <- !Z.mul_assoc. - replace (ri^(Z.of_nat count) * ri) with (ri^(Z.of_nat (Datatypes.S count))) - by (change (Datatypes.S count) with (1 + count)%nat; - autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). - rewrite <- !Z.add_assoc. - apply Z.add_mod_Proper; [ reflexivity | ]. - unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). - rewrite fst_redc_loop by (try apply redc_loop_good; auto; omega). - cbn [fst]. - rewrite Z.mod_pull_div by lia. - erewrite Z.div_to_inv_modulo; - [ - | solve [ eassumption | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] - | erewrite <- Z.pow_mul_l, <- Z.pow_1_l; - [ apply Z.pow_mod_Proper; [ eassumption | reflexivity ] - | lia ] ]. - pull_Zmod. - match goal with - | [ |- ?x mod ?N = ?y mod ?N ] - => change (Z.equiv_modulo N x y) - end. - repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia - | rewrite (Z.mul_comm _ ri) - | rewrite (Z.mul_assoc _ ri _) - | rewrite (Z.mul_comm _ (ri^_)) - | rewrite (Z.mul_assoc _ (ri^_) _) ]. - repeat first [ rewrite <- Z.mul_assoc - | rewrite <- Z.mul_add_distr_l - | rewrite (Z.mul_comm _ (eval B)) - | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; - rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) - | rewrite Zplus_minus - | reflexivity ]. } - Qed. - - Lemma redc_bound A - (small_A : small A) - : 0 <= eval (redc A) < eval N + eval B. - Proof. - unfold redc. - apply redc_loop_good; simpl; autorewrite with push_eval; - rewrite ?Npos_correct; auto; lia. - Qed. - - Lemma numlimbs_redc_gen A (small_A : small A) (Hnumlimbs : (R_numlimbs <= numlimbs B)%nat) - : numlimbs (redc A) + Definition redc_bound : 0 <= eval redc < eval N + eval B + := @redc_bound T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A. + Definition numlimbs_redc_gen + : numlimbs redc = match numlimbs A with | O => S (numlimbs B) | _ => S R_numlimbs - end. - Proof. - unfold redc; rewrite numlimbs_redc_loop by (cbn [fst snd]; t_small); - cbn [snd]; rewrite ?numlimbs_zero. - reflexivity. - Qed. - Lemma numlimbs_redc A (small_A : small A) (Hnumlimbs : R_numlimbs = numlimbs B) - : numlimbs (redc A) = S (numlimbs B). - Proof. rewrite numlimbs_redc_gen; subst; auto; destruct (numlimbs A); reflexivity. Qed. - - Lemma redc_mod_N A (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat (numlimbs A)) - : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). - Proof. - unfold redc. - rewrite snd_redc_loop_mod_N; cbn [fst snd]; - autorewrite with push_eval zsimplify; - [ | rewrite ?Npos_correct; auto; lia.. ]. - Z.rewrite_mod_small. - reflexivity. - Qed. + end + := @numlimbs_redc_gen T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_le. + Definition numlimbs_redc : numlimbs redc = S (numlimbs B) + := @numlimbs_redc T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_eq. + Definition redc_mod_N + : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N) + := @redc_mod_N T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A small_A A_bound. End WordByWordMontgomery. |