(* TODO: prune these *) Require Import Crypto.Algebra.Nsatz. Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. Require Import Coq.Sorting.Permutation. Require Import Coq.derive.Derive. Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *) Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *) Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. Require Import Crypto.Arithmetic.BarrettReduction.Generalized. Require Import Crypto.Arithmetic.ModularArithmeticTheorems. Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.Tactics.RunTacticAsConstr. Require Import Crypto.Util.Tactics.Head. Require Import Crypto.Util.Option. Require Import Crypto.Util.OptionList. Require Import Crypto.Util.Prod. Require Import Crypto.Util.Sum. Require Import Crypto.Util.Bool. Require Import Crypto.Util.Sigma. Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. Require Import Crypto.Util.ZUtil.Tactics.PeelLe. Require Import Crypto.Util.ZUtil.Tactics.LinearSubstitute. Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. Require Import Crypto.Util.ZUtil.Modulo.PullPush. Require Import Crypto.Util.ZUtil.Opp. Require Import Crypto.Util.ZUtil.Log2. Require Import Crypto.Util.ZUtil.Le. Require Import Crypto.Util.ZUtil.Hints.PullPush. Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. Require Import Crypto.Util.Tactics.SpecializeBy. Require Import Crypto.Util.Tactics.SplitInContext. Require Import Crypto.Util.Tactics.SubstEvars. Require Import Crypto.Util.Notations. Require Import Crypto.Util.ZUtil.Definitions. Require Import Crypto.Util.ZUtil.Sorting. Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. Require Import Crypto.Util.ZUtil.Hints.Core. Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. Require Import Crypto.Util.ZUtil.Hints.PullPush. Require Import Crypto.Util.ZUtil.EquivModulo. Require Import Crypto.Util.Prod. Require Import Crypto.Util.CPSNotations. Require Import Crypto.Util.Equality. Require Import Crypto.Util.Tactics.SetEvars. Import Coq.Lists.List ListNotations. Local Open Scope Z_scope. Module WordByWordMontgomery. Import Partition. Local Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg. Section with_args. Context (lgr : Z) (m : Z). Local Notation weight := (UniformWeight.uweight lgr). Let T (n : nat) := list Z. Let r := (2^lgr). Definition eval {n} : T n -> Z := Positional.eval weight n. Let zero {n} : T n := Positional.zeros n. Let divmod {n} : T (S n) -> T n * Z := Rows.divmod. Let scmul {n} (c : Z) (p : T n) : T (S n) (* uses double-output multiply *) := let '(v, c) := Rows.mul weight r n (S n) (Positional.extend_to_length 1 n [c]) p in v. Let addT {n} (p q : T n) : T (S n) (* joins carry *) := let '(v, c) := Rows.add weight n p q in v ++ [c]. Let drop_high_addT' {n} (p : T (S n)) (q : T n) : T (S n) (* drops carry *) := fst (Rows.add weight (S n) p (Positional.extend_to_length n (S n) q)). Let conditional_sub {n} (arg : T (S n)) (N : T n) : T n (* computes [arg - N] if [N <= arg], and drops high bit *) := Positional.drop_high_to_length n (Rows.conditional_sub weight (S n) arg (Positional.extend_to_length n (S n) N)). Context (R_numlimbs : nat) (N : T R_numlimbs). (* encoding of m *) Let sub_then_maybe_add (a b : T R_numlimbs) : T R_numlimbs (* computes [a - b + if (a - b) T pred_A_numlimbs * T (S R_numlimbs) := fun '(A, S') => A'_S3 _ B k A S'. Definition redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs) := nat_rect (fun count => T count * _ -> _) (fun A_S => A_S) (fun count' redc_loop_count' A_S => redc_loop_count' (redc_body A_S)) count. Definition pre_redc : T (S R_numlimbs) := snd (redc_loop A_numlimbs (A, @zero (1 + R_numlimbs)%nat)). Definition redc : T R_numlimbs := conditional_sub pre_redc N. End loop. Create HintDb word_by_word_montgomery. Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. Definition add (A B : T R_numlimbs) : T R_numlimbs := conditional_sub (@addT _ A B) N. Definition sub (A B : T R_numlimbs) : T R_numlimbs := sub_then_maybe_add A B. Definition opp (A : T R_numlimbs) : T R_numlimbs := sub (@zero _) A. Definition nonzero (A : list Z) : Z := fold_right Z.lor 0 A. Context (lgr_big : 0 < lgr) (R_numlimbs_nz : R_numlimbs <> 0%nat). Let R := (r^Z.of_nat R_numlimbs). Transparent T. Definition small {n} (v : T n) : Prop := v = partition weight n (eval v). Context (small_N : small N) (N_lt_R : eval N < R) (N_nz : 0 < eval N) (B : T R_numlimbs) (B_bounds : 0 <= eval B < R) (small_B : small B) ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) (k : Z) (k_correct : k * eval N mod r = (-1) mod r). Local Lemma r_big : r > 1. Proof using lgr_big. clear -lgr_big; subst r. auto with zarith. Qed. Local Notation wprops := (@UniformWeight.uwprops lgr lgr_big). Local Hint Immediate (wprops). Local Hint Immediate (weight_0 wprops). Local Hint Immediate (weight_positive wprops). Local Hint Immediate (weight_multiples wprops). Local Hint Immediate (weight_divides wprops). Local Hint Immediate r_big. Lemma length_small {n v} : @small n v -> length v = n. Proof using Type. clear; cbv [small]; intro H; rewrite H; autorewrite with distr_length; reflexivity. Qed. Lemma small_bound {n v} : @small n v -> 0 <= eval v < weight n. Proof using lgr_big. clear - lgr_big; cbv [small eval]; intro H; rewrite H; autorewrite with push_eval; auto with zarith. Qed. Lemma R_plusR_le : R + R <= weight (S R_numlimbs). Proof using lgr_big. clear - lgr_big. etransitivity; [ | apply UniformWeight.uweight_double_le; omega ]. rewrite UniformWeight.uweight_eq_alt by omega. subst r R; omega. Qed. Lemma mask_r_sub1 n x : map (Z.land (r - 1)) (partition weight n x) = partition weight n x. Proof using lgr_big. clear - lgr_big. cbv [partition]. rewrite map_map. apply map_ext; intros. rewrite UniformWeight.uweight_S by omega. rewrite <-Z.mod_pull_div by auto with zarith. replace (r - 1) with (Z.ones lgr) by (rewrite Z.ones_equiv; subst r; reflexivity). rewrite <-Z.land_comm, Z.land_ones by omega. auto with zarith. Qed. Let partition_Proper := (@partition_Proper _ wprops). Local Existing Instance partition_Proper. Lemma eval_nonzero n A : @small n A -> nonzero A = 0 <-> @eval n A = 0. Proof using lgr_big. clear -lgr_big partition_Proper. cbv [nonzero eval small]; intro Heq. do 2 rewrite Heq. rewrite !eval_partition, Z.mod_mod by auto. generalize (Positional.eval weight n A); clear Heq A. induction n as [|n IHn]. { cbn; rewrite weight_0 by auto; intros; autorewrite with zsimplify_const; omega. } { intro; rewrite partition_step. rewrite fold_right_snoc, Z.lor_comm, <- fold_right_push, Z.lor_eq_0_iff by auto using Z.lor_assoc. assert (Heq : Z.equiv_modulo (weight n) (z mod weight (S n)) (z mod (weight n))). { cbv [Z.equiv_modulo]. generalize (weight_multiples ltac:(auto) n). generalize (weight_positive ltac:(auto) n). generalize (weight_positive ltac:(auto) (S n)). generalize (weight (S n)) (weight n); clear; intros wsn wn. clear; intros. Z.div_mod_to_quot_rem; subst. autorewrite with zsimplify_const in *. Z.linear_substitute_all. apply Zminus_eq; ring_simplify. rewrite <- !Z.add_opp_r, !Z.mul_opp_comm, <- !Z.mul_opp_r, <- !Z.mul_assoc. rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. nia. } rewrite Heq at 1; rewrite IHn. rewrite Z.mod_mod by auto. generalize (weight_multiples ltac:(auto) n). generalize (weight_positive ltac:(auto) n). generalize (weight_positive ltac:(auto) (S n)). generalize (weight (S n)) (weight n); clear; intros wsn wn; intros. Z.div_mod_to_quot_rem. repeat (intro || apply conj); destruct_head'_or; try omega; destruct_head'_and; subst; autorewrite with zsimplify_const in *; try nia; Z.linear_substitute_all. all: apply Zminus_eq; ring_simplify. all: rewrite <- ?Z.add_opp_r, ?Z.mul_opp_comm, <- ?Z.mul_opp_r, <- ?Z.mul_assoc. all: rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. all: nia. } Qed. Local Ltac push_step := first [ progress eta_expand | rewrite Rows.mul_partitions | rewrite Rows.mul_div | rewrite Rows.add_partitions | rewrite Rows.add_div | progress autorewrite with push_eval distr_length | match goal with | [ H : ?v = _ |- context[length ?v] ] => erewrite length_small by eassumption | [ H : small ?v |- context[length ?v] ] => erewrite length_small by eassumption end | rewrite Positional.eval_cons by distr_length | progress rewrite ?weight_0, ?UniformWeight.uweight_1 by auto; autorewrite with zsimplify_fast | rewrite (weight_0 wprops) | rewrite <- Z.div_mod'' by auto with omega | solve [ trivial ] ]. Local Ltac push := repeat push_step. Local Ltac t_step := match goal with | [ H := _ |- _ ] => progress cbv [H] in * | _ => progress push_step | _ => progress autorewrite with zsimplify_const | _ => solve [ auto with omega ] end. Local Hint Unfold eval zero small divmod scmul drop_high_addT' addT R : loc. Local Lemma eval_zero : forall n, eval (@zero n) = 0. Proof using Type. clear; autounfold with loc; intros; autorewrite with push_eval; auto. Qed. Local Lemma small_zero : forall n, small (@zero n). Proof using Type. etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. Qed. Local Hint Immediate small_zero. Ltac push_recursive_partition := repeat match goal with | _ => progress cbn [recursive_partition] | H : small _ |- _ => rewrite H; clear H | _ => rewrite recursive_partition_equiv by auto using wprops | _ => rewrite UniformWeight.uweight_eval_shift by distr_length | _ => progress push end. Lemma eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. Proof using lgr_big. pose proof r_big as r_big. clear - r_big lgr_big; intros; autounfold with loc. push_recursive_partition; cbn [Rows.divmod fst tl]. autorewrite with zsimplify; reflexivity. Qed. Lemma eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r. Proof using lgr_big. clear - lgr_big; intros; autounfold with loc. push_recursive_partition; cbn [Rows.divmod snd hd]. autorewrite with zsimplify; reflexivity. Qed. Lemma small_div : forall n v, small v -> small (fst (@divmod n v)). Proof using lgr_big. pose proof r_big as r_big. clear - r_big lgr_big. intros; autounfold with loc. push_recursive_partition. cbn [Rows.divmod fst tl]. rewrite <-recursive_partition_equiv by auto. rewrite <-UniformWeight.uweight_recursive_partition_equiv with (i:=1%nat) by omega. push. apply Partition.partition_Proper; [ solve [auto] | ]. cbv [Z.equiv_modulo]. autorewrite with zsimplify. reflexivity. Qed. Definition canon_rep {n} x (v : T n) : Prop := (v = partition weight n x) /\ (0 <= x < weight n). Lemma eval_canon_rep n x v : @canon_rep n x v -> eval v = x. Proof using lgr_big. clear - lgr_big. cbv [canon_rep eval]; intros [Hv Hx]. rewrite Hv. autorewrite with push_eval. auto using Z.mod_small. Qed. Lemma small_canon_rep n x v : @canon_rep n x v -> small v. Proof using lgr_big. clear - lgr_big. cbv [canon_rep eval small]; intros [Hv Hx]. rewrite Hv. autorewrite with push_eval. apply partition_eq_mod; auto; [ ]. Z.rewrite_mod_small; reflexivity. Qed. Local Lemma scmul_correct: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> canon_rep (a * eval v) (@scmul n a v). Proof using lgr_big. pose proof r_big as r_big. clear - lgr_big r_big. autounfold with loc; intro n; destruct (zerop n); intros until 0; intro Hsmall; intros. { intros; subst; cbn; rewrite Z.add_with_get_carry_full_mod. split; cbn; autorewrite with zsimplify_fast; auto with zarith. } { rewrite (surjective_pairing (Rows.mul _ _ _ _ _ _)). rewrite Rows.mul_partitions by (try rewrite Hsmall; auto using length_partition, Positional.length_extend_to_length with omega). autorewrite with push_eval. rewrite Positional.eval_cons by reflexivity. rewrite weight_0 by auto. autorewrite with push_eval zsimplify_fast. split; [reflexivity | ]. rewrite UniformWeight.uweight_S, UniformWeight.uweight_eq_alt by omega. subst r; nia. } Qed. Local Lemma addT_correct : forall n a b, small a -> small b -> canon_rep (eval a + eval b) (@addT n a b). Proof using lgr_big. intros n a b Ha Hb. generalize (length_small Ha); generalize (length_small Hb). generalize (small_bound Ha); generalize (small_bound Hb). clear -lgr_big Ha Hb. autounfold with loc; destruct (zerop n); subst. { destruct a, b; cbn; try omega; split; auto with zarith. } { pose proof (UniformWeight.uweight_double_le lgr ltac:(omega) n). eta_expand; split; [ | lia ]. rewrite Rows.add_partitions, Rows.add_div by auto. rewrite partition_step. Z.rewrite_mod_small; reflexivity. } Qed. Local Lemma drop_high_addT'_correct : forall n a b, small a -> small b -> canon_rep ((eval a + eval b) mod (r^Z.of_nat (S n))) (@drop_high_addT' n a b). Proof using lgr_big. intros n a b Ha Hb; generalize (length_small Ha); generalize (length_small Hb). clear -lgr_big Ha Hb. autounfold with loc in *; subst; intros. rewrite Rows.add_partitions by auto using Positional.length_extend_to_length. autorewrite with push_eval. split; try apply partition_eq_mod; auto; rewrite UniformWeight.uweight_eq_alt by omega; subst r; Z.rewrite_mod_small; auto with zarith. Qed. Local Lemma conditional_sub_correct : forall v, small v -> 0 <= eval v < eval N + R -> canon_rep (eval v + if eval N <=? eval v then -eval N else 0) (conditional_sub v N). Proof using small_N lgr_big N_nz N_lt_R. pose proof R_plusR_le as R_plusR_le. clear - small_N lgr_big N_nz N_lt_R R_plusR_le. intros; autounfold with loc; cbv [conditional_sub]. repeat match goal with H : small _ |- _ => rewrite H; clear H end. autorewrite with push_eval. assert (weight R_numlimbs < weight (S R_numlimbs)) by (rewrite !UniformWeight.uweight_eq_alt by omega; autorewrite with push_Zof_nat; auto with zarith). assert (eval N mod weight R_numlimbs < weight (S R_numlimbs)) by (pose proof (Z.mod_pos_bound (eval N) (weight R_numlimbs) ltac:(auto)); omega). rewrite Rows.conditional_sub_partitions by (repeat (autorewrite with distr_length push_eval; auto using partition_eq_mod with zarith)). rewrite drop_high_to_length_partition by omega. autorewrite with push_eval. assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst R; reflexivity). Z.rewrite_mod_small. break_match; autorewrite with zsimplify_fast; Z.ltb_to_lt. { split; [ reflexivity | ]. rewrite Z.add_opp_r. fold (eval N). auto using Z.mod_small with lia. } { split; auto using Z.mod_small with lia. } Qed. Local Lemma sub_then_maybe_add_correct : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> canon_rep (eval a - eval b + if eval a - eval b rewrite H; clear H end. rewrite Rows.sub_then_maybe_add_partitions by (autorewrite with push_eval distr_length; auto with zarith). autorewrite with push_eval. assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst r R; reflexivity). Z.rewrite_mod_small. split; [ reflexivity | ]. break_match; Z.ltb_to_lt; lia. Qed. Local Lemma eval_scmul: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> eval (@scmul n a v) = a * eval v. Proof using lgr_big. eauto using scmul_correct, eval_canon_rep. Qed. Local Lemma small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> small (@scmul n a v). Proof using lgr_big. eauto using scmul_correct, small_canon_rep. Qed. Local Lemma eval_addT : forall n a b, small a -> small b -> eval (@addT n a b) = eval a + eval b. Proof using lgr_big. eauto using addT_correct, eval_canon_rep. Qed. Local Lemma small_addT : forall n a b, small a -> small b -> small (@addT n a b). Proof using lgr_big. eauto using addT_correct, small_canon_rep. Qed. Local Lemma eval_drop_high_addT' : forall n a b, small a -> small b -> eval (@drop_high_addT' n a b) = (eval a + eval b) mod (r^Z.of_nat (S n)). Proof using lgr_big. eauto using drop_high_addT'_correct, eval_canon_rep. Qed. Local Lemma small_drop_high_addT' : forall n a b, small a -> small b -> small (@drop_high_addT' n a b). Proof using lgr_big. eauto using drop_high_addT'_correct, small_canon_rep. Qed. Local Lemma eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0. Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using conditional_sub_correct, eval_canon_rep. Qed. Local Lemma small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v N). Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using conditional_sub_correct, small_canon_rep. Qed. Local Lemma eval_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> small (sub_then_maybe_add a b). Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using sub_then_maybe_add_correct, small_canon_rep. Qed. Local Opaque T addT drop_high_addT' divmod zero scmul conditional_sub sub_then_maybe_add. Create HintDb push_mont_eval discriminated. Create HintDb word_by_word_montgomery. Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. Let r_big' := r_big. (* to put it in the context *) Local Ltac t_small := repeat first [ assumption | apply small_addT | apply small_drop_high_addT' | apply small_div | apply small_zero | apply small_scmul | apply small_conditional_sub | apply small_sub_then_maybe_add | apply Z_mod_lt | rewrite Z.mul_split_mod | solve [ auto with zarith ] | lia | progress autorewrite with push_mont_eval | progress autounfold with word_by_word_montgomery | match goal with | [ H : and _ _ |- _ ] => destruct H end ]. Hint Rewrite eval_zero eval_div eval_mod eval_addT eval_drop_high_addT' eval_scmul eval_conditional_sub eval_sub_then_maybe_add using (repeat autounfold with word_by_word_montgomery; t_small) : push_mont_eval. Local Arguments eval {_} _. Local Arguments small {_} _. Local Arguments divmod {_} _. (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) Section Iteration_proofs. Context (pred_A_numlimbs : nat) (A : T (S pred_A_numlimbs)) (S : T (S R_numlimbs)) (small_A : small A) (small_S : small S) (S_nonneg : 0 <= eval S). (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) Local Coercion eval : T >-> Z. Local Notation a := (@a pred_A_numlimbs A). Local Notation A' := (@A' pred_A_numlimbs A). Local Notation S1 := (@S1 pred_A_numlimbs B A S). Local Notation s := (@s pred_A_numlimbs B A S). Local Notation q := (@q pred_A_numlimbs B k A S). Local Notation S2 := (@S2 pred_A_numlimbs B k A S). Local Notation S3 := (@S3' pred_A_numlimbs B k A S). Local Notation eval_pre_S3 := ((S + a * B + q * N) / r). Lemma eval_S3_eq : eval S3 = eval_pre_S3 mod (r * r ^ Z.of_nat R_numlimbs). Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big. unfold S3, S2, S1. autorewrite with push_mont_eval push_Zof_nat; []. rewrite !Z.pow_succ_r, <- ?Z.mul_assoc by omega. rewrite Z.mod_pull_div by Z.zero_bounds. do 2 f_equal; nia. Qed. Lemma pre_S3_bound : eval S < eval N + eval B -> eval_pre_S3 < eval N + eval B. Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big. assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) by (intros x y; pose proof (Z_mod_lt x y); omega). intro HS. eapply Z.le_lt_trans. { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; autorewrite with push_mont_eval; try Z.zero_bounds; auto with lia. } rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. autorewrite with zsimplify. simpl; omega. Qed. Lemma pre_S3_nonneg : 0 <= eval_pre_S3. Proof using N_nz B_bounds small_B small_A small_S S_nonneg lgr_big. clear -N_nz B_bounds small_B partition_Proper r_big' small_A small_S S_nonneg. repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; autorewrite with push_mont_eval; []. rewrite ?Npos_correct; Z.zero_bounds; lia. Qed. Lemma small_A' : small A'. Proof using small_A lgr_big. repeat autounfold with word_by_word_montgomery; t_small. Qed. Lemma small_S3 : small S3. Proof using small_A small_S small_N N_lt_R N_nz B_bounds small_B lgr_big. clear -small_A small_S small_N N_lt_R N_nz B_bounds small_B partition_Proper r_big'. repeat autounfold with word_by_word_montgomery; t_small. Qed. Lemma S3_nonneg : 0 <= eval S3. Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big sub_then_maybe_add. rewrite eval_S3_eq; Z.zero_bounds. Qed. Lemma S3_bound : eval S < eval N + eval B -> eval S3 < eval N + eval B. Proof using N_nz B_bounds small_B small_A small_S S_nonneg B_bounds N_nz N_lt_R small_N lgr_big. clear -N_nz B_bounds small_B small_A small_S S_nonneg B_bounds N_nz N_lt_R small_N lgr_big partition_Proper r_big' sub_then_maybe_add. rewrite eval_S3_eq. intro H; pose proof (pre_S3_bound H); pose proof pre_S3_nonneg. subst R. rewrite Z.mod_small by nia. assumption. Qed. Lemma S1_eq : eval S1 = S + a*B. Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S. clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper. cbv [S1 a A']. repeat autorewrite with push_mont_eval. reflexivity. Qed. Lemma S2_mod_r_helper : (S + a*B + q * N) mod r = 0. Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct. clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct. cbv [S2 q s]; autorewrite with push_mont_eval; rewrite S1_eq. assert (r > 0) by lia. assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). { destruct (Z.eq_dec r 1) as [H'|H']. { rewrite H'; split; reflexivity. } { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } autorewrite with pull_Zmod. replace 0 with (0 mod r) by apply Zmod_0_l. pose (Z.to_pos r) as r'. replace r with (Z.pos r') by (subst r'; rewrite Z2Pos.id; lia). eapply F.eq_of_Z_iff. rewrite Z.mul_split_mod. repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. rewrite <-!Algebra.Hierarchy.associative. replace ((F.of_Z r' k * F.of_Z r' (eval N))%F) with (F.opp (m:=r') F.one). { cbv [F.of_Z F.add]; simpl. apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. simpl. subst r'; rewrite Z2Pos.id by lia. rewrite (proj1 Hr), Z.mul_sub_distr_l. push_Zmod; pull_Zmod. apply (f_equal2 Z.modulo); omega. } { rewrite <- F.of_Z_mul. rewrite F.of_Z_mod. subst r'; rewrite Z2Pos.id by lia. rewrite k_correct. cbv [F.of_Z F.add F.opp F.one]; simpl. change (-(1)) with (-1) in *. apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. rewrite Z2Pos.id by lia. rewrite (proj1 Hr), (proj2 Hr); Z.rewrite_mod_small; reflexivity. } Qed. Lemma pre_S3_mod_N : eval_pre_S3 mod N = (S + a*B)*ri mod N. Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct ri_correct. clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct ri_correct sub_then_maybe_add. pose proof fun a => Z.div_to_inv_modulo N a r ri ltac:(lia) ri_correct as HH; cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. rewrite S2_mod_r_helper. push_Zmod; pull_Zmod; autorewrite with zsimplify_const. reflexivity. Qed. Lemma S3_mod_N (Hbound : eval S < eval N + eval B) : S3 mod N = (S + a*B)*ri mod N. Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct ri_correct small_N N_lt_R N_nz S_nonneg. clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct ri_correct N_nz N_lt_R small_N sub_then_maybe_add Hbound S_nonneg. rewrite eval_S3_eq. pose proof (pre_S3_bound Hbound); pose proof pre_S3_nonneg. rewrite (Z.mod_small _ (r * _)) by (subst R; nia). apply pre_S3_mod_N. Qed. End Iteration_proofs. Section redc_proofs. Local Notation redc_body := (@redc_body B k). Local Notation redc_loop := (@redc_loop B k). Local Notation pre_redc A := (@pre_redc _ A B k). Local Notation redc A := (@redc _ A B k). Section body. Context (pred_A_numlimbs : nat) (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)). Let A:=fst A_S. Let S:=snd A_S. Let A_a:=divmod A. Let a:=snd A_a. Context (small_A : small A) (small_S : small S) (S_bound : 0 <= eval S < eval N + eval B). Lemma small_fst_redc_body : small (fst (redc_body A_S)). Proof using S_bound small_A small_S lgr_big. destruct A_S; apply small_A'; assumption. Qed. Lemma small_snd_redc_body : small (snd (redc_body A_S)). Proof using small_S small_N small_B small_A lgr_big S_bound B_bounds N_nz N_lt_R. destruct A_S; unfold redc_body; apply small_S3; assumption. Qed. Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). Proof using small_S small_N small_B small_A lgr_big S_bound N_nz N_lt_R B_bounds. destruct A_S; apply S3_nonneg; assumption. Qed. Lemma snd_redc_body_mod_N : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). Proof using small_S small_N small_B small_A ri_correct lgr_big k_correct S_bound R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_S small_N small_B small_A ri_correct k_correct S_bound R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add r_big' partition_Proper. destruct A_S; apply S3_mod_N; auto; omega. Qed. Lemma fst_redc_body : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. Proof using small_S small_A S_bound lgr_big. destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. autorewrite with push_mont_eval. reflexivity. Qed. Lemma fst_redc_body_mod_N : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). Proof using small_S small_A ri_correct lgr_big S_bound. rewrite fst_redc_body. etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. unfold a, A_a, A. autorewrite with push_mont_eval. reflexivity. Qed. Lemma redc_body_bound : eval S < eval N + eval B -> eval (snd (redc_body A_S)) < eval N + eval B. Proof using small_S small_N small_B small_A lgr_big S_bound N_nz N_lt_R B_bounds. clear -small_S small_N small_B small_A S_bound N_nz N_lt_R B_bounds r_big' partition_Proper sub_then_maybe_add. destruct A_S; apply S3_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. Qed. End body. Local Arguments Z.pow !_ !_. Local Arguments Z.of_nat !_. Local Ltac induction_loop count IHcount := induction count as [|count IHcount]; intros; cbn [redc_loop nat_rect] in *; [ | (*rewrite redc_loop_comm_body in * *) ]. Lemma redc_loop_good count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : (small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S))) /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. induction_loop count IHcount; auto; []. change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). destruct_head'_and. repeat first [ apply conj | apply small_fst_redc_body | apply small_snd_redc_body | apply redc_body_bound | apply snd_redc_body_nonneg | apply IHcount | solve [ auto ] ]. Qed. Lemma small_redc_loop count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S)). Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. apply redc_loop_good; assumption. Qed. Lemma redc_loop_bound count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. apply redc_loop_good; assumption. Qed. Local Ltac handle_IH_small := repeat first [ apply redc_loop_good | apply small_fst_redc_body | apply small_snd_redc_body | apply redc_body_bound | apply snd_redc_body_nonneg | apply conj | progress cbn [fst snd] | progress destruct_head' and | solve [ auto ] ]. Lemma fst_redc_loop count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add Hsmall Hbound. cbv [redc_loop]; induction_loop count IHcount. { simpl; autorewrite with zsimplify; reflexivity. } { rewrite IHcount, fst_redc_body by handle_IH_small. change (1 + R_numlimbs)%nat with (S R_numlimbs) in *. rewrite Zdiv_Zdiv by Z.zero_bounds. rewrite <- (Z.pow_1_r r) at 1. rewrite <- Z.pow_add_r by lia. replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia. reflexivity. } Qed. Lemma fst_redc_loop_mod_N count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : eval (fst (redc_loop count A_S)) mod (eval N) = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) * ri^(Z.of_nat count) mod (eval N). Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds ri_correct. clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add Hsmall Hbound ri_correct. rewrite fst_redc_loop by assumption. destruct count. { simpl; autorewrite with zsimplify; reflexivity. } { etransitivity; [ eapply Z.div_to_inv_modulo; try solve [ eassumption | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] | ]. { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } { lia. } } reflexivity. } Qed. Local Arguments Z.pow : simpl never. Lemma snd_redc_loop_mod_N count A_S (Hsmall : small (fst A_S) /\ small (snd A_S)) (Hbound : 0 <= eval (snd A_S) < eval N + eval B) : (eval (snd (redc_loop count A_S))) mod (eval N) = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). Proof using small_N small_B ri_correct lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds k_correct. clear -small_N small_B ri_correct r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add k_correct Hsmall Hbound. cbv [redc_loop]. induction_loop count IHcount. { simpl; autorewrite with zsimplify; reflexivity. } { rewrite IHcount by handle_IH_small. push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod. autorewrite with push_mont_eval; []. match goal with | [ |- ?x mod ?N = ?y mod ?N ] => change (Z.equiv_modulo N x y) end. destruct A_S as [A S]. cbn [fst snd]. change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). rewrite !Z.mul_add_distr_r. rewrite <- !Z.mul_assoc. replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count))) by (change (Datatypes.S count) with (1 + count)%nat; autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). rewrite <- !Z.add_assoc. apply Z.add_mod_Proper; [ reflexivity | ]. unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). rewrite Z.mod_pull_div by auto with zarith lia. push_Zmod. erewrite Z.div_to_inv_modulo; [ | apply Z.lt_gt; lia | eassumption ]. pull_Zmod. match goal with | [ |- ?x mod ?N = ?y mod ?N ] => change (Z.equiv_modulo N x y) end. repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia | rewrite (Z.mul_comm _ ri) | rewrite (Z.mul_assoc _ ri _) | rewrite (Z.mul_comm _ (ri^_)) | rewrite (Z.mul_assoc _ (ri^_) _) ]. repeat first [ rewrite <- Z.mul_assoc | rewrite <- Z.mul_add_distr_l | rewrite (Z.mul_comm _ (eval B)) | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) | rewrite Zplus_minus | rewrite (Z.mul_comm r (r^_)) | reflexivity ]. } Qed. Lemma pre_redc_bound A_numlimbs (A : T A_numlimbs) (small_A : small A) : 0 <= eval (pre_redc A) < eval N + eval B. Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds sub_then_maybe_add small_A. unfold pre_redc. apply redc_loop_good; simpl; autorewrite with push_mont_eval; rewrite ?Npos_correct; auto; lia. Qed. Lemma small_pre_redc A_numlimbs (A : T A_numlimbs) (small_A : small A) : small (pre_redc A). Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds sub_then_maybe_add small_A. unfold pre_redc. apply redc_loop_good; simpl; autorewrite with push_mont_eval; rewrite ?Npos_correct; auto; lia. Qed. Lemma pre_redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) : (eval (pre_redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds R_numlimbs_nz ri_correct k_correct. clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds R_numlimbs_nz ri_correct k_correct sub_then_maybe_add small_A A_bound. unfold pre_redc. rewrite snd_redc_loop_mod_N; cbn [fst snd]; autorewrite with push_mont_eval zsimplify; [ | rewrite ?Npos_correct; auto; lia.. ]. Z.rewrite_mod_small. reflexivity. Qed. Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). Proof using small_N small_B ri_correct lgr_big k_correct R_numlimbs_nz N_nz N_lt_R B_bounds. pose proof (@small_pre_redc _ A small_A). pose proof (@pre_redc_bound _ A small_A). unfold redc. autorewrite with push_mont_eval; []. break_innermost_match; try rewrite Z.add_opp_r, Zminus_mod, Z_mod_same_full; autorewrite with zsimplify_fast; apply pre_redc_mod_N; auto. Qed. Lemma redc_bound_tight A_numlimbs (A : T A_numlimbs) (small_A : small A) : 0 <= eval (redc A) < eval N + eval B + if eval N <=? eval (pre_redc A) then -eval N else 0. Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds r_big' partition_Proper small_A sub_then_maybe_add. pose proof (@small_pre_redc _ A small_A). pose proof (@pre_redc_bound _ A small_A). unfold redc. rewrite eval_conditional_sub by t_small. break_innermost_match; Z.ltb_to_lt; omega. Qed. Lemma redc_bound_N A_numlimbs (A : T A_numlimbs) (small_A : small A) : eval B < eval N -> 0 <= eval (redc A) < eval N. Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add. pose proof (@small_pre_redc _ A small_A). pose proof (@pre_redc_bound _ A small_A). unfold redc. rewrite eval_conditional_sub by t_small. break_innermost_match; Z.ltb_to_lt; omega. Qed. Lemma redc_bound A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) : 0 <= eval (redc A) < R. Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add A_bound. pose proof (@small_pre_redc _ A small_A). pose proof (@pre_redc_bound _ A small_A). unfold redc. rewrite eval_conditional_sub by t_small. break_innermost_match; Z.ltb_to_lt; try omega. Qed. Lemma small_redc A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) : small (redc A). Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add A_bound. pose proof (@small_pre_redc _ A small_A). pose proof (@pre_redc_bound _ A small_A). unfold redc. apply small_conditional_sub; [ apply small_pre_redc | .. ]; auto; omega. Qed. End redc_proofs. Section add_sub. Context (Av Bv : T R_numlimbs) (small_Av : small Av) (small_Bv : small Bv) (Av_bound : 0 <= eval Av < eval N) (Bv_bound : 0 <= eval Bv < eval N). Local Ltac do_clear := clear dependent B; clear dependent k; clear dependent ri. Lemma small_add : small (add Av Bv). Proof using small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound small_N ri k R_numlimbs_nz N_nz B_bounds B. clear -small_Bv small_Av N_lt_R Bv_bound Av_bound partition_Proper r_big' small_N ri k R_numlimbs_nz N_nz B_bounds B sub_then_maybe_add. unfold add; t_small. Qed. Lemma small_sub : small (sub Av Bv). Proof using small_N small_Bv small_Av partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R Bv_bound Av_bound. unfold sub; t_small. Qed. Lemma small_opp : small (opp Av). Proof using small_N small_Bv small_Av partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R Av_bound. unfold opp, sub; t_small. Qed. Lemma eval_add : eval (add Av Bv) = eval Av + eval Bv + if (eval N <=? eval Av + eval Bv) then -eval N else 0. Proof using small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound small_N ri k R_numlimbs_nz N_nz B_bounds B. clear -small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound partition_Proper r_big' small_N ri k R_numlimbs_nz N_nz B_bounds B sub_then_maybe_add. unfold add; autorewrite with push_mont_eval; reflexivity. Qed. Lemma eval_sub : eval (sub Av Bv) = eval Av - eval Bv + if (eval Av - eval Bv -> Z. Context (r' : Z) (m' : Z) (r'_correct : (r * r') mod m = 1) (m'_correct : (m * m') mod r = (-1) mod r) (bitwidth_big : 0 < bitwidth) (m_big : 1 < m) (n_nz : n <> 0%nat) (m_small : m < r^n). Local Notation wprops := (@UniformWeight.uwprops bitwidth bitwidth_big). Local Notation small := (@small bitwidth n). Local Hint Immediate (wprops). Local Hint Immediate (weight_0 wprops). Local Hint Immediate (weight_positive wprops). Local Hint Immediate (weight_multiples wprops). Local Hint Immediate (weight_divides wprops). Local Lemma m_enc_correct_montgomery : m = eval m_enc. Proof using m_small m_big bitwidth_big. clear -m_small m_big bitwidth_big. cbv [eval m_enc]; autorewrite with push_eval; auto. rewrite UniformWeight.uweight_eq_alt by omega. Z.rewrite_mod_small; reflexivity. Qed. Local Lemma r'_pow_correct : (r'^n * r^n) mod (eval m_enc) = 1. Proof using r'_correct m_small m_big bitwidth_big. clear -r'_correct m_small m_big bitwidth_big. rewrite <- Z.pow_mul_l, Z.mod_pow_full, ?(Z.mul_comm r'), <- m_enc_correct_montgomery, r'_correct. autorewrite with zsimplify_const; auto with omega. Z.rewrite_mod_small; omega. Qed. Local Lemma small_m_enc : small m_enc. Proof using m_small m_big bitwidth_big. clear -m_small m_big bitwidth_big. cbv [m_enc small eval]; autorewrite with push_eval; auto. rewrite UniformWeight.uweight_eq_alt by omega. Z.rewrite_mod_small; reflexivity. Qed. Local Ltac t_fin := repeat match goal with | _ => assumption | [ |- ?x = ?x ] => reflexivity | [ |- and _ _ ] => split | _ => rewrite <- !m_enc_correct_montgomery | _ => rewrite !r'_correct | _ => rewrite !Z.mod_1_l by assumption; reflexivity | _ => rewrite !(Z.mul_comm m' m) | _ => lia | _ => exact small_m_enc | [ H : small ?x |- context[eval ?x] ] => rewrite H; cbv [eval]; rewrite eval_partition by auto | [ |- context[weight _] ] => rewrite UniformWeight.uweight_eq_alt by auto with omega | _=> progress Z.rewrite_mod_small | _ => progress Z.zero_bounds | [ |- _ mod ?x < ?x ] => apply Z.mod_pos_bound end. Definition mulmod (a b : list Z) : list Z := @redc bitwidth n m_enc n a b m'. Definition squaremod (a : list Z) : list Z := mulmod a a. Definition addmod (a b : list Z) : list Z := @add bitwidth n m_enc a b. Definition submod (a b : list Z) : list Z := @sub bitwidth n m_enc a b. Definition oppmod (a : list Z) : list Z := @opp bitwidth n m_enc a. Definition nonzeromod (a : list Z) : Z := @nonzero a. Definition to_bytesmod (a : list Z) : list Z := @to_bytesmod bitwidth 1 n a. Definition valid (a : list Z) := small a /\ 0 <= eval a < m. Lemma mulmod_correct0 : forall a b : list Z, small a -> small b -> small (mulmod a b) /\ (eval b < m -> 0 <= eval (mulmod a b) < m) /\ (eval (mulmod a b) mod m = (eval a * eval b * r'^n) mod m). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. intros a b Ha Hb; repeat apply conj; cbv [small mulmod eval]; [ eapply small_redc | rewrite m_enc_correct_montgomery; eapply redc_bound_N | rewrite !m_enc_correct_montgomery; eapply redc_mod_N ]; t_fin. Qed. Definition onemod : list Z := partition weight n 1. Definition onemod_correct : eval onemod = 1 /\ valid onemod. Proof using n_nz m_big bitwidth_big. clear -n_nz m_big bitwidth_big. cbv [valid small onemod eval]; autorewrite with push_eval; t_fin. Qed. Lemma eval_onemod : eval onemod = 1. Proof. apply onemod_correct. Qed. Definition R2mod : list Z := partition weight n ((r^n * r^n) mod m). Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. Proof using n_nz m_small m_big m'_correct bitwidth_big. clear -n_nz m_small m_big m'_correct bitwidth_big. cbv [valid small R2mod eval]; autorewrite with push_eval; t_fin; rewrite !(Z.mod_small (_ mod m)) by (Z.div_mod_to_quot_rem; subst r; lia); t_fin. Qed. Definition from_montgomerymod (v : list Z) : list Z := mulmod v onemod. Lemma from_montgomerymod_correct (v : list Z) : valid v -> eval (from_montgomerymod v) mod m = (eval v * r'^n) mod m /\ valid (from_montgomerymod v). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. clear -r'_correct n_nz m_small m_big m'_correct bitwidth_big. intro Hv; cbv [from_montgomerymod valid] in *; destruct_head'_and. replace (eval v * r'^n) with (eval v * eval onemod * r'^n) by (rewrite (proj1 onemod_correct); lia). repeat apply conj; apply mulmod_correct0; auto; try apply onemod_correct; rewrite (proj1 onemod_correct); omega. Qed. Lemma eval_from_montgomerymod (v : list Z) : valid v -> eval (from_montgomerymod v) mod m = (eval v * r'^n) mod m. Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. intros; apply from_montgomerymod_correct; assumption. Qed. Lemma valid_from_montgomerymod (v : list Z) : valid v -> valid (from_montgomerymod v). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. intros; apply from_montgomerymod_correct; assumption. Qed. Lemma mulmod_correct : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (mulmod a b)) mod m = (eval (from_montgomerymod a) * eval (from_montgomerymod b)) mod m) /\ (forall a (_ : valid a) b (_ : valid b), valid (mulmod a b)). Proof using r'_correct r' n_nz m_small m_big m'_correct bitwidth_big. repeat apply conj; intros; push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; try apply mulmod_correct0; cbv [valid] in *; destruct_head'_and; auto; []. rewrite !Z.mul_assoc. apply Z.mul_mod_Proper; [ | reflexivity ]. cbv [Z.equiv_modulo]; etransitivity; [ apply mulmod_correct0 | apply f_equal2; lia ]; auto. Qed. Lemma eval_mulmod : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (mulmod a b)) mod m = (eval (from_montgomerymod a) * eval (from_montgomerymod b)) mod m). Proof. apply mulmod_correct. Qed. Lemma squaremod_correct : (forall a (_ : valid a), eval (from_montgomerymod (squaremod a)) mod m = (eval (from_montgomerymod a) * eval (from_montgomerymod a)) mod m) /\ (forall a (_ : valid a), valid (squaremod a)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. split; intros; cbv [squaremod]; apply mulmod_correct; assumption. Qed. Lemma eval_squaremod : (forall a (_ : valid a), eval (from_montgomerymod (squaremod a)) mod m = (eval (from_montgomerymod a) * eval (from_montgomerymod a)) mod m). Proof. apply squaremod_correct. Qed. Definition encodemod (v : Z) : list Z := mulmod (partition weight n v) R2mod. Local Ltac t_valid v := cbv [valid]; repeat apply conj; auto; cbv [small eval]; autorewrite with push_eval; auto; rewrite ?UniformWeight.uweight_eq_alt by omega; Z.rewrite_mod_small; rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia); rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia); try apply Z.mod_pos_bound; subst r; try lia; try reflexivity. Lemma encodemod_correct : (forall v, 0 <= v < m -> eval (from_montgomerymod (encodemod v)) mod m = v mod m) /\ (forall v, 0 <= v < m -> valid (encodemod v)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. split; intros v ?; cbv [encodemod R2mod]; [ rewrite (proj1 mulmod_correct) | apply mulmod_correct ]; [ | now t_valid v.. ]. push_Zmod; rewrite !eval_from_montgomerymod; [ | now t_valid v.. ]. cbv [eval]; autorewrite with push_eval; auto. rewrite ?UniformWeight.uweight_eq_alt by omega. rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia). rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia). pull_Zmod. rewrite <- !Z.mul_assoc; autorewrite with pull_Zpow. generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. rewrite (Z.mul_comm r' r); generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. Z.rewrite_mod_small. reflexivity. Qed. Lemma eval_encodemod : (forall v, 0 <= v < m -> eval (from_montgomerymod (encodemod v)) mod m = v mod m). Proof. apply encodemod_correct. Qed. Lemma addmod_correct : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (addmod a b)) mod m = (eval (from_montgomerymod a) + eval (from_montgomerymod b)) mod m) /\ (forall a (_ : valid a) b (_ : valid b), valid (addmod a b)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. repeat apply conj; intros; push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; cbv [valid addmod] in *; destruct_head'_and; auto; try rewrite m_enc_correct_montgomery; try (eapply small_add || eapply add_bound); cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega; [ ]. push_Zmod; erewrite eval_add by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. Qed. Lemma eval_addmod : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (addmod a b)) mod m = (eval (from_montgomerymod a) + eval (from_montgomerymod b)) mod m). Proof. apply addmod_correct. Qed. Lemma submod_correct : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (submod a b)) mod m = (eval (from_montgomerymod a) - eval (from_montgomerymod b)) mod m) /\ (forall a (_ : valid a) b (_ : valid b), valid (submod a b)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. repeat apply conj; intros; push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; cbv [valid submod] in *; destruct_head'_and; auto; try rewrite m_enc_correct_montgomery; try (eapply small_sub || eapply sub_bound); cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega; [ ]. push_Zmod; erewrite eval_sub by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. Qed. Lemma eval_submod : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (submod a b)) mod m = (eval (from_montgomerymod a) - eval (from_montgomerymod b)) mod m). Proof. apply submod_correct. Qed. Lemma oppmod_correct : (forall a (_ : valid a), eval (from_montgomerymod (oppmod a)) mod m = (-eval (from_montgomerymod a)) mod m) /\ (forall a (_ : valid a), valid (oppmod a)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. repeat apply conj; intros; push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; cbv [valid oppmod] in *; destruct_head'_and; auto; try rewrite m_enc_correct_montgomery; try (eapply small_opp || eapply opp_bound); cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega; [ ]. push_Zmod; erewrite eval_opp by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. Qed. Lemma eval_oppmod : (forall a (_ : valid a), eval (from_montgomerymod (oppmod a)) mod m = (-eval (from_montgomerymod a)) mod m). Proof. apply oppmod_correct. Qed. Lemma nonzeromod_correct : (forall a (_ : valid a), (nonzeromod a = 0) <-> ((eval (from_montgomerymod a)) mod m = 0)). Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. intros a Ha; rewrite eval_from_montgomerymod by assumption. cbv [nonzeromod valid] in *; destruct_head'_and. rewrite eval_nonzero; try eassumption; [ | subst r; apply conj; try eassumption; omega.. ]. split; intro H'; [ rewrite H'; autorewrite with zsimplify_const; reflexivity | ]. assert (H'' : ((eval a * r'^n) * r^n) mod m = 0) by (revert H'; push_Zmod; intro H'; rewrite H'; autorewrite with zsimplify_const; reflexivity). rewrite <- Z.mul_assoc in H''. autorewrite with pull_Zpow push_Zmod in H''. rewrite (Z.mul_comm r' r), r'_correct in H''. autorewrite with zsimplify_const pull_Zmod in H''; [ | lia.. ]. clear H'. generalize dependent (eval a); clear. intros z ???. assert (z / m = 0) by (Z.div_mod_to_quot_rem; nia). Z.div_mod_to_quot_rem; nia. Qed. Lemma to_bytesmod_correct : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) = eval a mod m) /\ (forall a (_ : valid a), to_bytesmod a = partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). Proof using n_nz m_small bitwidth_big. clear -n_nz m_small bitwidth_big. generalize (@length_small bitwidth n); cbv [valid small to_bytesmod eval]; split; intros; (etransitivity; [ apply eval_to_bytesmod | ]); fold weight in *; fold (UniformWeight.uweight 8) in *; subst r; try solve [ intuition eauto with omega ]. all: repeat first [ rewrite UniformWeight.uweight_eq_alt by omega | omega | reflexivity | progress Z.rewrite_mod_small ]. Qed. Lemma eval_to_bytesmod : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) = eval a mod m). Proof. apply to_bytesmod_correct. Qed. End modops. End WordByWordMontgomery.