diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-15 21:50:44 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-15 21:50:46 -0400 |
commit | 8079cc40a7e523c1c62ee88d949f2e905d7cad73 (patch) | |
tree | 8c01c23e68f487a4f40389938f90a76109698f15 /src/Arithmetic/MontgomeryReduction/WordByWord | |
parent | 06d3a5f4cffdf615f209677f6ffccd3e8b23a03b (diff) |
CPSify montgomery wbw reduction
I didn't want to bother redoing all of the proofs that I'd already done,
so instead I prove the cps'ified version equal to the non-cps version,
and transfer over the proofs that way.
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction/WordByWord')
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v | 54 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v | 91 |
2 files changed, 134 insertions, 11 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index bc95d7091..c2048889d 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -1,12 +1,56 @@ (*** Word-By-Word Montgomery Multiplication *) (** This file implements Montgomery Form, Montgomery Reduction, and - Montgomery Multiplication on an abstract [list ℤ]. *) -Require Import Coq.ZArith.BinInt. + Montgomery Multiplication on an abstract [list ℤ]. See + https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion + of the algorithm; note that it may be that none of the algorithms + there exactly match what we're doing here. *) +Require Import Coq.ZArith.ZArith. Require Import Crypto.Arithmetic.Saturated. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. -Section redc. +Local Open Scope Z_scope. + +Section WordByWordMontgomery. + Local Coercion Z.pos : positive >-> Z. (** TODO: pick better names for the arguments to this definition. *) - Definition redc {r : positive} {R_numlimbs : nat} (N A B : T) (k : Z) : T + Context + {r : positive} + {R_numlimbs : nat} + (N : T). + + Definition redc_body_no_cps (B : T) (k : Z) (A_S : T * T) : T * T + := @redc_body T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k A_S. + Definition redc_loop_no_cps (B : T) (k : Z) (count : nat) (A_S : T * T) : T * T + := @redc_loop T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S. + Definition redc_no_cps (A B : T) (k : Z) : T := @redc T numlimbs zero divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N A B k. -End redc. + + Definition redc_body_cps (A B : T) (k : Z) (S' : T) {cpsT} (rest : T * T -> cpsT) : cpsT + := divmod_cps A (fun '(A, a) => + @scmul_cps r a B _ (fun aB => @add_cps r S' aB _ (fun S1 => + divmod_cps S1 (fun '(_, s) => + dlet q := s * k mod r in + @scmul_cps r q N _ (fun qN => @add_cps r S1 qN _ (fun S2 => + divmod_cps S2 (fun '(S3, _) => + @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))). + + Section loop. + Context (A B : T) (k : Z) {cpsT : Type}. + Fixpoint redc_loop_cps (count : nat) (rest : T * T -> cpsT) : T * T -> cpsT + := match count with + | O => rest + | S count' => fun '(A, S') => redc_body_cps A B k S' (redc_loop_cps count' rest) + end. + + Definition redc_cps (rest : T -> cpsT) : cpsT + := redc_loop_cps (numlimbs A) (fun '(A, S') => rest S') (A, zero (1 + numlimbs B)). + End loop. + + Definition redc_body (A B : T) (k : Z) (S' : T) : T * T := redc_body_cps A B k S' id. + Definition redc_loop (B : T) (k : Z) (count : nat) : T * T -> T * T := redc_loop_cps B k count id. + Definition redc (A B : T) (k : Z) : T := redc_cps A B k id. +End WordByWordMontgomery. + +Hint Opaque redc redc_body redc_loop : uncps. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index 324b68511..96646c61c 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -6,6 +6,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definit Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Proofs. Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition. Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. Local Open Scope Z_scope. Local Coercion Z.pos : positive >-> Z. @@ -70,20 +71,98 @@ Section WordByWordMontgomery. rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity. Qed. + Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N). + Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N). + Local Notation redc_body := (@redc_body r R_numlimbs N). + Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k). + Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k). + Local Notation redc_loop := (@redc_loop r R_numlimbs N B k). + Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A B k). + Local Notation redc_cps := (@redc_cps r R_numlimbs N A B k). Local Notation redc := (@redc r R_numlimbs N A B k). - Definition redc_bound : 0 <= eval redc < eval N + eval B + Definition redc_no_cps_bound : 0 <= eval redc_no_cps < eval N + eval B := @redc_bound T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A. - Definition numlimbs_redc_gen - : numlimbs redc + Definition numlimbs_redc_no_cps_gen + : numlimbs redc_no_cps = match numlimbs A with | O => S (numlimbs B) | _ => S R_numlimbs end := @numlimbs_redc_gen T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_le. - Definition numlimbs_redc : numlimbs redc = S (numlimbs B) + Definition numlimbs_redc_no_cps : numlimbs redc_no_cps = S (numlimbs B) := @numlimbs_redc T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_eq. - Definition redc_mod_N - : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N) + Definition redc_no_cps_mod_N + : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N) := @redc_mod_N T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A small_A A_bound. + + Lemma redc_body_cps_id (A' S' : T) {cpsT} f + : @redc_body_cps A' B k S' cpsT f = f (redc_body A' B k S'). + Proof. + unfold redc_body, redc_body_cps, LetIn.Let_In. + repeat first [ reflexivity + | break_innermost_match_step + | progress autorewrite with uncps ]. + Qed. + + Lemma redc_loop_cps_id (count : nat) (A_S : T * T) {cpsT} f + : @redc_loop_cps cpsT count f A_S = f (redc_loop count A_S). + Proof. + unfold redc_loop. + revert A_S f. + induction count as [|count IHcount]. + { reflexivity. } + { intros [A' S']; simpl; intros. + etransitivity; rewrite @redc_body_cps_id; [ rewrite IHcount | ]; reflexivity. } + Qed. + Lemma redc_cps_id {cpsT} f : @redc_cps cpsT f = f redc. + Proof. + unfold redc, redc_cps. + etransitivity; rewrite redc_loop_cps_id; [ | reflexivity ]; break_innermost_match; reflexivity. + Qed. + + Lemma redc_body_id_no_cps A' S' + : redc_body A' B k S' = redc_body_no_cps B k (A', S'). + Proof. + unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Definition.redc_body, LetIn.Let_In, id. + repeat autounfold with word_by_word_montgomery. + repeat first [ reflexivity + | progress cbn [fst snd id] + | progress autorewrite with uncps + | break_innermost_match_step + | f_equal; [] ]. + Qed. + Lemma redc_loop_cps_id_no_cps count A_S + : redc_loop count A_S = redc_loop_no_cps count A_S. + Proof. + unfold redc_loop_no_cps, id. + revert A_S. + induction count as [|count IHcount]; simpl; [ reflexivity | ]. + intros [A' S']; unfold redc_loop; simpl. + rewrite redc_body_cps_id, redc_loop_cps_id, IHcount, redc_body_id_no_cps. + reflexivity. + Qed. + Lemma redc_cps_id_no_cps : redc = redc_no_cps. + Proof. + unfold redc, redc_no_cps, redc_cps, Abstract.Definition.redc. + rewrite redc_loop_cps_id, (surjective_pairing (redc_loop _ _)). + rewrite redc_loop_cps_id_no_cps; reflexivity. + Qed. + + Lemma redc_bound : 0 <= eval redc < eval N + eval B. + Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound. Qed. + Lemma numlimbs_redc_gen + : numlimbs redc + = match numlimbs A with + | O => S (numlimbs B) + | _ => S R_numlimbs + end. + Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps_gen. Qed. + Lemma numlimbs_redc : numlimbs redc = S (numlimbs B). + Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps. Qed. + Lemma redc_mod_N + : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N). + Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_mod_N. Qed. End WordByWordMontgomery. + +Hint Rewrite redc_body_cps_id redc_loop_cps_id redc_cps_id : uncps. |