diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Arithmetic.v')
-rw-r--r-- | src/Experiments/NewPipeline/Arithmetic.v | 1327 |
1 files changed, 1284 insertions, 43 deletions
diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 398fd17bc..4bf3c55e6 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -10,8 +10,8 @@ Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. Require Import Crypto.Arithmetic.BarrettReduction.Generalized. -Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. -Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.Tactics.RunTacticAsConstr. Require Import Crypto.Util.Tactics.Head. @@ -20,6 +20,7 @@ Require Import Crypto.Util.OptionList. Require Import Crypto.Util.Prod. Require Import Crypto.Util.Sum. Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Sigma. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. Require Import Crypto.Util.ZUtil.Hints.PullPush. @@ -43,6 +44,7 @@ Require Import Crypto.Util.ZUtil.EquivModulo. Require Import Crypto.Util.Prod. Require Import Crypto.Util.CPSNotations. Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.SetEvars. Import ListNotations. Local Open Scope Z_scope. Module Associational. @@ -499,6 +501,10 @@ Module Positional. Section Positional. autorewrite with push_eval cancel_pair; ring. Qed. + Lemma eval_snoc_S n x y : n = length x -> eval (S n) (x ++ [y]) = eval n x + weight n * y. + Proof. intros; erewrite eval_snoc; eauto. Qed. + Hint Rewrite eval_snoc_S : push_eval. + (* SKIP over this: zeros, add_to_nth *) Local Ltac push := autorewrite with push_eval push_map distr_length push_flat_map push_fold_right push_nth_default cancel_pair natsimplify. @@ -534,6 +540,9 @@ Module Positional. Section Positional. end; lia. Qed. Hint Rewrite @eval_add_to_nth eval_zeros eval_combine_zeros : push_eval. + Lemma zeros_ext_map {A} n (p : list A) : length p = n -> zeros n = map (fun _ => 0) p. + Proof. cbv [zeros]; intro; subst; induction p; cbn; congruence. Qed. + Definition place (t:Z*Z) (i:nat) : nat * Z := nat_rect (fun _ => unit -> (nat * Z)%type) @@ -584,11 +593,52 @@ Module Positional. Section Positional. push; omega. Qed. Hint Rewrite eval_extend_to_length : push_eval. - Lemma length_eval_extend_to_length n_in n_out p : + Lemma length_extend_to_length n_in n_out p : length p = n_in -> (n_in <= n_out)%nat -> length (extend_to_length n_in n_out p) = n_out. Proof. cbv [extend_to_length]; intros; distr_length. Qed. - Hint Rewrite length_eval_extend_to_length : distr_length. + Hint Rewrite length_extend_to_length : distr_length. + + Definition drop_high_to_length (n : nat) (p:list Z) : list Z := + firstn n p. + (* + Lemma eval_drop_high_to_length n m p : + (forall i, weight (S i) mod weight i = 0) -> length p = m -> (n <= m)%nat -> + eval n (drop_high_to_length n p) mod weight n + = eval m p mod weight n. + Proof. + cbv [eval drop_high_to_length to_associational]; intros. + replace m with (n + (m - n))%nat in * by (f_equal; omega). + generalize dependent (m - n)%nat; clear m; intro m; intros H' H''. + rewrite seq_add, map_app, <- (firstn_skipn n p), combine_app_samelength, firstn_skipn, Associational.eval_app; + push; try omega **. + rewrite <- (Z.add_0_r (Associational.eval _)) at 1. + apply Z.add_mod_Proper; [ reflexivity | cbv [Z.equiv_modulo] ]. + generalize (skipn_length n p); rewrite H', minus_plus. + generalize (skipn n p); clear dependent p; clear H''; intros p Hp. + rewrite Zmod_0_l. + subst. + cbv [Associational.eval]. + revert n; induction p as [|p ps IHps]; intro; [ reflexivity | ]. + cbn in *. + push_Zmod; pull_Zmod; autorewrite with zsimplify_const. + rewrite <- IHps. + { cbn; reflexivity + Search (0 mod _). + rewrite Z.mod_0_l + Search (?x + ?y - ?x)%nat. + Search Z.equiv_modulo Proper. + pose proof H as H''. + rewrite <- (firstn_skipn n p) in H''. + distr_length. + + rewrite Nat.min_l in H'' by omega. + Qed. + Hint Rewrite eval_drop_high_to_length : push_eval.*) + Lemma length_drop_high_to_length n p : + length (drop_high_to_length n p) = Nat.min n (length p). + Proof. cbv [drop_high_to_length]; intros; distr_length. Qed. + Hint Rewrite length_drop_high_to_length : distr_length. Section mulmod. Context (s:Z) (s_nz:s <> 0) @@ -646,6 +696,7 @@ Module Positional. Section Positional. Lemma length_carry n m index p : length (carry n m index p) = m. Proof. cbv [carry]; distr_length. Qed. + Hint Rewrite length_carry : distr_length. Lemma eval_carry n m i p: (n <> 0%nat) -> (m <> 0%nat) -> weight (S i) / weight i <> 0 -> eval m (carry n m i p) = eval n p. @@ -712,6 +763,12 @@ Module Positional. Section Positional. apply fold_right_invariant; [|intro; rewrite <-in_rev]; intros; push; auto. Qed. Hint Rewrite @eval_chained_carries_no_reduce : push_eval. + Lemma length_chained_carries_no_reduce n p idxs + : length p = n -> length (@chained_carries_no_reduce n p idxs) = n. + Proof. + intros; cbv [chained_carries_no_reduce]; induction (rev idxs) as [|x xs IHxs]; + cbn [fold_right]; distr_length. + Qed. Hint Rewrite @length_chained_carries_no_reduce : distr_length. (* Reverse of [eval]; translate from Z to basesystem by putting everything in first digit and then carrying. *) @@ -727,9 +784,22 @@ Module Positional. Section Positional. : length (encode n s c x) = n. Proof. cbv [encode]; repeat distr_length. Qed. + (* Reverse of [eval]; translate from Z to basesystem by putting + everything in first digit and then carrying, but without reduction. *) + Definition encode_no_reduce n (x : Z) : list Z := + chained_carries_no_reduce n (from_associational n [(1,x)]) (seq 0 n). + Lemma eval_encode_no_reduce n x : + (n <> 0%nat) -> + (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> + eval n (encode_no_reduce n x) = x. + Proof using Type*. cbv [encode_no_reduce]; intros; push; auto; f_equal; omega. Qed. + Lemma length_encode_no_reduce n x + : length (encode_no_reduce n x) = n. + Proof. cbv [encode_no_reduce]; repeat distr_length. Qed. + End Carries. - Hint Rewrite @eval_encode : push_eval. - Hint Rewrite @length_encode : distr_length. + Hint Rewrite @eval_encode @eval_encode_no_reduce @eval_carry @eval_carry_reduce @eval_chained_carries @eval_chained_carries_no_reduce : push_eval. + Hint Rewrite @length_encode @length_encode_no_reduce @length_carry @length_carry_reduce @length_chained_carries @length_chained_carries_no_reduce : distr_length. Section sub. Context (n:nat) @@ -840,8 +910,8 @@ Module Positional. Section Positional. End select. End Positional. (* Hint Rewrite disappears after the end of a section *) -Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min : distr_length. -Hint Rewrite @eval_select @eval_zselect : push_eval. +Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length. +Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length (*@eval_drop_high_to_length*) : push_eval. Section Positional_nonuniform. Context (weight weight' : nat -> Z). @@ -2099,7 +2169,15 @@ Module Rows. (* Subtract q if and only if p >= q. *) Definition conditional_sub n (p q:list Z) := let '(v, c) := sub n p q in - Positional.select c v p. + Positional.select (-c) v p. + + (* the carry will be 0 unless we underflow--we do the addition only + in the underflow case *) + Definition sub_then_maybe_add n mask (p q r:list Z) := + let '(p_minus_q, c) := sub n p q in + let rr := Positional.zselect mask (-c) r in + let '(res, c') := add n p_minus_q rr in + (res, c' - c). Hint Rewrite eval_cons eval_nil using solve [auto] : push_eval. @@ -2189,6 +2267,11 @@ Module Rows. fst (mul base n m p q) = partition m (Positional.eval weight n p * Positional.eval weight n q). Proof using wprops. solver. Qed. + Lemma mul_div base n m p q : + base <> 0 -> n <> 0%nat -> m <> 0%nat -> length p = n -> length q = n -> + snd (mul base n m p q) = (Positional.eval weight n p * Positional.eval weight n q) / weight m. + Proof using wprops. solver. Qed. + Lemma length_mul base n m p q : length p = n -> length q = n -> length (fst (mul base n m p q)) = m. @@ -2227,11 +2310,53 @@ Module Rows. rewrite <-Z.div_mod'' by auto. autorewrite with push_eval; reflexivity. Qed. + + (* returns all-but-lowest-limb and lowest limb *) + Definition divmod (p : list Z) : list Z * Z + := (tl p, hd 0 p). + (* + Lemma eval_divmod n (p : list Z) : + length p = S n -> + (forall i, weight i = weight 1 ^ Z.of_nat i) -> + (forall i, (i <= n)%nat -> + nth_default 0 p i = (Positional.eval weight (S n) p mod weight (S i)) / (weight i)) -> + let pv := Positional.eval weight (S n) p in + Positional.eval (fun i => weight (S i) / weight 1) n (fst (divmod p)) = pv / weight 1 + /\ snd (divmod p) = pv mod weight 1. + Proof. + cbv [is_div_mod divmod]; destruct p; cbn [fst snd hd tl length]; [ omega | ]. + intros Hlen Hsmall. + split. + { rewrite Positional.eval_cons, weight_0 by (assumption || omega). + autorewrite with zsimplify_const. + symmetry; erewrite Positional.eval_weight_mul. + Print Positional. + 2: { + (etransitivity; [ exact (Hsmall 0%nat ltac:(omega)) | ]). + rewrite weight_0 by assumption; autorewrite with zsimplify_const; reflexivity. + } + + + revert H0. + push_Zmod. +hd 0 p). + Lemma eval_divmod n (p : list Z) : + length p = S n -> p = partition (S n) (Positional.eval weight (S n) p) -> + is_div_mod (Positional.eval (fun i => weight (S i) / weight 1) n) + (divmod p) + (Positional.eval weight (S n) p) + (weight 1). + Proof. + cbv [is_div_mod divmod]; destruct p; cbn [fst snd hd tl length]; [ omega | ]. + intros. + rewrite eval_ + *) End Ops. End Rows. Hint Rewrite length_from_columns using eassumption : distr_length. Hint Rewrite length_sum_rows using solve [ reflexivity | eassumption | distr_length; eauto ] : distr_length. Hint Rewrite length_fst_extract_row length_snd_extract_row length_flatten length_flatten' length_partition length_fst_from_columns' length_snd_from_columns' : distr_length. + Hint Rewrite @eval_partition : push_eval. End Rows. Module BaseConversion. @@ -2609,7 +2734,7 @@ Section freeze_mod_ops. (m_enc_correct : Positional.eval weight n m_enc = m) (Hm_enc_len : length m_enc = n). - Definition wprops_bytes := (@wprops 8 1 ltac:(lia)). + Definition wprops_bytes := (@wprops 8 1 ltac:(clear; lia)). Local Notation wprops := (@wprops limbwidth_num limbwidth_den limbwidth_good). Local Hint Immediate (wprops). @@ -2624,7 +2749,9 @@ Section freeze_mod_ops. Local Hint Immediate (weight_divides wprops_bytes). Local Hint Resolve Z.positive_is_nonzero Z.lt_gt. - Definition bytes_n := (1 + (Z.to_nat (Z.log2_up (weight n) / 8)))%nat. + Definition bytes_n + := Eval cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv Pos.mul] + in Z.to_nat (Qceiling (Z.log2_up (weight n) / 8)). Definition to_bytes' (v : list Z) := BaseConversion.convert_bases weight bytes_weight n bytes_n v. @@ -2638,54 +2765,118 @@ Section freeze_mod_ops. that the result partitions. See https://github.com/JasonGross/fiat-crypto/tree/zzz-wip-better-arith-proofs for some partial work in this direction. *) - Definition to_bytesmod (f : list Z) : list Z - := let v := to_bytes' (freeze weight n (Z.ones bitwidth) m_enc f) in + Definition to_bytes (f : list Z) : list Z + := let v := to_bytes' f in fst (Rows.flatten bytes_weight bytes_n (Rows.from_associational bytes_weight bytes_n (Positional.to_associational bytes_weight bytes_n v))). + Definition freeze_to_bytesmod (f : list Z) : list Z + := to_bytes (freeze weight n (Z.ones bitwidth) m_enc f). + + Definition to_bytesmod (f : list Z) : list Z + := to_bytes f. + Definition from_bytesmod (f : list Z) : list Z := from_bytes f. + Lemma bytes_nz : bytes_n <> 0%nat. + Proof using limbwidth_good Hn_nz. + clear -limbwidth_good Hn_nz. + cbv [bytes_n]. + cbv [Qceiling Qdiv inject_Z Qfloor Qmult Qopp Qnum Qden Qinv]. + autorewrite with zsimplify_const. + change (Z.pos (1*8)) with 8. + cbv [weight]. + rewrite Z.log2_up_pow2 by (Z.div_mod_to_quot_rem; nia). + autorewrite with zsimplify_fast. + rewrite <- Z2Nat.inj_0, Z2Nat.inj_iff by (Z.div_mod_to_quot_rem; nia). + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma bytes_n_big : weight n <= bytes_weight bytes_n. + Proof using limbwidth_good Hn_nz. + clear -limbwidth_good Hn_nz. + cbv [bytes_n bytes_weight]. + Z.peel_le. + rewrite Z.log2_up_pow2 by (Z.div_mod_to_quot_rem; nia). + autorewrite with zsimplify_fast. + rewrite Z2Nat.id by (Z.div_mod_to_quot_rem; nia). + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma eval_to_bytes_mod + : forall (f : list Z) + (Hf : length f = n), + eval bytes_weight bytes_n (to_bytes f) = eval weight n f mod (bytes_weight bytes_n). + Proof using limbwidth_good Hn_nz. + generalize wprops wprops_bytes; clear -Hn_nz limbwidth_good. + intros. + cbv [to_bytes]. + rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. + rewrite Rows.eval_from_associational by eauto using bytes_nz with omega. + rewrite eval_to_associational. + cbv [to_bytes']. + rewrite BaseConversion.eval_convert_bases + by (auto using bytes_nz; distr_length; auto using wprops). + reflexivity. + Qed. + + Lemma eval_to_bytes + : forall (f : list Z) + (Hf : length f = n) + (Hf_small : 0 <= eval weight n f < weight n), + eval bytes_weight bytes_n (to_bytes f) = eval weight n f. + Proof using Hn_nz limbwidth_good. + generalize bytes_n_big. clear -Hn_nz limbwidth_good. + intros; rewrite eval_to_bytes_mod by assumption. + rewrite Z.mod_small by omega; reflexivity. + Qed. + + Lemma to_bytes_partitions + : forall (f : list Z) + (Hf : length f = n), + to_bytes f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f). + Proof using Hn_nz limbwidth_good. + clear -Hn_nz limbwidth_good. + intros; cbv [to_bytes]. + rewrite Rows.flatten_partitions' by eauto using wprops, Rows.length_from_associational. + rewrite Rows.eval_from_associational by eauto using bytes_nz with omega. + rewrite eval_to_associational. + cbv [to_bytes']. + rewrite BaseConversion.eval_convert_bases + by (auto using wprops_bytes, bytes_nz; distr_length; auto using wprops). + reflexivity. + Qed. + Lemma eval_to_bytesmod : forall (f : list Z) + (Hf : length f = n) + (Hf_small : 0 <= eval weight n f < weight n), + eval bytes_weight bytes_n (to_bytesmod f) = eval weight n f + /\ to_bytesmod f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f). + Proof using Hn_nz limbwidth_good. + split; apply eval_to_bytes || apply to_bytes_partitions; assumption. + Qed. + + Lemma eval_freeze_to_bytesmod + : forall (f : list Z) (Hf : length f = n) (Hf_bounded : 0 <= eval weight n f < 2 * m), - (eval bytes_weight bytes_n (to_bytesmod f)) = (eval weight n f) mod m - /\ to_bytesmod f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f mod m). + (eval bytes_weight bytes_n (freeze_to_bytesmod f)) = (eval weight n f) mod m + /\ freeze_to_bytesmod f = Rows.partition bytes_weight bytes_n (Positional.eval weight n f mod m). Proof. - intros; subst m s; split. - { cbv [to_bytesmod]. - rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. - rewrite Rows.eval_from_associational by (cbv [bytes_n]; eauto with omega). - rewrite eval_to_associational. - cbv [to_bytes']. - rewrite BaseConversion.eval_convert_bases - by (cbv [bytes_n]; auto using wprops_bytes; distr_length; auto using wprops). - erewrite eval_freeze by eauto using wprops. - rewrite (Z.mod_small (_ mod _)); [ reflexivity | ]. - split; [ | eapply Z.lt_le_trans ]; [ apply Z.mod_pos_bound; omega.. | ]. - transitivity (weight n); [ omega | ]. - cbv [weight bytes_n]. - Z.peel_le. - rewrite Z.log2_up_pow2 by (Z.div_mod_to_quot_rem_in_goal; nia). - autorewrite with push_Zof_nat. - rewrite Z2Nat.id by (Z.div_mod_to_quot_rem_in_goal; nia). - Z.div_mod_to_quot_rem_in_goal; nia. } - { cbv [to_bytesmod]. - rewrite Rows.flatten_partitions' by eauto using wprops, Rows.length_from_associational. - rewrite Rows.eval_from_associational by (cbv [bytes_n]; eauto with omega). - rewrite eval_to_associational. - cbv [to_bytes']. - rewrite BaseConversion.eval_convert_bases - by (cbv [bytes_n]; auto using wprops_bytes; distr_length; auto using wprops). - erewrite eval_freeze by eauto using wprops. - reflexivity. } + intros; subst m s. + cbv [freeze_to_bytesmod]. + rewrite eval_to_bytes, to_bytes_partitions; + erewrite ?eval_freeze by eauto using wprops; + autorewrite with distr_length; eauto. + Z.div_mod_to_quot_rem; nia. Qed. Lemma eval_from_bytesmod : forall (f : list Z) (Hf : length f = bytes_n), eval weight n (from_bytesmod f) = eval bytes_weight bytes_n f. - Proof. + Proof using Hn_nz limbwidth_good. cbv [from_bytesmod from_bytes]; intros. rewrite BaseConversion.eval_convert_bases by eauto using wprops. reflexivity. @@ -2729,3 +2920,1053 @@ Section primitives. | progress Z.rewrite_mod_small ]. Qed. End primitives. + +Module UniformWeight. + Definition uweight (lgr : Z) : nat -> Z + := weight lgr 1. + Definition uwprops lgr (Hr : 0 < lgr) : @weight_properties (uweight lgr). + Proof. apply wprops; omega. Qed. + Lemma uweight_eq_alt' lgr n : uweight lgr n = 2^(lgr*Z.of_nat n). + Proof. now cbv [uweight weight]; autorewrite with zsimplify_fast. Qed. + Lemma uweight_eq_alt lgr (Hr : 0 <= lgr) n : uweight lgr n = (2^lgr)^Z.of_nat n. + Proof. now rewrite uweight_eq_alt', Z.pow_mul_r by lia. Qed. +End UniformWeight. + +Module WordByWordMontgomery. + Section with_args. + Context (lgr : Z) + (m : Z). + Local Notation weight := (UniformWeight.uweight lgr). + Let T (n : nat) := list Z. + Let r := (2^lgr). + Definition eval {n} : T n -> Z := Positional.eval weight n. + Let zero {n} : T n := Positional.zeros n. + Let divmod {n} : T (S n) -> T n * Z := Rows.divmod. + Let scmul {n} (c : Z) (p : T n) : T (S n) (* uses double-output multiply *) + := let '(v, c) := Rows.mul weight r n (S n) (Positional.extend_to_length 1 n [c]) p in + v. + Let addT {n} (p q : T n) : T (S n) (* joins carry *) + := let '(v, c) := Rows.add weight n p q in + v ++ [c]. + Let drop_high_addT' {n} (p : T (S n)) (q : T n) : T (S n) (* drops carry *) + := fst (Rows.add weight (S n) p (Positional.extend_to_length n (S n) q)). + Let conditional_sub {n} (arg : T (S n)) (N : T n) : T n (* computes [arg - N] if [N <= arg], and drops high bit *) + := Positional.drop_high_to_length n (Rows.conditional_sub weight (S n) arg (Positional.extend_to_length n (S n) N)). + Context (R_numlimbs : nat) + (N : T R_numlimbs). (* encoding of m *) + Let sub_then_maybe_add (a b : T R_numlimbs) : T R_numlimbs (* computes [a - b + if (a - b) <? 0 then N else 0] *) + := fst (Rows.sub_then_maybe_add weight R_numlimbs (r-1) a b N). + Local Opaque T. + Section Iteration. + Context (pred_A_numlimbs : nat) + (B : T R_numlimbs) (k : Z) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + Local Definition A_a := dlet p := @divmod _ A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. + Local Definition S1 := @addT _ S (@scmul _ a B). + Local Definition s := snd (@divmod _ S1). + Local Definition q := fst (Z.mul_split r s k). + Local Definition S2 := @drop_high_addT' _ S1 (@scmul _ q N). + Local Definition S3' := fst (@divmod _ S2). + + Local Definition A'_S3 + := dlet A_a := @divmod _ A in + dlet A' := fst A_a in + dlet a := snd A_a in + dlet S1 := @addT _ S (@scmul _ a B) in + dlet s := snd (@divmod _ S1) in + dlet q := fst (Z.mul_split r s k) in + dlet S2 := @drop_high_addT' _ S1 (@scmul _ q N) in + dlet S3 := fst (@divmod _ S2) in + (A', S3). + + Lemma A'_S3_alt : A'_S3 = (A', S3'). + Proof. cbv [A'_S3 A' S3' Let_In S2 q s S1 A' a A_a]; reflexivity. Qed. + End Iteration. + + Section loop. + Context (A_numlimbs : nat) + (A : T A_numlimbs) + (B : T R_numlimbs) + (k : Z) + (S' : T (S R_numlimbs)). + + Definition redc_body {pred_A_numlimbs} : T (S pred_A_numlimbs) * T (S R_numlimbs) + -> T pred_A_numlimbs * T (S R_numlimbs) + := fun '(A, S') => A'_S3 _ B k A S'. + + Definition redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs) + := nat_rect + (fun count => T count * _ -> _) + (fun A_S => A_S) + (fun count' redc_loop_count' A_S + => redc_loop_count' (redc_body A_S)) + count. + + Definition pre_redc : T (S R_numlimbs) + := snd (redc_loop A_numlimbs (A, @zero (1 + R_numlimbs)%nat)). + + Definition redc : T R_numlimbs + := conditional_sub pre_redc N. + End loop. + + Create HintDb word_by_word_montgomery. + Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. + + Definition add (A B : T R_numlimbs) : T R_numlimbs + := conditional_sub (@addT _ A B) N. + Definition sub (A B : T R_numlimbs) : T R_numlimbs + := sub_then_maybe_add A B. + Definition opp (A : T R_numlimbs) : T R_numlimbs + := sub (@zero _) A. + Definition nonzero (A : list Z) : Z + := fold_right Z.lor 0 A. + + Context (lgr_big : 0 < lgr) + (R_numlimbs_nz : R_numlimbs <> 0%nat). + Let R := (r^Z.of_nat R_numlimbs). + Transparent T. + Definition small {n} (v : T n) : Prop + := v = Rows.partition weight n (eval v). + Context (small_N : small N) + (N_lt_R : eval N < R) + (N_nz : 0 < eval N) + (B : T R_numlimbs) + (B_bounds : 0 <= eval B < R) + (small_B : small B) + ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) + (k : Z) (k_correct : k * eval N mod r = (-1) mod r). + + Local Lemma r_big : r > 1. + Proof using lgr_big. clear -lgr_big; subst r. auto with zarith. Qed. + Local Notation wprops := (@UniformWeight.uwprops lgr lgr_big). + + Local Hint Immediate (wprops). + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + Local Hint Immediate r_big. + + Lemma length_small {n v} : @small n v -> length v = n. + Proof using Type. clear; cbv [small]; intro H; rewrite H; autorewrite with distr_length; reflexivity. Qed. + + Let partition_Proper := (@Rows.partition_Proper _ wprops). + Local Existing Instance partition_Proper. + Lemma eval_nonzero n A : @small n A -> nonzero A = 0 <-> @eval n A = 0. + Proof. + cbv [nonzero eval small]; intro Heq. + do 2 rewrite Heq. + rewrite !Rows.eval_partition, Z.mod_mod by auto. + generalize (Positional.eval weight n A); clear Heq A. + induction n as [|n IHn]. + { cbn; rewrite weight_0 by auto; intros; autorewrite with zsimplify_const; omega. } + { intro; rewrite Rows.partition_step. + rewrite fold_right_snoc, Z.lor_comm, <- fold_right_push, Z.lor_eq_0_iff by auto using Z.lor_assoc. + assert (Heq : Z.equiv_modulo (weight n) (z mod weight (S n)) (z mod (weight n))). + { cbv [Z.equiv_modulo]. + generalize (weight_multiples ltac:(auto) n). + generalize (weight_positive ltac:(auto) n). + generalize (weight_positive ltac:(auto) (S n)). + generalize (weight (S n)) (weight n); clear; intros wsn wn. + clear; intros. + Z.div_mod_to_quot_rem; subst. + autorewrite with zsimplify_const in *. + Z.linear_substitute_all. + apply Zminus_eq; ring_simplify. + rewrite <- !Z.add_opp_r, !Z.mul_opp_comm, <- !Z.mul_opp_r, <- !Z.mul_assoc. + rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. + nia. } + rewrite Heq at 1; rewrite IHn. + rewrite Z.mod_mod by auto. + generalize (weight_multiples ltac:(auto) n). + generalize (weight_positive ltac:(auto) n). + generalize (weight_positive ltac:(auto) (S n)). + generalize (weight (S n)) (weight n); clear; intros wsn wn; intros. + Z.div_mod_to_quot_rem. + repeat (intro || apply conj); destruct_head'_or; try omega; destruct_head'_and; subst; autorewrite with zsimplify_const in *; try nia; + Z.linear_substitute_all. + all: apply Zminus_eq; ring_simplify. + all: rewrite <- ?Z.add_opp_r, ?Z.mul_opp_comm, <- ?Z.mul_opp_r, <- ?Z.mul_assoc. + all: rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. + all: nia. } + Qed. + + Local Ltac push_step := + first [ progress eta_expand + | rewrite Rows.mul_partitions + | rewrite Rows.mul_div + | rewrite Rows.add_partitions + | rewrite Rows.add_div + | progress autorewrite with push_eval distr_length + | match goal with + | [ H : ?v = _ |- context[length ?v] ] => erewrite length_small by eassumption + | [ H : small ?v |- context[length ?v] ] => erewrite length_small by eassumption + end + | rewrite Positional.eval_cons + | rewrite (weight_0 wprops) + | rewrite <- Z.div_mod'' by auto with omega + | solve [ trivial ] ]. + Local Ltac push := repeat push_step. + + Local Ltac t_step := + match goal with + | [ H := _ |- _ ] => progress cbv [H] in * + | _ => progress push_step + | _ => progress autorewrite with zsimplify_const + | _ => solve [ auto with omega ] + end. + + Local Hint Unfold eval zero small divmod scmul drop_high_addT' addT R : loc. + Local Lemma eval_zero : forall n, eval (@zero n) = 0. + Proof using Type. + clear; autounfold with loc; intros; autorewrite with push_eval; auto. + cbv -[Z.pow Z.mul Z.opp Z.div]; autorewrite with zsimplify_const; reflexivity. + Qed. + Local Lemma small_zero : forall n, small (@zero n). + Proof using Type. + etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [Rows.partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. + Qed. + Local Hint Immediate small_zero. + Local Axiom eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. + Local Axiom eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r. + Local Axiom small_div : forall n v, small v -> small (fst (@divmod n v)). + Local Lemma eval_scmul: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> eval (@scmul n a v) = a * eval v. + Proof using lgr_big. + generalize (@length_small); clear -lgr_big; intro. + autounfold with loc; intro n; destruct (zerop n). + { cbn; intros; subst; cbn; rewrite Z.add_with_get_carry_full_mod; cbn; omega. } + intros; repeat t_step. + repeat first [ reflexivity + | rewrite UniformWeight.uweight_eq_alt by omega + | progress autorewrite with push_Zof_nat + | rewrite Z.pow_succ_r by lia + | progress Z.rewrite_mod_small ]. + Qed. + Local Lemma small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> small (@scmul n a v). + Proof using lgr_big. + intros n a v Hpart. + generalize (length_small Hpart). + generalize eval_scmul. + clear -Hpart lgr_big. + destruct (zerop n). + { destruct v; subst; cbn; try congruence; cbv [small]; cbn. + rewrite Z.add_with_get_carry_full_mod; cbn; autorewrite with zsimplify_const; reflexivity. } + { cbv [small]; intros eval_scmul; intros. + rewrite eval_scmul by auto. + cbv [scmul]; eta_expand. + rewrite Rows.mul_partitions by (auto with omega; autorewrite with distr_length; auto with omega). + autorewrite with push_eval; auto with omega. + rewrite Positional.eval_cons, Positional.eval_nil by reflexivity. + rewrite weight_0 by auto; autorewrite with zsimplify_const; reflexivity. } + Qed. + Local Lemma eval_addT : forall n a b, small a -> small b -> eval (@addT n a b) = eval a + eval b. + Proof using lgr_big. + intros n a b Ha Hb; generalize (length_small Ha); generalize (length_small Hb). + clear -lgr_big Ha Hb. + autounfold with loc; destruct (zerop n); subst. + { destruct a, b; cbn; try omega. } + { eta_expand; intros; repeat t_step. } + Qed. + Local Axiom small_addT : forall n a b, small a -> small b -> small (@addT n a b). + Local Lemma eval_drop_high_addT' : forall n a b, small a -> small b -> eval (@drop_high_addT' n a b) = (eval a + eval b) mod (r^Z.of_nat (S n)). + Proof using lgr_big. + intros n a b Ha Hb; generalize (length_small Ha); generalize (length_small Hb). + clear -lgr_big Ha Hb. + autounfold with loc in *; destruct (zerop n); subst. + { destruct a as [| ? [|] ], b; cbn; try omega. + cbv [Rows.partition seq eval map] in Ha. + cbn in Ha. + rewrite (weight_0 wprops) in *. + rewrite Z.add_with_get_carry_full_mod. + subst r. + rewrite UniformWeight.uweight_eq_alt in * by omega. + autorewrite with zsimplify_const in *. + inversion Ha as [Ha']; clear Ha. + rewrite <- !Ha'. + reflexivity. } + { eta_expand; intros; repeat t_step. + rewrite UniformWeight.uweight_eq_alt by omega. + reflexivity. } + Qed. + Local Lemma small_drop_high_addT' : forall n a b, small a -> small b -> small (@drop_high_addT' n a b). + Proof using lgr_big. + intros n a b Ha Hb; generalize (length_small Ha); generalize (length_small Hb); generalize (@eval_drop_high_addT' n a b Ha). + clear -lgr_big Ha Hb. + cbv [small]. + intro Heq; rewrite Heq; autounfold with loc in *. + rewrite Ha, Hb. + repeat t_step. + rewrite !UniformWeight.uweight_eq_alt by omega. + autorewrite with push_Zof_nat zsimplify_fast. + rewrite Z.pow_succ_r by omega. + Admitted. + Local Axiom eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0. + Local Axiom small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v N). + Local Axiom eval_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b <? 0 then eval N else 0. + Local Axiom small_sub_then_maybe_add : forall a b, small (sub_then_maybe_add a b). + + Local Opaque T addT drop_high_addT' divmod zero scmul conditional_sub sub_then_maybe_add. + Create HintDb push_mont_eval discriminated. + Create HintDb word_by_word_montgomery. + Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. + Let r_big' := r_big. (* to put it in the context *) + Local Ltac t_small := + repeat first [ assumption + | apply small_addT + | apply small_drop_high_addT' + | apply small_div + | apply small_zero + | apply small_scmul + | apply small_conditional_sub + | apply small_sub_then_maybe_add + | apply Z_mod_lt + | rewrite Z.mul_split_mod + | solve [ auto with zarith ] + | lia + | progress autorewrite with push_mont_eval + | progress autounfold with word_by_word_montgomery + | match goal with + | [ H : and _ _ |- _ ] => destruct H + end ]. + Hint Rewrite + eval_zero + eval_div + eval_mod + eval_addT + eval_drop_high_addT' + eval_scmul + eval_conditional_sub + eval_sub_then_maybe_add + using (repeat autounfold with word_by_word_montgomery; t_small) + : push_mont_eval. + + Local Arguments eval {_} _. + Local Arguments small {_} _. + Local Arguments divmod {_} _. + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration_proofs. + Context (pred_A_numlimbs : nat) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)) + (small_A : small A) + (small_S : small S) + (S_nonneg : 0 <= eval S). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + + Local Coercion eval : T >-> Z. + + Local Notation a := (@a pred_A_numlimbs A). + Local Notation A' := (@A' pred_A_numlimbs A). + Local Notation S1 := (@S1 pred_A_numlimbs B A S). + Local Notation s := (@s pred_A_numlimbs B A S). + Local Notation q := (@q pred_A_numlimbs B k A S). + Local Notation S2 := (@S2 pred_A_numlimbs B k A S). + Local Notation S3 := (@S3' pred_A_numlimbs B k A S). + + Local Notation eval_pre_S3 := ((S + a * B + q * N) / r). + + Lemma eval_S3_eq : eval S3 = eval_pre_S3 mod (r * r ^ Z.of_nat R_numlimbs). + Proof. + unfold S3, S2, S1. + autorewrite with push_mont_eval push_Zof_nat; []. + rewrite !Z.pow_succ_r, <- ?Z.mul_assoc by omega. + rewrite Z.mod_pull_div by Z.zero_bounds. + do 2 f_equal; nia. + Qed. + + Lemma pre_S3_bound + : eval S < eval N + eval B + -> eval_pre_S3 < eval N + eval B. + Proof. + assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) + by (intros x y; pose proof (Z_mod_lt x y); omega). + intro HS. + eapply Z.le_lt_trans. + { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); + [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. + Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; + repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; + autorewrite with push_mont_eval; + try Z.zero_bounds; + auto with lia. } + rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. + autorewrite with zsimplify. + simpl; omega. + Qed. + + Lemma pre_S3_nonneg : 0 <= eval_pre_S3. + Proof. + repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; + autorewrite with push_mont_eval; []. + rewrite ?Npos_correct; Z.zero_bounds; lia. + Qed. + + Lemma small_A' + : small A'. + Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. + + Lemma small_S3 + : small S3. + Proof. repeat autounfold with word_by_word_montgomery; t_small. Qed. + + Lemma S3_nonneg : 0 <= eval S3. + Proof. rewrite eval_S3_eq; Z.zero_bounds. Qed. + + Lemma S3_bound + : eval S < eval N + eval B + -> eval S3 < eval N + eval B. + Proof. + rewrite eval_S3_eq. + intro H; pose proof (pre_S3_bound H); pose proof pre_S3_nonneg. + subst R. + rewrite Z.mod_small by nia. + assumption. + Qed. + + Lemma S1_eq : eval S1 = S + a*B. + Proof. + cbv [S1 a A']. + repeat autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma S2_mod_r_helper : (S + a*B + q * N) mod r = 0. + Proof. + cbv [S2 q s]; autorewrite with push_mont_eval; rewrite S1_eq. + assert (r > 0) by lia. + assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). + { destruct (Z.eq_dec r 1) as [H'|H']. + { rewrite H'; split; reflexivity. } + { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } + autorewrite with pull_Zmod. + replace 0 with (0 mod r) by apply Zmod_0_l. + pose (Z.to_pos r) as r'. + replace r with (Z.pos r') by (subst r'; rewrite Z2Pos.id; lia). + eapply F.eq_of_Z_iff. + rewrite Z.mul_split_mod. + repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. + rewrite <-!Algebra.Hierarchy.associative. + replace ((F.of_Z r' k * F.of_Z r' (eval N))%F) with (F.opp (m:=r') F.one). + { cbv [F.of_Z F.add]; simpl. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. + simpl. + subst r'; rewrite Z2Pos.id by lia. + rewrite (proj1 Hr), Z.mul_sub_distr_l. + push_Zmod; pull_Zmod. + apply (f_equal2 Z.modulo); omega. } + { rewrite <- F.of_Z_mul. + rewrite F.of_Z_mod. + subst r'; rewrite Z2Pos.id by lia. + rewrite k_correct. + cbv [F.of_Z F.add F.opp F.one]; simpl. + change (-(1)) with (-1) in *. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. + rewrite Z2Pos.id by lia. + rewrite (proj1 Hr), (proj2 Hr); Z.rewrite_mod_small; reflexivity. } + Qed. + + Lemma pre_S3_mod_N + : eval_pre_S3 mod N = (S + a*B)*ri mod N. + Proof. + pose proof fun a => Z.div_to_inv_modulo N a r ri ltac:(lia) ri_correct as HH; + cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. + etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| + rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. + rewrite S2_mod_r_helper. + push_Zmod; pull_Zmod; autorewrite with zsimplify_const. + reflexivity. + Qed. + + Lemma S3_mod_N + (Hbound : eval S < eval N + eval B) + : S3 mod N = (S + a*B)*ri mod N. + Proof. + rewrite eval_S3_eq. + pose proof (pre_S3_bound Hbound); pose proof pre_S3_nonneg. + rewrite (Z.mod_small _ (r * _)) by (subst R; nia). + apply pre_S3_mod_N. + Qed. + End Iteration_proofs. + + Section redc_proofs. + Local Notation redc_body := (@redc_body B k). + Local Notation redc_loop := (@redc_loop B k). + Local Notation pre_redc A := (@pre_redc _ A B k). + Local Notation redc A := (@redc _ A B k). + + Section body. + Context (pred_A_numlimbs : nat) + (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)). + Let A:=fst A_S. + Let S:=snd A_S. + Let A_a:=divmod A. + Let a:=snd A_a. + Context (small_A : small A) + (small_S : small S) + (S_bound : 0 <= eval S < eval N + eval B). + + Lemma small_fst_redc_body : small (fst (redc_body A_S)). + Proof. destruct A_S; apply small_A'; assumption. Qed. + Lemma small_snd_redc_body : small (snd (redc_body A_S)). + Proof. destruct A_S; unfold redc_body; apply small_S3; assumption. Qed. + Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). + Proof. destruct A_S; apply S3_nonneg; assumption. Qed. + + Lemma snd_redc_body_mod_N + : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). + Proof. destruct A_S; apply S3_mod_N; auto; omega. Qed. + + Lemma fst_redc_body + : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. + Proof. + destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. + autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma fst_redc_body_mod_N + : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). + Proof. + rewrite fst_redc_body. + etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. + unfold a, A_a, A. + autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma redc_body_bound + : eval S < eval N + eval B + -> eval (snd (redc_body A_S)) < eval N + eval B. + Proof. + destruct A_S; apply S3_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. + Qed. + End body. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.of_nat !_. + Local Ltac induction_loop count IHcount + := induction count as [|count IHcount]; intros; cbn [redc_loop nat_rect] in *; [ | (*rewrite redc_loop_comm_body in * *) ]. + Lemma redc_loop_good count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S))) + /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. + induction_loop count IHcount; auto; []. + change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). + destruct_head'_and. + repeat first [ apply conj + | apply small_fst_redc_body + | apply small_snd_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply IHcount + | solve [ auto ] ]. + Qed. + + Lemma small_redc_loop count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S)). + Proof. apply redc_loop_good; assumption. Qed. + + Lemma redc_loop_bound count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof. apply redc_loop_good; assumption. Qed. + + Local Ltac handle_IH_small := + repeat first [ apply redc_loop_good + | apply small_fst_redc_body + | apply small_snd_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply conj + | progress cbn [fst snd] + | progress destruct_head' and + | solve [ auto ] ]. + + Lemma fst_redc_loop count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). + Proof. + cbv [redc_loop]; induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount, fst_redc_body by handle_IH_small. + change (1 + R_numlimbs)%nat with (S R_numlimbs) in *. + rewrite Zdiv_Zdiv by Z.zero_bounds. + rewrite <- (Z.pow_1_r r) at 1. + rewrite <- Z.pow_add_r by lia. + replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia. + reflexivity. } + Qed. + + Lemma fst_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) mod (eval N) + = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) + * ri^(Z.of_nat count) mod (eval N). + Proof. + rewrite fst_redc_loop by assumption. + destruct count. + { simpl; autorewrite with zsimplify; reflexivity. } + { etransitivity; + [ eapply Z.div_to_inv_modulo; + try solve [ eassumption + | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] + | ]. + { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. + { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } + { lia. } } + reflexivity. } + Qed. + + Local Arguments Z.pow : simpl never. + Lemma snd_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (eval (snd (redc_loop count A_S))) mod (eval N) + = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). + Proof. + cbv [redc_loop]. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount by handle_IH_small. + push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod. + autorewrite with push_mont_eval; []. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + destruct A_S as [A S]. + cbn [fst snd]. + change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). + rewrite !Z.mul_add_distr_r. + rewrite <- !Z.mul_assoc. + replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count))) + by (change (Datatypes.S count) with (1 + count)%nat; + autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). + rewrite <- !Z.add_assoc. + apply Z.add_mod_Proper; [ reflexivity | ]. + unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). + rewrite Z.mod_pull_div by auto with zarith lia. + push_Zmod. + erewrite Z.div_to_inv_modulo; + [ + | apply Z.lt_gt; lia + | eassumption ]. + pull_Zmod. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia + | rewrite (Z.mul_comm _ ri) + | rewrite (Z.mul_assoc _ ri _) + | rewrite (Z.mul_comm _ (ri^_)) + | rewrite (Z.mul_assoc _ (ri^_) _) ]. + repeat first [ rewrite <- Z.mul_assoc + | rewrite <- Z.mul_add_distr_l + | rewrite (Z.mul_comm _ (eval B)) + | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; + rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) + | rewrite Zplus_minus + | rewrite (Z.mul_comm r (r^_)) + | reflexivity ]. } + Qed. + + Lemma pre_redc_bound A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : 0 <= eval (pre_redc A) < eval N + eval B. + Proof. + unfold pre_redc. + apply redc_loop_good; simpl; autorewrite with push_mont_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma small_pre_redc A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : small (pre_redc A). + Proof. + unfold pre_redc. + apply redc_loop_good; simpl; autorewrite with push_mont_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma pre_redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : (eval (pre_redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). + Proof. + unfold pre_redc. + rewrite snd_redc_loop_mod_N; cbn [fst snd]; + autorewrite with push_mont_eval zsimplify; + [ | rewrite ?Npos_correct; auto; lia.. ]. + Z.rewrite_mod_small. + reflexivity. + Qed. + + Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). + Proof. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + autorewrite with push_mont_eval; []. + break_innermost_match; + try rewrite Z.add_opp_r, Zminus_mod, Z_mod_same_full; + autorewrite with zsimplify_fast; + apply pre_redc_mod_N; auto. + Qed. + + Lemma redc_bound_tight A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : 0 <= eval (redc A) < eval N + eval B + if eval N <=? eval (pre_redc A) then -eval N else 0. + Proof. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; omega. + Qed. + + Lemma redc_bound_N A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : eval B < eval N -> 0 <= eval (redc A) < eval N. + Proof. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; omega. + Qed. + + Lemma redc_bound A_numlimbs (A : T A_numlimbs) + (small_A : small A) + (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : 0 <= eval (redc A) < R. + Proof. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; try omega. + Qed. + + Lemma small_redc A_numlimbs (A : T A_numlimbs) + (small_A : small A) + (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : small (redc A). + Proof. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + apply small_conditional_sub; [ apply small_pre_redc | .. ]; auto; omega. + Qed. + End redc_proofs. + + Section add_sub. + Context (Av Bv : T R_numlimbs) + (small_Av : small Av) + (small_Bv : small Bv) + (Av_bound : 0 <= eval Av < eval N) + (Bv_bound : 0 <= eval Bv < eval N). + + Local Ltac do_clear := + clear dependent B; clear dependent k; clear dependent ri. + + Lemma small_add : small (add Av Bv). + Proof. unfold add; t_small. Qed. + Lemma small_sub : small (sub Av Bv). + Proof. unfold sub; t_small. Qed. + Lemma small_opp : small (opp Av). + Proof. unfold opp, sub; t_small. Qed. + + Lemma eval_add : eval (add Av Bv) = eval Av + eval Bv + if (eval N <=? eval Av + eval Bv) then -eval N else 0. + Proof. unfold add; autorewrite with push_mont_eval; reflexivity. Qed. + Lemma eval_sub : eval (sub Av Bv) = eval Av - eval Bv + if (eval Av - eval Bv <? 0) then eval N else 0. + Proof. unfold sub; autorewrite with push_mont_eval; reflexivity. Qed. + Lemma eval_opp : eval (opp Av) = (if (eval Av =? 0) then 0 else eval N) - eval Av. + Proof. + unfold opp, sub; autorewrite with push_mont_eval. + break_innermost_match; Z.ltb_to_lt; lia. + Qed. + + Local Ltac t_mod_N := + repeat first [ progress break_innermost_match + | reflexivity + | let H := fresh in intro H; rewrite H; clear H + | progress autorewrite with zsimplify_const + | rewrite Z.add_opp_r + | progress (push_Zmod; pull_Zmod) ]. + + Lemma eval_add_mod_N : eval (add Av Bv) mod eval N = (eval Av + eval Bv) mod eval N. + Proof. generalize eval_add; clear. t_mod_N. Qed. + Lemma eval_sub_mod_N : eval (sub Av Bv) mod eval N = (eval Av - eval Bv) mod eval N. + Proof. generalize eval_sub; clear. t_mod_N. Qed. + Lemma eval_opp_mod_N : eval (opp Av) mod eval N = (-eval Av) mod eval N. + Proof. generalize eval_opp; clear; t_mod_N. Qed. + + Lemma add_bound : 0 <= eval (add Av Bv) < eval N. + Proof. generalize eval_add; break_innermost_match; Z.ltb_to_lt; lia. Qed. + Lemma sub_bound : 0 <= eval (sub Av Bv) < eval N. + Proof. generalize eval_sub; break_innermost_match; Z.ltb_to_lt; lia. Qed. + Lemma opp_bound : 0 <= eval (opp Av) < eval N. + Proof. generalize eval_opp; break_innermost_match; Z.ltb_to_lt; lia. Qed. + End add_sub. + End with_args. + + Section modops. + Context (bitwidth : Z) + (n : nat) + (m : Z). + Let r := 2^bitwidth. + Local Notation weight := (UniformWeight.uweight bitwidth). + Local Notation eval := (@eval bitwidth n). + Let m_enc := Rows.partition weight n m. + Local Coercion Z.of_nat : nat >-> Z. + Context (r' : Z) + (m' : Z) + (r'_correct : (r * r') mod m = 1) + (m'_correct : (m * m') mod r = (-1) mod r) + (bitwidth_big : 0 < bitwidth) + (m_big : 1 < m) + (n_nz : n <> 0%nat) + (m_small : m < r^n). + + Local Notation wprops := (@UniformWeight.uwprops bitwidth bitwidth_big). + Local Notation small := (@small bitwidth n). + + Local Hint Immediate (wprops). + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + + Local Lemma m_enc_correct_montgomery : m = eval m_enc. + Proof. + cbv [eval m_enc]; autorewrite with push_eval; auto. + rewrite UniformWeight.uweight_eq_alt by omega. + Z.rewrite_mod_small; reflexivity. + Qed. + Local Lemma r'_pow_correct : (r'^n * r^n) mod (eval m_enc) = 1. + Proof. + rewrite <- Z.pow_mul_l, Z.mod_pow_full, ?(Z.mul_comm r'), <- m_enc_correct_montgomery, r'_correct. + autorewrite with zsimplify_const; auto with omega. + Z.rewrite_mod_small; omega. + Qed. + Local Lemma small_m_enc : small m_enc. + Proof. + cbv [m_enc small eval]; autorewrite with push_eval; auto. + rewrite UniformWeight.uweight_eq_alt by omega. + Z.rewrite_mod_small; reflexivity. + Qed. + + Local Ltac t_fin := + repeat match goal with + | _ => assumption + | [ |- ?x = ?x ] => reflexivity + | [ |- and _ _ ] => split + | _ => rewrite <- !m_enc_correct_montgomery + | _ => rewrite !r'_correct + | _ => rewrite !Z.mod_1_l by assumption; reflexivity + | _ => rewrite !(Z.mul_comm m' m) + | _ => lia + | _ => exact small_m_enc + | [ H : small ?x |- context[eval ?x] ] + => rewrite H; cbv [eval]; rewrite Rows.eval_partition by auto + | [ |- context[weight _] ] => rewrite UniformWeight.uweight_eq_alt by auto with omega + | _=> progress Z.rewrite_mod_small + | _ => progress Z.zero_bounds + | [ |- _ mod ?x < ?x ] => apply Z.mod_pos_bound + end. + + Definition mulmod (a b : list Z) : list Z := @redc bitwidth n m_enc n a b m'. + Definition squaremod (a : list Z) : list Z := mulmod a a. + Definition addmod (a b : list Z) : list Z := @add bitwidth n m_enc a b. + Definition submod (a b : list Z) : list Z := @sub bitwidth n m_enc a b. + Definition oppmod (a : list Z) : list Z := @opp bitwidth n m_enc a. + Definition nonzeromod (a : list Z) : Z := @nonzero a. + Definition to_bytesmod (a : list Z) : list Z := @to_bytesmod bitwidth 1 n a. + + Definition valid (a : list Z) := small a /\ 0 <= eval a < m. + + Lemma mulmod_correct0 + : forall a b : list Z, + small a -> small b + -> small (mulmod a b) + /\ (eval b < m -> 0 <= eval (mulmod a b) < m) + /\ (eval (mulmod a b) mod m = (eval a * eval b * r'^n) mod m). + Proof. + intros a b Ha Hb; repeat apply conj; cbv [small mulmod eval]; + [ eapply small_redc + | rewrite m_enc_correct_montgomery; eapply redc_bound_N + | rewrite !m_enc_correct_montgomery; eapply redc_mod_N ]; + t_fin. + Qed. + + Definition onemod : list Z := Rows.partition weight n 1. + + Definition onemod_correct : eval onemod = 1 /\ valid onemod. + Proof. cbv [valid small onemod eval]; autorewrite with push_eval; t_fin. Qed. + + Definition R2mod : list Z := Rows.partition weight n ((r^n * r^n) mod m). + + Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. + Proof. + cbv [valid small R2mod eval]; autorewrite with push_eval; t_fin; + rewrite !(Z.mod_small (_ mod m)) by (Z.div_mod_to_quot_rem; subst r; lia); + t_fin. + Qed. + + Definition from_montgomery_mod (v : list Z) : list Z + := mulmod v onemod. + + Lemma from_montgomery_mod_correct (v : list Z) + : valid v -> eval (from_montgomery_mod v) mod m = (eval v * r'^n) mod m + /\ valid (from_montgomery_mod v). + Proof. + intro Hv; cbv [from_montgomery_mod valid] in *; destruct_head'_and. + replace (eval v * r'^n) with (eval v * eval onemod * r'^n) by (rewrite (proj1 onemod_correct); lia). + repeat apply conj; apply mulmod_correct0; auto; try apply onemod_correct; rewrite (proj1 onemod_correct); omega. + Qed. + + Lemma eval_from_montgomery_mod (v : list Z) : valid v -> eval (from_montgomery_mod v) mod m = (eval v * r'^n) mod m. + Proof. intros; apply from_montgomery_mod_correct; assumption. Qed. + Lemma valid_from_montgomery_mod (v : list Z) + : valid v -> valid (from_montgomery_mod v). + Proof. intros; apply from_montgomery_mod_correct; assumption. Qed. + + Lemma mulmod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomery_mod (mulmod a b)) mod m + = (eval (from_montgomery_mod a) * eval (from_montgomery_mod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (mulmod a b)). + Proof. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomery_mod; pull_Zmod; repeat apply conj; + try apply mulmod_correct0; cbv [valid] in *; destruct_head'_and; auto; []. + rewrite !Z.mul_assoc. + apply Z.mul_mod_Proper; [ | reflexivity ]. + cbv [Z.equiv_modulo]; etransitivity; [ apply mulmod_correct0 | apply f_equal2; lia ]; auto. + Qed. + + Lemma squaremod_correct + : (forall a (_ : valid a), eval (from_montgomery_mod (squaremod a)) mod m + = (eval (from_montgomery_mod a) * eval (from_montgomery_mod a)) mod m) + /\ (forall a (_ : valid a), valid (squaremod a)). + Proof. + split; intros; cbv [squaremod]; apply mulmod_correct; assumption. + Qed. + + Definition encodemod (v : Z) : list Z + := mulmod (Rows.partition weight n v) R2mod. + + Local Ltac t_valid v := + cbv [valid]; repeat apply conj; + auto; cbv [small eval]; autorewrite with push_eval; auto; + rewrite ?UniformWeight.uweight_eq_alt by omega; + Z.rewrite_mod_small; + rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia); + rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia); + try apply Z.mod_pos_bound; subst r; try lia; try reflexivity. + Lemma encodemod_correct + : (forall v, 0 <= v < m -> eval (from_montgomery_mod (encodemod v)) mod m = v mod m) + /\ (forall v, 0 <= v < m -> valid (encodemod v)). + Proof. + split; intros v ?; cbv [encodemod R2mod]; [ rewrite (proj1 mulmod_correct) | apply mulmod_correct ]; + [ | now t_valid v.. ]. + push_Zmod; rewrite !eval_from_montgomery_mod; [ | now t_valid v.. ]. + cbv [eval]; autorewrite with push_eval; auto. + rewrite ?UniformWeight.uweight_eq_alt by omega. + rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia). + rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia). + pull_Zmod. + rewrite <- !Z.mul_assoc; autorewrite with pull_Zpow. + generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. + rewrite (Z.mul_comm r' r); generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. + Z.rewrite_mod_small. + reflexivity. + Qed. + + Lemma addmod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomery_mod (addmod a b)) mod m + = (eval (from_montgomery_mod a) + eval (from_montgomery_mod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (addmod a b)). + Proof. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomery_mod; pull_Zmod; repeat apply conj; + cbv [valid addmod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_add || eapply add_bound); rewrite <- ?m_enc_correct_montgomery; eauto with omega; []. + push_Zmod; erewrite eval_add by (rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma submod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomery_mod (submod a b)) mod m + = (eval (from_montgomery_mod a) - eval (from_montgomery_mod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (submod a b)). + Proof. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomery_mod; pull_Zmod; repeat apply conj; + cbv [valid submod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_sub || eapply sub_bound); rewrite <- ?m_enc_correct_montgomery; eauto with omega; []. + push_Zmod; erewrite eval_sub by (rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma oppmod_correct + : (forall a (_ : valid a), eval (from_montgomery_mod (oppmod a)) mod m + = (-eval (from_montgomery_mod a)) mod m) + /\ (forall a (_ : valid a), valid (oppmod a)). + Proof. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomery_mod; pull_Zmod; repeat apply conj; + cbv [valid oppmod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_opp || eapply opp_bound); rewrite <- ?m_enc_correct_montgomery; eauto with omega; []. + push_Zmod; erewrite eval_opp by (rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma nonzeromod_correct + : (forall a (_ : valid a), (nonzeromod a = 0) <-> ((eval (from_montgomery_mod a)) mod m = 0)). + Proof. + intros a Ha; rewrite eval_from_montgomery_mod by assumption. + cbv [nonzeromod valid] in *; destruct_head'_and. + rewrite eval_nonzero; try eassumption; [ | subst r; apply conj; try eassumption; omega.. ]. + split; intro H'; [ rewrite H'; autorewrite with zsimplify_const; reflexivity | ]. + assert (H'' : ((eval a * r'^n) * r^n) mod m = 0) + by (revert H'; push_Zmod; intro H'; rewrite H'; autorewrite with zsimplify_const; reflexivity). + rewrite <- Z.mul_assoc in H''. + autorewrite with pull_Zpow push_Zmod in H''. + rewrite (Z.mul_comm r' r), r'_correct in H''. + autorewrite with zsimplify_const pull_Zmod in H''; [ | lia.. ]. + clear H'. + generalize dependent (eval a); clear. + intros z ???. + assert (z / m = 0) by (Z.div_mod_to_quot_rem; nia). + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma to_bytesmod_correct + : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) + = eval a mod m) + /\ (forall a (_ : valid a), to_bytesmod a = Rows.partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). + Proof. + generalize (@length_small bitwidth n); + cbv [valid small to_bytesmod eval]; split; intros; (etransitivity; [ apply eval_to_bytesmod | ]); + fold weight in *; fold (UniformWeight.uweight 8) in *; subst r; + try solve [ intuition eauto with omega ]. + all: repeat first [ rewrite UniformWeight.uweight_eq_alt by omega + | omega + | reflexivity + | progress Z.rewrite_mod_small ]. + Qed. + End modops. +End WordByWordMontgomery. |