diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-10 00:40:01 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-06-10 00:41:11 -0400 |
commit | 2311a022266ea0595244ac398e9f2c073801484c (patch) | |
tree | e057b1ee7bcecffece92e06c4a1b03d54667c1c4 /src/Arithmetic/MontgomeryReduction | |
parent | b6cc64b9915c9fb77deb77dd37b20e067817d5b1 (diff) |
More work in progress on montgomery proofs
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v | 79 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v | 370 |
2 files changed, 336 insertions, 113 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index fba7b4c8b..f440cf16f 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -1,14 +1,11 @@ (*** Word-By-Word Montgomery Multiplication *) (** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on [tuple ℤ]. We follow "Fast Prime + Montgomery Multiplication on an abstract [T]. We follow "Fast Prime Field Elliptic Curve Cryptography with 256 Bit Primes", https://eprint.iacr.org/2013/816.pdf. *) Require Import Coq.ZArith.ZArith. -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated. -Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. -Require Import Crypto.Util.Tuple. Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. (** Quoting from page 7 of "Fast Prime Field Elliptic Curve Cryptography with 256 Bit Primes", @@ -33,32 +30,50 @@ Require Import Crypto.Util.Notations. Return X >> *) Local Open Scope Z_scope. -Section columns. - (** TODO(jadep): implement these *) - Context {T : Type} {length : T -> nat} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {scmul : Z -> T -> T} (* uses double-output multiply *) - {add : T -> T -> T * Z} (* produces carry *) - {join : T * Z -> T} - {zero : nat -> T} - (A B : T) - (bound : Z) - (N : T) - (k : Z) (* [(-1 mod N) mod bound] *). - Definition redc_body : T * T -> T * T - := fun '(A, S') - => let '(A, a) := divmod A in - let '(S', _) := add S' (scmul a B) in - let '(_, q) := divmod (scmul k S') in - let '(S', _) := divmod (join (add S' (scmul q N))) in - (A, S'). - Fixpoint redc_loop (count : nat) : T * T -> T * T - := match count with - | O => fun A_S => A_S - | S count' => fun A_S => redc_loop count' (redc_body A_S) - end. +Section WordByWordMontgomery. + Local Coercion Z.pos : positive >-> Z. + Context + {T : Type} + {eval : T -> Z} + {numlimbs : T -> nat} + {zero : nat -> T} + {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {scmul : Z -> T -> T} (* uses double-output multiply *) + {R : positive} + {add : T -> T -> T} (* joins carry *) + (N : T). - Definition redc : T - := snd (redc_loop (length A) (A, zero (1 + length B))). -End columns. + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (B : T) (k : Z). + Context (A S : T). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + Local Definition A_a := dlet p := divmod A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. + Local Definition S1 := add S (scmul a B). + Local Definition s := snd (divmod S1). + Local Definition q := s * k mod r. + Local Definition cS2 := add S1 (scmul q N). + Local Definition S3 := fst (divmod cS2). + End Iteration. + + Section loop. + Context (A B : T) (k : Z) (S' : T). + + Definition redc_body : T * T -> T * T + := fun '(A, S') => (A' A, S3 B k A S'). + + Fixpoint redc_loop (count : nat) : T * T -> T * T + := match count with + | O => fun A_S => A_S + | S count' => fun A_S => redc_loop count' (redc_body A_S) + end. + + Definition redc : T + := snd (redc_loop (numlimbs A) (A, zero (1 + numlimbs B))). + End loop. +End WordByWordMontgomery. + +Create HintDb word_by_word_montgomery. +Hint Unfold A_a A' a S1 s q cS2 S3 : word_by_word_montgomery. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index bbf7a68f1..e0b2d26d6 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -1,85 +1,170 @@ (*** Word-By-Word Montgomery Multiplication Proofs *) -Require Import Coq.ZArith.ZArith. -Require Import Coq.omega.Omega. -Require Import Crypto.Arithmetic.Core. -Require Import Crypto.Arithmetic.Saturated. -Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. -Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. +Require Import Coq.ZArith.BinInt Coq.ZArith.ZArith Coq.ZArith.Zdiv Coq.micromega.Lia. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems Crypto.Spec.ModularArithmetic. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.ZUtil.EquivModulo. -Require Import Crypto.Util.Notations. -Require Import Crypto.Util.Tactics.BreakMatch. - +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.Tactics.SetEvars. +Require Import Crypto.Util.Tactics.SubstEvars. Local Open Scope Z_scope. -Section montgomery. - Context {T : Type} {length : T -> nat} - {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) - {scmul : Z -> T -> T} (* uses double-output multiply *) - {add : T -> T -> T * Z} (* produces carry *) - {join : T * Z -> T} - {zero : nat -> T} - {to_Z : T -> Z} - (A B : T) - (bound : Z) - (N : T) - (k : Z) (* [(-1 mod N) mod bound] *) - (divmod_div : forall v, to_Z (fst (divmod v)) = to_Z v / bound) - (divmod_mod : forall v, snd (divmod v) = to_Z v mod bound) - (scmul_correct : forall a v, to_Z (scmul a v) = a * to_Z v) - (join_add_correct : forall a b, to_Z (join (add a b)) = to_Z a * to_Z b) - (length_divmod_div : forall v, length (fst (divmod v)) = pred (length v)) - (length_join : forall v, length (join v) = S (length (fst v))) - (length_add : forall a b, length (fst (add a b)) = max (length a) (length b)) - (length_scmul : forall a v, 0 <= a < bound -> length (scmul a v) = S (length v)) - (bound_pos : 0 < bound). - Local Infix "≡" := (Z.equiv_modulo bound). - - Local Notation redc_body := (@redc_body T divmod scmul add join B N k). - Local Notation redc_loop := (@redc_loop T divmod scmul add join B N k). - Local Notation redc := (@redc T length divmod scmul add join zero A B N k). - - Local Ltac start := - unfold redc_body; - repeat match goal with - | [ H : _ * _ |- _ ] => destruct H - | [ |- context[match ?x with pair _ _ => _ end] ] - => rewrite (surjective_pairing x); simpl - end. - - Hint Rewrite divmod_div divmod_mod join_add_correct scmul_correct : rew_db. - Hint Rewrite length_add length_divmod_div length_scmul length_join : rew_db. - Hint Rewrite Max.max_idempotent : rew_db. - - Lemma redc_body_small A_S - : to_Z (snd A_S) < to_Z N + to_Z B - -> to_Z (snd (redc_body A_S)) < to_Z N + to_Z B. - Proof. - start; repeat autorewrite with rew_db. - Admitted. +Section WordByWordMontgomery. + Context + {T : Type} + {eval : T -> Z} + {numlimbs : T -> nat} + {zero : nat -> T} + {divmod : T -> T * Z} (* returns lowest limb and all-but-lowest-limb *) + {r : positive} + {r_big : r > 1} + {eval_zero : forall n, eval (zero n) = 0} + {eval_div : forall v, eval (fst (divmod v)) = eval v / r} + {eval_mod : forall v, snd (divmod v) = eval v mod r} + {scmul : Z -> T -> T} (* uses double-output multiply *) + {eval_scmul: forall a v, eval (scmul a v) = a * eval v} + {R : positive} + {R_big : R > 3} (* needed for [(N + B - 1) / R <= 1] *). + Local Notation bn := (r * R) (only parsing). + Context + {add : T -> T -> T} (* joins carry *) + {eval_add : forall a b, eval (add a b) = eval a + eval b} + {eval_nonneg : forall v, 0 <= eval v} + (N : T) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) + (N_small : eval N < R) + (B : T) + (B_small : eval B < R) + ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)). + Context (k : Z) (k_correct : k * eval N mod r = -1). - Lemma fst_redc_body_length A_S - : length (fst (redc_body A_S)) = pred (length (fst A_S)). - Proof. - start; autorewrite with rew_db; reflexivity. - Qed. - Lemma snd_redc_body_length A_S - : length (snd A_S) = S (max (length B) (length N)) - -> length (snd (redc_body A_S)) = S (max (length B) (length N)). - Proof. - apply Max.max_case_strong; intro Hm; - start; intro H; - repeat first [ progress autorewrite with rew_db - | rewrite H - | reflexivity - | apply Z.mod_pos_bound; assumption - | match goal with - | [ |- context[max ?x ?y] ] - => first [ rewrite (Max.max_l x y) by omega - | rewrite (Max.max_r x y) by omega ] - end ]. - Qed. + Create HintDb push_eval discriminated. + Hint Rewrite + eval_zero + eval_div + eval_mod + eval_add + eval_scmul + : push_eval. + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration. + Context (A S : T) + (S_small : eval S / R <= 1). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + + Local Coercion eval : T >-> Z. + + Local Notation a := (@WordByWord.Definition.a T divmod A). + Local Notation S3 := (@WordByWord.Definition.S3 T divmod r scmul add N B k A S). + Local Notation S1 := (@WordByWord.Definition.S1 T divmod scmul add B A S). + Local Notation cS2 := (@WordByWord.Definition.cS2 T divmod r scmul add N B k A S). + + Lemma S3_bound + : 0 <= eval S < eval N + eval B + -> 0 <= eval S3 < eval N + eval B. + Proof. + assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) + by (intros x y; pose proof (Z_mod_lt x y); omega). + intro HS. + unfold S3, WordByWord.Definition.cS2, WordByWord.Definition.S1. + autorewrite with push_eval. + split; + [ solve + [ autounfold with word_by_word_montgomery; + unfold Let_In; autorewrite with push_eval; + Z.zero_bounds ] + | ]. + eapply Z.le_lt_trans. + { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); + [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. + Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; + autounfold with word_by_word_montgomery; + unfold Let_In; autorewrite with push_eval; + try Z.zero_bounds; + auto with lia. } + rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. + autorewrite with zsimplify. + omega. + Qed. + + Lemma S1_eq : eval S1 = S + a*B. + Proof. + cbv [S1 a WordByWord.Definition.A']. + repeat autorewrite with push_eval. + reflexivity. + Qed. + + Lemma cS2_mod_N : (eval cS2) mod N = (S + a*B) mod N. + Proof. + assert (bn_large : bn >= r) by (unfold bn; nia). + cbv [cS2 WordByWord.Definition.q WordByWord.Definition.s]; autorewrite with push_eval zsimplify. rewrite S1_eq. reflexivity. + Qed. + + Lemma cS2_mod_r : cS2 mod r = 0. + cbv [cS2 WordByWord.Definition.q WordByWord.Definition.s]; autorewrite with push_eval. + assert (r > 0) by lia. + assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). + { destruct (Z.eq_dec r 1) as [H'|H']. + { rewrite H'; split; reflexivity. } + { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } + autorewrite with pull_Zmod. + replace 0 with (0 mod r) by apply Zmod_0_l. + eapply F.eq_of_Z_iff. + repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. + rewrite <-Algebra.Hierarchy.associative. + replace ((F.of_Z r k * F.of_Z r (eval N))%F) with (F.opp (m:=r) F.one). + { cbv [F.of_Z F.add]; simpl. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. + simpl. + rewrite (proj1 Hr), Z.mul_sub_distr_l. + push_Zmod; pull_Zmod. + autorewrite with zsimplify; reflexivity. } + { rewrite <- F.of_Z_mul. + rewrite F.of_Z_mod. + rewrite k_correct. + cbv [F.of_Z F.add F.opp F.one]; simpl. + change (-(1)) with (-1) in *. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. + rewrite (proj1 Hr), (proj2 Hr); reflexivity. } + Qed. + + Lemma S3_mod_N + : S3 mod N = (S + a*B)*ri mod N. + Proof. + assert (r_div_bn : (r | bn)) by apply Z.divide_factor_l. + cbv [S3]; autorewrite with push_eval cancel_pair. + pose proof fun a => Z.div_to_inv_modulo N a r ri eq_refl ri_correct as HH; + cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. + etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| + rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. + rewrite <-cS2_mod_N; repeat (f_equal; []); autorewrite with push_eval. + autorewrite with push_Zmod; + replace (bn mod r) with 0 + by (symmetry; apply Z.mod_divide; try assumption; try lia); + rewrite cS2_mod_r; + autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma small_from_bound + : forall x, 0 <= x < eval N + eval B -> x / R <= 1. + Proof. + clear -R_big N_small B_small. + intros x Hbound. + cut ((N + B - 1) / R <= 1); + [ Z.div_mod_to_quot_rem; subst; nia | ]. + transitivity (((R-1) + (R-1) - 1) / R); + [ Z.peel_le; omega | ]. + autorewrite with zsimplify. + reflexivity. + Qed. + End Iteration. + + Local Notation redc_body := (@redc_body T divmod r scmul add N B k). + Local Notation redc_loop := (@redc_loop T divmod r scmul add N B k). + Local Notation redc A := (@redc T numlimbs zero divmod r scmul add N A B k). Fixpoint redc_loop_rev (count : nat) : T * T -> T * T := match count with @@ -101,7 +186,130 @@ Section montgomery. simpl; intro; rewrite <- IHcount, redc_loop_comm_body; reflexivity. Qed. -(* Print WordByWord.Definition.redc. - Lemma redc_correct i - : *) -End montgomery. + Section body. + Context (A_S : T * T). + Let A':=fst A_S. + Let S:=snd A_S. + Let A_a:=divmod A'. + Let a:=snd A_a. + Context (S_small : eval S / R <= 1). + + Lemma redc_body_mod_N + : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). + Proof. + destruct A_S; apply S3_mod_N; assumption. + Qed. + + Lemma fst_redc_body + : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. + Proof. + destruct A_S; simpl; unfold WordByWord.Definition.A', WordByWord.Definition.A_a, Let_In, a, A_a, A'; simpl. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma fst_redc_body_mod_N + : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). + Proof. + rewrite fst_redc_body. + etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. + unfold a, A_a, A'. + autorewrite with push_eval. + reflexivity. + Qed. + + Lemma redc_body_bound + : 0 <= eval S < eval N + eval B + -> 0 <= eval (snd (redc_body A_S)) < eval N + eval B. + Proof. + destruct A_S; apply S3_bound. + Qed. + End body. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.of_nat !_. + Lemma redc_loop_bound A_S count + : 0 <= eval (snd A_S) < eval N + eval B + -> 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. + rewrite redc_loop__redc_loop_rev. + induction count as [|count IHcount]. + { simpl; trivial. } + { simpl; intro; + repeat first [ apply redc_body_bound + | apply IHcount + | assumption + | apply small_from_bound ]. } + Qed. + + Lemma fst_redc_loop A_S count + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). + Proof. + rewrite redc_loop__redc_loop_rev. + induction count as [|count IHcount]. + { simpl; autorewrite with zsimplify; reflexivity. } + { simpl @redc_loop_rev. + rewrite fst_redc_body, IHcount + by (apply small_from_bound; rewrite <-redc_loop__redc_loop_rev; apply redc_loop_bound; auto). + rewrite Zdiv_Zdiv by Z.zero_bounds. + rewrite <- (Z.pow_1_r r) at 2. + rewrite <- Z.pow_add_r by lia. + replace (Z.of_nat count + 1) with (Z.of_nat (S count)) by (simpl; lia). + reflexivity. } + Qed. + + Lemma fst_redc_loop_mod_N A_S count + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) mod (eval N) = eval (fst A_S) * ri^(Z.of_nat count) mod (eval N). + Proof. + rewrite fst_redc_loop by assumption. + destruct count. + { simpl; autorewrite with zsimplify; reflexivity. } + { etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption | ]. + Focus 2. + erewrite <- Z.pow_mul_l, <- Z.pow_1_l. + Admitted. + + Lemma redc_loop_mod_N A_S count + (S_bound : 0 <= eval (snd A_S) < eval N + eval B) + : (eval (snd (redc_loop count A_S))) mod (eval N) = ((eval (snd A_S) + (eval (fst A_S) mod ri^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). + Proof. + rewrite redc_loop__redc_loop_rev. + induction count as [|count IHcount]. + { simpl; autorewrite with zsimplify; reflexivity. } + { simpl; rewrite redc_body_mod_N + by (apply small_from_bound; rewrite <- redc_loop__redc_loop_rev; apply redc_loop_bound; auto). + push_Zmod; rewrite IHcount; pull_Zmod. + autorewrite with push_eval. + rewrite <- redc_loop__redc_loop_rev, fst_redc_loop by omega. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + destruct A_S as [A S]. + cbn [fst snd]. + change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). + rewrite !Z.mul_add_distr_r. + rewrite <- !Z.mul_assoc. + replace (ri^(Z.of_nat count) * ri) with (ri^(Z.of_nat (Datatypes.S count))) + by (change (Datatypes.S count) with (1 + count)%nat; + autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). + rewrite <- !Z.add_assoc. + apply Z.add_mod_Proper; [ reflexivity | ]. + Unset Printing Coercions. + Coercion eval : T >-> Z. + Coercion Z.of_nat : nat >-> Z. + Notation "x '.+1'" := (Datatypes.S x) (format "x '.+1'", at level 10). + Infix "≡" := (Z.equiv_modulo _) (at level 70). + Admitted. + + Lemma redc_mod_N A + : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). + Proof. + unfold redc. + rewrite redc_loop_mod_N; cbn [fst snd]; + autorewrite with push_eval zsimplify; + [ | rewrite Npos_correct; pose proof (eval_nonneg B); lia ]. + Admitted. +End WordByWordMontgomery. |