diff options
author | Jason Gross <jgross@mit.edu> | 2017-06-17 18:09:01 -0400 |
---|---|---|
committer | Jason Gross <jgross@mit.edu> | 2017-06-17 18:09:01 -0400 |
commit | 755f09db00c197f7cbbe295ace9bfd97be41c0fb (patch) | |
tree | 8f1b3d7ac5d4a9cb61598a9d0027811ede89f6b8 /src/Arithmetic/MontgomeryReduction | |
parent | 75bca9a1c48254dcc058fcab78a239711bd80f35 (diff) |
Make use of non-uniform tuple-based add
Maybe it'll result in better output code with fewer zeros?
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
4 files changed, 42 insertions, 38 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v index 11b73b6a8..83fc63777 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Definition.v @@ -23,8 +23,9 @@ Section WordByWordMontgomery. {R_numlimbs : nat} {scmul : forall {n}, Z -> T n -> T (S n)} (* uses double-output multiply *) {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *) + {add' : forall {n}, T (S n) -> T n -> T (S (S n))} (* joins carry *) {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) - (N : T (S R_numlimbs)). + (N : T R_numlimbs). (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) Section Iteration. @@ -37,7 +38,7 @@ Section WordByWordMontgomery. Local Definition S1 := add _ S (scmul _ a B). Local Definition s := snd (divmod _ S1). Local Definition q := fst (Z.mul_split r s k). - Local Definition S2 := add _ S1 (scmul _ q N). + Local Definition S2 := add' _ S1 (scmul _ q N). Local Definition S3 := fst (divmod _ S2). Local Definition S4 := drop_high S3. End Iteration. diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v index 253b56b40..8ce8aaf77 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Abstract/Dependent/Proofs.v @@ -36,9 +36,12 @@ Section WordByWordMontgomery. {add : forall {n}, T n -> T n -> T (S n)} (* joins carry *) {eval_add : forall n a b, eval (@add n a b) = eval a + eval b} {small_add : forall n a b, small (@add n a b)} + {add' : forall {n}, T (S n) -> T n -> T (S (S n))} (* joins carry *) + {eval_add' : forall n a b, eval (@add' n a b) = eval a + eval b} + {small_add' : forall n a b, small (@add' n a b)} {drop_high : T (S (S R_numlimbs)) -> T (S R_numlimbs)} (* drops the highest limb *) {eval_drop_high : forall v, small v -> eval (drop_high v) = eval v mod (r * r^Z.of_nat R_numlimbs)} - (N : T (S R_numlimbs)) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) + (N : T R_numlimbs) (Npos : positive) (Npos_correct: eval N = Z.pos Npos) (N_lt_R : eval N < R) (B : T R_numlimbs) (B_bounds : 0 <= eval B < R) @@ -49,6 +52,7 @@ Section WordByWordMontgomery. Local Ltac t_small := repeat first [ assumption | apply small_add + | apply small_add' | apply small_div | apply Z_mod_lt | solve [ auto ] @@ -59,6 +63,7 @@ Section WordByWordMontgomery. eval_div eval_mod eval_add + eval_add' eval_scmul eval_drop_high using (repeat autounfold with word_by_word_montgomery; t_small) @@ -84,9 +89,9 @@ Section WordByWordMontgomery. Local Notation S1 := (@WordByWord.Abstract.Dependent.Definition.S1 T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S). Local Notation s := (@WordByWord.Abstract.Dependent.Definition.s T (@divmod) R_numlimbs scmul add pred_A_numlimbs B A S). Local Notation q := (@WordByWord.Abstract.Dependent.Definition.q T (@divmod) r R_numlimbs scmul add pred_A_numlimbs B k A S). - Local Notation S2 := (@WordByWord.Abstract.Dependent.Definition.S2 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S). - Local Notation S3 := (@WordByWord.Abstract.Dependent.Definition.S3 T (@divmod) r R_numlimbs scmul add N pred_A_numlimbs B k A S). - Local Notation S4 := (@WordByWord.Abstract.Dependent.Definition.S4 T (@divmod) r R_numlimbs scmul add drop_high N pred_A_numlimbs B k A S). + Local Notation S2 := (@WordByWord.Abstract.Dependent.Definition.S2 T (@divmod) r R_numlimbs scmul add add' N pred_A_numlimbs B k A S). + Local Notation S3 := (@WordByWord.Abstract.Dependent.Definition.S3 T (@divmod) r R_numlimbs scmul add add' N pred_A_numlimbs B k A S). + Local Notation S4 := (@WordByWord.Abstract.Dependent.Definition.S4 T (@divmod) r R_numlimbs scmul add add' drop_high N pred_A_numlimbs B k A S). Lemma S3_bound : eval S < eval N + eval B @@ -209,9 +214,9 @@ Section WordByWordMontgomery. Qed. End Iteration. - Local Notation redc_body := (@redc_body T (@divmod) r R_numlimbs scmul add drop_high N B k). - Local Notation redc_loop := (@redc_loop T (@divmod) r R_numlimbs scmul add drop_high N B k). - Local Notation redc A := (@redc T zero (@divmod) r R_numlimbs scmul add drop_high N _ A B k). + Local Notation redc_body := (@redc_body T (@divmod) r R_numlimbs scmul add add' drop_high N B k). + Local Notation redc_loop := (@redc_loop T (@divmod) r R_numlimbs scmul add add' drop_high N B k). + Local Notation redc A := (@redc T zero (@divmod) r R_numlimbs scmul add add' drop_high N _ A B k). (*Lemma redc_loop_comm_body count : forall A_S, redc_loop count (redc_body A_S) = redc_body (redc_loop count A_S). diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v index 0b4666856..2785eb37c 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Definition.v @@ -19,32 +19,29 @@ Section WordByWordMontgomery. Context {r : positive} {R_numlimbs : nat} - (N' : T R_numlimbs). - (** TODO(andreser): Add a comment here about why we take in [N : T - R_numlimbs]; we need [N : T (S R_numlimbs)] so that the limb - arithmetic works out exactly (so that we can add [q * N] of - length [S (S R_numlimbs)] to [S1] of length [S (S R_numlimbs)] - and then do [divmod] to get something of length [S - R_numlimbs]. *) - Local Notation N := (join0 N'). + (N : T R_numlimbs). + + Local Notation scmul := (@scmul (Z.pos r)). + Local Notation add' := (fun n => @add (Z.pos r) (S n) n (S n)). + Local Notation add := (fun n => @add (Z.pos r) n n n). Definition redc_body_no_cps (B : T R_numlimbs) (k : Z) {pred_A_numlimbs} (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)) : T pred_A_numlimbs * T (S R_numlimbs) - := @redc_body T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k _ A_S. + := @redc_body T (@divmod) r R_numlimbs (@scmul) add add' (@drop_high (S R_numlimbs)) N B k _ A_S. Definition redc_loop_no_cps (B : T R_numlimbs) (k : Z) (count : nat) (A_S : T count * T (S R_numlimbs)) : T 0 * T (S R_numlimbs) - := @redc_loop T (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N B k count A_S. + := @redc_loop T (@divmod) r R_numlimbs (@scmul) add add' (@drop_high (S R_numlimbs)) N B k count A_S. Definition redc_no_cps {A_numlimbs} (A : T A_numlimbs) (B : T R_numlimbs) (k : Z) : T (S R_numlimbs) - := @redc T (@zero) (@divmod) r R_numlimbs (@scmul (Z.pos r)) (@add (Z.pos r)) (@drop_high (S R_numlimbs)) N _ A B k. + := @redc T (@zero) (@divmod) r R_numlimbs (@scmul) add add' (@drop_high (S R_numlimbs)) N _ A B k. Definition redc_body_cps {pred_A_numlimbs} (A : T (S pred_A_numlimbs)) (B : T R_numlimbs) (k : Z) (S' : T (S R_numlimbs)) {cpsT} (rest : T pred_A_numlimbs * T (S R_numlimbs) -> cpsT) : cpsT := divmod_cps A (fun '(A, a) => - @scmul_cps r _ a B _ (fun aB => @add_cps r _ S' aB _ (fun S1 => + @scmul_cps r _ a B _ (fun aB => @add_cps r _ _ (S R_numlimbs) S' aB _ (fun S1 => divmod_cps S1 (fun '(_, s) => dlet q := fst (Z.mul_split r s k) in - @scmul_cps r _ q N _ (fun qN => @add_cps r _ S1 qN _ (fun S2 => + @scmul_cps r _ q N _ (fun qN => @add_cps r _ _ _ S1 qN _ (fun S2 => divmod_cps S2 (fun '(S3, _) => @drop_high_cps (S R_numlimbs) S3 _ (fun S4 => rest (A, S4))))))))). diff --git a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v index 6dfd6a10a..3710a3dd6 100644 --- a/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v +++ b/src/Arithmetic/MontgomeryReduction/WordByWord/Proofs.v @@ -16,7 +16,8 @@ Section WordByWordMontgomery. (R_numlimbs : nat). Local Notation small := (@small (Z.pos r)). Local Notation eval := (@eval (Z.pos r)). - Local Notation add := (@add (Z.pos r)). + Local Notation add' := (fun n => @add (Z.pos r) (S n) n (S n)). + Local Notation add := (fun n => @add (Z.pos r) n n n). Local Notation scmul := (@scmul (Z.pos r)). Local Notation eval_zero := (@eval_zero (Z.pos r)). Local Notation eval_join0 := (@eval_zero (Z.pos r) (Zorder.Zgt_pos_0 _)). @@ -24,15 +25,15 @@ Section WordByWordMontgomery. Local Notation eval_mod := (@eval_mod (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation small_div := (@small_div (Z.pos r) (Zorder.Zgt_pos_0 _)). Local Notation eval_scmul := (@eval_scmul (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation eval_add := (@eval_add (Z.pos r) (Zorder.Zgt_pos_0 _)). - Local Notation small_add := (@small_add (Z.pos r) (Zorder.Zgt_pos_0 _)). + Local Notation eval_add := (@eval_add_same (Z.pos r) (Zorder.Zgt_pos_0 _)). + Local Notation eval_add' := (@eval_add_S1 (Z.pos r) (Zorder.Zgt_pos_0 _)). + Local Notation small_add := (fun n => @small_add (Z.pos r) (Zorder.Zgt_pos_0 _) _ _ _). Local Notation drop_high := (@drop_high (S R_numlimbs)). Context (A_numlimbs : nat) - (N' : T R_numlimbs) + (N : T R_numlimbs) (A : T A_numlimbs) (B : T R_numlimbs) (k : Z). - Local Notation N := (join0 N'). Context ri (r_big : r > 1) (small_A : small A) @@ -70,21 +71,21 @@ Section WordByWordMontgomery. rewrite Znat.Nat2Z.inj_succ, Z.pow_succ_r by lia; reflexivity. Qed. - Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N'). - Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N'). - Local Notation redc_body := (@redc_body r R_numlimbs N'). - Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N' B k). - Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N' B k). - Local Notation redc_loop := (@redc_loop r R_numlimbs N' B k). - Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N' A_numlimbs A B k). - Local Notation redc_cps := (@redc_cps r R_numlimbs N' A_numlimbs A B k). - Local Notation redc := (@redc r R_numlimbs N' A_numlimbs A B k). + Local Notation redc_body_no_cps := (@redc_body_no_cps r R_numlimbs N). + Local Notation redc_body_cps := (@redc_body_cps r R_numlimbs N). + Local Notation redc_body := (@redc_body r R_numlimbs N). + Local Notation redc_loop_no_cps := (@redc_loop_no_cps r R_numlimbs N B k). + Local Notation redc_loop_cps := (@redc_loop_cps r R_numlimbs N B k). + Local Notation redc_loop := (@redc_loop r R_numlimbs N B k). + Local Notation redc_no_cps := (@redc_no_cps r R_numlimbs N A_numlimbs A B k). + Local Notation redc_cps := (@redc_cps r R_numlimbs N A_numlimbs A B k). + Local Notation redc := (@redc r R_numlimbs N A_numlimbs A B k). Definition redc_no_cps_bound : 0 <= eval redc_no_cps < eval N + eval B - := @redc_bound T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A_numlimbs A small_A. + := @redc_bound T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add (@add') eval_add' small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri k A_numlimbs A small_A. Definition redc_no_cps_mod_N : (eval redc_no_cps) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N) - := @redc_mod_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A_numlimbs A small_A A_bound. + := @redc_mod_N T (@eval) (@zero) (@divmod) r r_big R R_numlimbs R_correct (@small) eval_zero eval_div eval_mod small_div (@scmul) eval_scmul (@add) eval_add small_add (@add') eval_add' small_add drop_high eval_drop_high N Npos Npos_correct N_lt_R B B_bound ri ri_correct k k_correct A_numlimbs A small_A A_bound. Lemma redc_body_cps_id pred_A_numlimbs (A' : T (S pred_A_numlimbs)) (S' : T (S R_numlimbs)) {cpsT} f : @redc_body_cps pred_A_numlimbs A' B k S' cpsT f = f (redc_body A' B k S'). |