aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-06-15 21:50:44 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-06-15 21:50:46 -0400
commit8079cc40a7e523c1c62ee88d949f2e905d7cad73 (patch)
tree8c01c23e68f487a4f40389938f90a76109698f15 /src
parent06d3a5f4cffdf615f209677f6ffccd3e8b23a03b (diff)
CPSify montgomery wbw reduction
I didn't want to bother redoing all of the proofs that I'd already done, so instead I prove the cps'ified version equal to the non-cps version, and transfer over the proofs that way.
Diffstat (limited to 'src')
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v54
-rw-r--r--src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v91
-rw-r--r--src/Arithmetic/Saturated.v2
3 files changed, 135 insertions, 12 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
index bc95d7091..c2048889d 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v
@@ -1,12 +1,56 @@
(*** Word-By-Word Montgomery Multiplication *)
(** This file implements Montgomery Form, Montgomery Reduction, and
- Montgomery Multiplication on an abstract [list ℤ]. *)
-Require Import Coq.ZArith.BinInt.
+ Montgomery Multiplication on an abstract [list ℤ]. See
+ https://github.com/mit-plv/fiat-crypto/issues/157 for a discussion
+ of the algorithm; note that it may be that none of the algorithms
+ there exactly match what we're doing here. *)
+Require Import Coq.ZArith.ZArith.
Require Import Crypto.Arithmetic.Saturated.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definition.
+Require Import Crypto.Util.Notations.
+Require Import Crypto.Util.LetIn.
-Section redc.
+Local Open Scope Z_scope.
+
+Section WordByWordMontgomery.
+ Local Coercion Z.pos : positive >-> Z.
(** TODO: pick better names for the arguments to this definition. *)
- Definition redc {r : positive} {R_numlimbs : nat} (N A B : T) (k : Z) : T
+ Context
+ {r : positive}
+ {R_numlimbs : nat}
+ (N : T).
+
+ Definition redc_body_no_cps (B : T) (k : Z) (A_S : T * T) : T * T
+ := @redc_body T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k A_S.
+ Definition redc_loop_no_cps (B : T) (k : Z) (count : nat) (A_S : T * T) : T * T
+ := @redc_loop T divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S.
+ Definition redc_no_cps (A B : T) (k : Z) : T
:= @redc T numlimbs zero divmod r (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N A B k.
-End redc.
+
+ Definition redc_body_cps (A B : T) (k : Z) (S' : T) {cpsT} (rest : T * T -> cpsT) : cpsT
+ := divmod_cps A (fun '(A, a) =>
+ @scmul_cps r a B _ (fun aB => @add_cps r S' aB _ (fun S1 =>
+ divmod_cps S1 (fun '(_, s) =>
+ dlet q := s * k mod r in
+ @scmul_cps r q N _ (fun qN => @add_cps r S1 qN _ (fun S2 =>
+ divmod_cps S2 (fun '(S3, _) =>
+ @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))).
+
+ Section loop.
+ Context (A B : T) (k : Z) {cpsT : Type}.
+ Fixpoint redc_loop_cps (count : nat) (rest : T * T -> cpsT) : T * T -> cpsT
+ := match count with
+ | O => rest
+ | S count' => fun '(A, S') => redc_body_cps A B k S' (redc_loop_cps count' rest)
+ end.
+
+ Definition redc_cps (rest : T -> cpsT) : cpsT
+ := redc_loop_cps (numlimbs A) (fun '(A, S') => rest S') (A, zero (1 + numlimbs B)).
+ End loop.
+
+ Definition redc_body (A B : T) (k : Z) (S' : T) : T * T := redc_body_cps A B k S' id.
+ Definition redc_loop (B : T) (k : Z) (count : nat) : T * T -> T * T := redc_loop_cps B k count id.
+ Definition redc (A B : T) (k : Z) : T := redc_cps A B k id.
+End WordByWordMontgomery.
+
+Hint Opaque redc redc_body redc_loop : uncps.
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
index 324b68511..96646c61c 100644
--- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
+++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v
@@ -6,6 +6,7 @@ Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Definit
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Abstract.Proofs.
Require Import Crypto.Arithmetic.MontgomeryReduction.WordByWord.Definition.
Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
Local Open Scope Z_scope.
Local Coercion Z.pos : positive >-> Z.
@@ -70,20 +71,98 @@ Section WordByWordMontgomery.
rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity.
Qed.
+ Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N).
+ Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N).
+ Local Notation redc_body := (@redc_body r R_numlimbs N).
+ Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k).
+ Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k).
+ Local Notation redc_loop := (@redc_loop r R_numlimbs N B k).
+ Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A B k).
+ Local Notation redc_cps := (@redc_cps r R_numlimbs N A B k).
Local Notation redc := (@redc r R_numlimbs N A B k).
- Definition redc_bound : 0 <= eval redc < eval N + eval B
+ Definition redc_no_cps_bound : 0 <= eval redc_no_cps < eval N + eval B
:= @redc_bound T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A.
- Definition numlimbs_redc_gen
- : numlimbs redc
+ Definition numlimbs_redc_no_cps_gen
+ : numlimbs redc_no_cps
= match numlimbs A with
| O => S (numlimbs B)
| _ => S R_numlimbs
end
:= @numlimbs_redc_gen T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_le.
- Definition numlimbs_redc : numlimbs redc = S (numlimbs B)
+ Definition numlimbs_redc_no_cps : numlimbs redc_no_cps = S (numlimbs B)
:= @numlimbs_redc T eval numlimbs zero divmod r r_big small eval_zero numlimbs_zero eval_div eval_mod small_div numlimbs_div scmul eval_scmul numlimbs_scmul R R_numlimbs R_correct add eval_add small_add numlimbs_add drop_high eval_drop_high numlimbs_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A small_A Hnumlimbs_eq.
- Definition redc_mod_N
- : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N)
+ Definition redc_no_cps_mod_N
+ : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N)
:= @redc_mod_N T eval numlimbs zero divmod r r_big small eval_zero eval_div eval_mod small_div scmul eval_scmul R R_numlimbs R_correct add eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A small_A A_bound.
+
+ Lemma redc_body_cps_id (A' S' : T) {cpsT} f
+ : @redc_body_cps A' B k S' cpsT f = f (redc_body A' B k S').
+ Proof.
+ unfold redc_body, redc_body_cps, LetIn.Let_In.
+ repeat first [ reflexivity
+ | break_innermost_match_step
+ | progress autorewrite with uncps ].
+ Qed.
+
+ Lemma redc_loop_cps_id (count : nat) (A_S : T * T) {cpsT} f
+ : @redc_loop_cps cpsT count f A_S = f (redc_loop count A_S).
+ Proof.
+ unfold redc_loop.
+ revert A_S f.
+ induction count as [|count IHcount].
+ { reflexivity. }
+ { intros [A' S']; simpl; intros.
+ etransitivity; rewrite @redc_body_cps_id; [ rewrite IHcount | ]; reflexivity. }
+ Qed.
+ Lemma redc_cps_id {cpsT} f : @redc_cps cpsT f = f redc.
+ Proof.
+ unfold redc, redc_cps.
+ etransitivity; rewrite redc_loop_cps_id; [ | reflexivity ]; break_innermost_match; reflexivity.
+ Qed.
+
+ Lemma redc_body_id_no_cps A' S'
+ : redc_body A' B k S' = redc_body_no_cps B k (A', S').
+ Proof.
+ unfold redc_body, redc_body_cps, redc_body_no_cps, Abstract.Definition.redc_body, LetIn.Let_In, id.
+ repeat autounfold with word_by_word_montgomery.
+ repeat first [ reflexivity
+ | progress cbn [fst snd id]
+ | progress autorewrite with uncps
+ | break_innermost_match_step
+ | f_equal; [] ].
+ Qed.
+ Lemma redc_loop_cps_id_no_cps count A_S
+ : redc_loop count A_S = redc_loop_no_cps count A_S.
+ Proof.
+ unfold redc_loop_no_cps, id.
+ revert A_S.
+ induction count as [|count IHcount]; simpl; [ reflexivity | ].
+ intros [A' S']; unfold redc_loop; simpl.
+ rewrite redc_body_cps_id, redc_loop_cps_id, IHcount, redc_body_id_no_cps.
+ reflexivity.
+ Qed.
+ Lemma redc_cps_id_no_cps : redc = redc_no_cps.
+ Proof.
+ unfold redc, redc_no_cps, redc_cps, Abstract.Definition.redc.
+ rewrite redc_loop_cps_id, (surjective_pairing (redc_loop _ _)).
+ rewrite redc_loop_cps_id_no_cps; reflexivity.
+ Qed.
+
+ Lemma redc_bound : 0 <= eval redc < eval N + eval B.
+ Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_bound. Qed.
+ Lemma numlimbs_redc_gen
+ : numlimbs redc
+ = match numlimbs A with
+ | O => S (numlimbs B)
+ | _ => S R_numlimbs
+ end.
+ Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps_gen. Qed.
+ Lemma numlimbs_redc : numlimbs redc = S (numlimbs B).
+ Proof. rewrite redc_cps_id_no_cps; apply numlimbs_redc_no_cps. Qed.
+ Lemma redc_mod_N
+ : (eval redc) mod (eval N) = (eval A * eval B * ri^(Z.of_nat (numlimbs A))) mod (eval N).
+ Proof. rewrite redc_cps_id_no_cps; apply redc_no_cps_mod_N. Qed.
End WordByWordMontgomery.
+
+Hint Rewrite redc_body_cps_id redc_loop_cps_id redc_cps_id : uncps.
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
index 3e33a136a..f7f623198 100644
--- a/src/Arithmetic/Saturated.v
+++ b/src/Arithmetic/Saturated.v
@@ -899,7 +899,7 @@ Section API.
End Proofs.
End API.
-
+Hint Rewrite divmod_id drop_high_id scmul_id add_id : uncps.
(*
(* Just some pretty-printing *)