diff options
Diffstat (limited to 'src/Arithmetic/WordByWordMontgomery.v')
-rw-r--r-- | src/Arithmetic/WordByWordMontgomery.v | 1311 |
1 files changed, 1311 insertions, 0 deletions
diff --git a/src/Arithmetic/WordByWordMontgomery.v b/src/Arithmetic/WordByWordMontgomery.v new file mode 100644 index 000000000..f52dbdeb1 --- /dev/null +++ b/src/Arithmetic/WordByWordMontgomery.v @@ -0,0 +1,1311 @@ + +(* TODO: prune these *) +Require Import Crypto.Algebra.Nsatz. +Require Import Coq.ZArith.ZArith Coq.micromega.Lia Crypto.Algebra.Nsatz. +Require Import Coq.Sorting.Mergesort Coq.Structures.Orders. +Require Import Coq.Sorting.Permutation. +Require Import Coq.derive.Derive. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. (* For MontgomeryReduction *) +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. (* For MontgomeryReduction *) +Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable. +Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn. +Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Arithmetic.ModularArithmeticTheorems. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.RunTacticAsConstr. +Require Import Crypto.Util.Tactics.Head. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.OptionList. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Sum. +Require Import Crypto.Util.Bool. +Require Import Crypto.Util.Sigma. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Tactics.RewriteModSmall. +Require Import Crypto.Util.ZUtil.Tactics.PeelLe. +Require Import Crypto.Util.ZUtil.Tactics.LinearSubstitute. +Require Import Crypto.Util.ZUtil.Tactics.ZeroBounds. +Require Import Crypto.Util.ZUtil.Modulo.PullPush. +Require Import Crypto.Util.ZUtil.Opp. +Require Import Crypto.Util.ZUtil.Log2. +Require Import Crypto.Util.ZUtil.Le. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.ZUtil.Tactics.DivModToQuotRem. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Notations. +Require Import Crypto.Util.ZUtil.Definitions. +Require Import Crypto.Util.ZUtil.Sorting. +Require Import Crypto.Util.ZUtil.CC Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Require Import Crypto.Util.ZUtil.Hints.Core. +Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div. +Require Import Crypto.Util.ZUtil.Hints.PullPush. +Require Import Crypto.Util.ZUtil.EquivModulo. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.CPSNotations. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.SetEvars. +Import Coq.Lists.List ListNotations. Local Open Scope Z_scope. + +Module WordByWordMontgomery. + Import Partition. + Local Hint Resolve Z.positive_is_nonzero Z.lt_gt Nat2Z.is_nonneg. + Section with_args. + Context (lgr : Z) + (m : Z). + Local Notation weight := (UniformWeight.uweight lgr). + Let T (n : nat) := list Z. + Let r := (2^lgr). + Definition eval {n} : T n -> Z := Positional.eval weight n. + Let zero {n} : T n := Positional.zeros n. + Let divmod {n} : T (S n) -> T n * Z := Rows.divmod. + Let scmul {n} (c : Z) (p : T n) : T (S n) (* uses double-output multiply *) + := let '(v, c) := Rows.mul weight r n (S n) (Positional.extend_to_length 1 n [c]) p in + v. + Let addT {n} (p q : T n) : T (S n) (* joins carry *) + := let '(v, c) := Rows.add weight n p q in + v ++ [c]. + Let drop_high_addT' {n} (p : T (S n)) (q : T n) : T (S n) (* drops carry *) + := fst (Rows.add weight (S n) p (Positional.extend_to_length n (S n) q)). + Let conditional_sub {n} (arg : T (S n)) (N : T n) : T n (* computes [arg - N] if [N <= arg], and drops high bit *) + := Positional.drop_high_to_length n (Rows.conditional_sub weight (S n) arg (Positional.extend_to_length n (S n) N)). + Context (R_numlimbs : nat) + (N : T R_numlimbs). (* encoding of m *) + Let sub_then_maybe_add (a b : T R_numlimbs) : T R_numlimbs (* computes [a - b + if (a - b) <? 0 then N else 0] *) + := fst (Rows.sub_then_maybe_add weight R_numlimbs (r-1) a b N). + Local Opaque T. + Section Iteration. + Context (pred_A_numlimbs : nat) + (B : T R_numlimbs) (k : Z) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + Local Definition A_a := dlet p := @divmod _ A in p. Local Definition A' := fst A_a. Local Definition a := snd A_a. + Local Definition S1 := @addT _ S (@scmul _ a B). + Local Definition s := snd (@divmod _ S1). + Local Definition q := fst (Z.mul_split r s k). + Local Definition S2 := @drop_high_addT' _ S1 (@scmul _ q N). + Local Definition S3' := fst (@divmod _ S2). + + Local Definition A'_S3 + := dlet A_a := @divmod _ A in + dlet A' := fst A_a in + dlet a := snd A_a in + dlet S1 := @addT _ S (@scmul _ a B) in + dlet s := snd (@divmod _ S1) in + dlet q := fst (Z.mul_split r s k) in + dlet S2 := @drop_high_addT' _ S1 (@scmul _ q N) in + dlet S3 := fst (@divmod _ S2) in + (A', S3). + + Lemma A'_S3_alt : A'_S3 = (A', S3'). + Proof using Type. cbv [A'_S3 A' S3' Let_In S2 q s S1 A' a A_a]; reflexivity. Qed. + End Iteration. + + Section loop. + Context (A_numlimbs : nat) + (A : T A_numlimbs) + (B : T R_numlimbs) + (k : Z) + (S' : T (S R_numlimbs)). + + Definition redc_body {pred_A_numlimbs} : T (S pred_A_numlimbs) * T (S R_numlimbs) + -> T pred_A_numlimbs * T (S R_numlimbs) + := fun '(A, S') => A'_S3 _ B k A S'. + + Definition redc_loop (count : nat) : T count * T (S R_numlimbs) -> T O * T (S R_numlimbs) + := nat_rect + (fun count => T count * _ -> _) + (fun A_S => A_S) + (fun count' redc_loop_count' A_S + => redc_loop_count' (redc_body A_S)) + count. + + Definition pre_redc : T (S R_numlimbs) + := snd (redc_loop A_numlimbs (A, @zero (1 + R_numlimbs)%nat)). + + Definition redc : T R_numlimbs + := conditional_sub pre_redc N. + End loop. + + Create HintDb word_by_word_montgomery. + Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. + + Definition add (A B : T R_numlimbs) : T R_numlimbs + := conditional_sub (@addT _ A B) N. + Definition sub (A B : T R_numlimbs) : T R_numlimbs + := sub_then_maybe_add A B. + Definition opp (A : T R_numlimbs) : T R_numlimbs + := sub (@zero _) A. + Definition nonzero (A : list Z) : Z + := fold_right Z.lor 0 A. + + Context (lgr_big : 0 < lgr) + (R_numlimbs_nz : R_numlimbs <> 0%nat). + Let R := (r^Z.of_nat R_numlimbs). + Transparent T. + Definition small {n} (v : T n) : Prop + := v = partition weight n (eval v). + Context (small_N : small N) + (N_lt_R : eval N < R) + (N_nz : 0 < eval N) + (B : T R_numlimbs) + (B_bounds : 0 <= eval B < R) + (small_B : small B) + ri (ri_correct : r*ri mod (eval N) = 1 mod (eval N)) + (k : Z) (k_correct : k * eval N mod r = (-1) mod r). + + Local Lemma r_big : r > 1. + Proof using lgr_big. clear -lgr_big; subst r. auto with zarith. Qed. + Local Notation wprops := (@UniformWeight.uwprops lgr lgr_big). + + Local Hint Immediate (wprops). + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + Local Hint Immediate r_big. + + Lemma length_small {n v} : @small n v -> length v = n. + Proof using Type. clear; cbv [small]; intro H; rewrite H; autorewrite with distr_length; reflexivity. Qed. + Lemma small_bound {n v} : @small n v -> 0 <= eval v < weight n. + Proof using lgr_big. clear - lgr_big; cbv [small eval]; intro H; rewrite H; autorewrite with push_eval; auto with zarith. Qed. + + Lemma R_plusR_le : R + R <= weight (S R_numlimbs). + Proof using lgr_big. + clear - lgr_big. + etransitivity; [ | apply UniformWeight.uweight_double_le; omega ]. + rewrite UniformWeight.uweight_eq_alt by omega. + subst r R; omega. + Qed. + + Lemma mask_r_sub1 n x : + map (Z.land (r - 1)) (partition weight n x) = partition weight n x. + Proof using lgr_big. + clear - lgr_big. cbv [partition]. + rewrite map_map. apply map_ext; intros. + rewrite UniformWeight.uweight_S by omega. + rewrite <-Z.mod_pull_div by auto with zarith. + replace (r - 1) with (Z.ones lgr) by (rewrite Z.ones_equiv; subst r; reflexivity). + rewrite <-Z.land_comm, Z.land_ones by omega. + auto with zarith. + Qed. + + Let partition_Proper := (@partition_Proper _ wprops). + Local Existing Instance partition_Proper. + Lemma eval_nonzero n A : @small n A -> nonzero A = 0 <-> @eval n A = 0. + Proof using lgr_big. + clear -lgr_big partition_Proper. + cbv [nonzero eval small]; intro Heq. + do 2 rewrite Heq. + rewrite !eval_partition, Z.mod_mod by auto. + generalize (Positional.eval weight n A); clear Heq A. + induction n as [|n IHn]. + { cbn; rewrite weight_0 by auto; intros; autorewrite with zsimplify_const; omega. } + { intro; rewrite partition_step. + rewrite fold_right_snoc, Z.lor_comm, <- fold_right_push, Z.lor_eq_0_iff by auto using Z.lor_assoc. + assert (Heq : Z.equiv_modulo (weight n) (z mod weight (S n)) (z mod (weight n))). + { cbv [Z.equiv_modulo]. + generalize (weight_multiples ltac:(auto) n). + generalize (weight_positive ltac:(auto) n). + generalize (weight_positive ltac:(auto) (S n)). + generalize (weight (S n)) (weight n); clear; intros wsn wn. + clear; intros. + Z.div_mod_to_quot_rem; subst. + autorewrite with zsimplify_const in *. + Z.linear_substitute_all. + apply Zminus_eq; ring_simplify. + rewrite <- !Z.add_opp_r, !Z.mul_opp_comm, <- !Z.mul_opp_r, <- !Z.mul_assoc. + rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. + nia. } + rewrite Heq at 1; rewrite IHn. + rewrite Z.mod_mod by auto. + generalize (weight_multiples ltac:(auto) n). + generalize (weight_positive ltac:(auto) n). + generalize (weight_positive ltac:(auto) (S n)). + generalize (weight (S n)) (weight n); clear; intros wsn wn; intros. + Z.div_mod_to_quot_rem. + repeat (intro || apply conj); destruct_head'_or; try omega; destruct_head'_and; subst; autorewrite with zsimplify_const in *; try nia; + Z.linear_substitute_all. + all: apply Zminus_eq; ring_simplify. + all: rewrite <- ?Z.add_opp_r, ?Z.mul_opp_comm, <- ?Z.mul_opp_r, <- ?Z.mul_assoc. + all: rewrite <- !Z.mul_add_distr_l, Z.mul_eq_0. + all: nia. } + Qed. + + Local Ltac push_step := + first [ progress eta_expand + | rewrite Rows.mul_partitions + | rewrite Rows.mul_div + | rewrite Rows.add_partitions + | rewrite Rows.add_div + | progress autorewrite with push_eval distr_length + | match goal with + | [ H : ?v = _ |- context[length ?v] ] => erewrite length_small by eassumption + | [ H : small ?v |- context[length ?v] ] => erewrite length_small by eassumption + end + | rewrite Positional.eval_cons by distr_length + | progress rewrite ?weight_0, ?UniformWeight.uweight_1 by auto; + autorewrite with zsimplify_fast + | rewrite (weight_0 wprops) + | rewrite <- Z.div_mod'' by auto with omega + | solve [ trivial ] ]. + Local Ltac push := repeat push_step. + + Local Ltac t_step := + match goal with + | [ H := _ |- _ ] => progress cbv [H] in * + | _ => progress push_step + | _ => progress autorewrite with zsimplify_const + | _ => solve [ auto with omega ] + end. + + Local Hint Unfold eval zero small divmod scmul drop_high_addT' addT R : loc. + Local Lemma eval_zero : forall n, eval (@zero n) = 0. + Proof using Type. + clear; autounfold with loc; intros; autorewrite with push_eval; auto. + Qed. + Local Lemma small_zero : forall n, small (@zero n). + Proof using Type. + etransitivity; [ eapply Positional.zeros_ext_map | rewrite eval_zero ]; cbv [partition]; cbn; try reflexivity; autorewrite with distr_length; reflexivity. + Qed. + Local Hint Immediate small_zero. + + Ltac push_recursive_partition := + repeat match goal with + | _ => progress cbn [recursive_partition] + | H : small _ |- _ => rewrite H; clear H + | _ => rewrite recursive_partition_equiv by auto using wprops + | _ => rewrite UniformWeight.uweight_eval_shift by distr_length + | _ => progress push + end. + + Lemma eval_div : forall n v, small v -> eval (fst (@divmod n v)) = eval v / r. + Proof using lgr_big. + pose proof r_big as r_big. + clear - r_big lgr_big; intros; autounfold with loc. + push_recursive_partition; cbn [Rows.divmod fst tl]. + autorewrite with zsimplify; reflexivity. + Qed. + Lemma eval_mod : forall n v, small v -> snd (@divmod n v) = eval v mod r. + Proof using lgr_big. + clear - lgr_big; intros; autounfold with loc. + push_recursive_partition; cbn [Rows.divmod snd hd]. + autorewrite with zsimplify; reflexivity. + Qed. + Lemma small_div : forall n v, small v -> small (fst (@divmod n v)). + Proof using lgr_big. + pose proof r_big as r_big. + clear - r_big lgr_big. intros; autounfold with loc. + push_recursive_partition. cbn [Rows.divmod fst tl]. + rewrite <-recursive_partition_equiv by auto. + rewrite <-UniformWeight.uweight_recursive_partition_equiv with (i:=1%nat) by omega. + push. + apply Partition.partition_Proper; [ solve [auto] | ]. + cbv [Z.equiv_modulo]. autorewrite with zsimplify. + reflexivity. + Qed. + + Definition canon_rep {n} x (v : T n) : Prop := + (v = partition weight n x) /\ (0 <= x < weight n). + Lemma eval_canon_rep n x v : @canon_rep n x v -> eval v = x. + Proof using lgr_big. + clear - lgr_big. + cbv [canon_rep eval]; intros [Hv Hx]. + rewrite Hv. autorewrite with push_eval. + auto using Z.mod_small. + Qed. + Lemma small_canon_rep n x v : @canon_rep n x v -> small v. + Proof using lgr_big. + clear - lgr_big. + cbv [canon_rep eval small]; intros [Hv Hx]. + rewrite Hv. autorewrite with push_eval. + apply partition_eq_mod; auto; [ ]. + Z.rewrite_mod_small; reflexivity. + Qed. + + Local Lemma scmul_correct: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> canon_rep (a * eval v) (@scmul n a v). + Proof using lgr_big. + pose proof r_big as r_big. + clear - lgr_big r_big. + autounfold with loc; intro n; destruct (zerop n); intros until 0; intro Hsmall; intros. + { intros; subst; cbn; rewrite Z.add_with_get_carry_full_mod. + split; cbn; autorewrite with zsimplify_fast; auto with zarith. } + { rewrite (surjective_pairing (Rows.mul _ _ _ _ _ _)). + rewrite Rows.mul_partitions by (try rewrite Hsmall; auto using length_partition, Positional.length_extend_to_length with omega). + autorewrite with push_eval. + rewrite Positional.eval_cons by reflexivity. + rewrite weight_0 by auto. + autorewrite with push_eval zsimplify_fast. + split; [reflexivity | ]. + rewrite UniformWeight.uweight_S, UniformWeight.uweight_eq_alt by omega. + subst r; nia. } + Qed. + + Local Lemma addT_correct : forall n a b, small a -> small b -> canon_rep (eval a + eval b) (@addT n a b). + Proof using lgr_big. + intros n a b Ha Hb. + generalize (length_small Ha); generalize (length_small Hb). + generalize (small_bound Ha); generalize (small_bound Hb). + clear -lgr_big Ha Hb. + autounfold with loc; destruct (zerop n); subst. + { destruct a, b; cbn; try omega; split; auto with zarith. } + { pose proof (UniformWeight.uweight_double_le lgr ltac:(omega) n). + eta_expand; split; [ | lia ]. + rewrite Rows.add_partitions, Rows.add_div by auto. + rewrite partition_step. + Z.rewrite_mod_small; reflexivity. } + Qed. + + Local Lemma drop_high_addT'_correct : forall n a b, small a -> small b -> canon_rep ((eval a + eval b) mod (r^Z.of_nat (S n))) (@drop_high_addT' n a b). + Proof using lgr_big. + intros n a b Ha Hb; generalize (length_small Ha); generalize (length_small Hb). + clear -lgr_big Ha Hb. + autounfold with loc in *; subst; intros. + rewrite Rows.add_partitions by auto using Positional.length_extend_to_length. + autorewrite with push_eval. + split; try apply partition_eq_mod; auto; rewrite UniformWeight.uweight_eq_alt by omega; subst r; Z.rewrite_mod_small; auto with zarith. + Qed. + + Local Lemma conditional_sub_correct : forall v, small v -> 0 <= eval v < eval N + R -> canon_rep (eval v + if eval N <=? eval v then -eval N else 0) (conditional_sub v N). + Proof using small_N lgr_big N_nz N_lt_R. + pose proof R_plusR_le as R_plusR_le. + clear - small_N lgr_big N_nz N_lt_R R_plusR_le. + intros; autounfold with loc; cbv [conditional_sub]. + repeat match goal with H : small _ |- _ => + rewrite H; clear H end. + autorewrite with push_eval. + assert (weight R_numlimbs < weight (S R_numlimbs)) by (rewrite !UniformWeight.uweight_eq_alt by omega; autorewrite with push_Zof_nat; auto with zarith). + assert (eval N mod weight R_numlimbs < weight (S R_numlimbs)) by (pose proof (Z.mod_pos_bound (eval N) (weight R_numlimbs) ltac:(auto)); omega). + rewrite Rows.conditional_sub_partitions by (repeat (autorewrite with distr_length push_eval; auto using partition_eq_mod with zarith)). + rewrite drop_high_to_length_partition by omega. + autorewrite with push_eval. + assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst R; reflexivity). + Z.rewrite_mod_small. + break_match; autorewrite with zsimplify_fast; Z.ltb_to_lt. + { split; [ reflexivity | ]. + rewrite Z.add_opp_r. fold (eval N). + auto using Z.mod_small with lia. } + { split; auto using Z.mod_small with lia. } + Qed. + + Local Lemma sub_then_maybe_add_correct : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> canon_rep (eval a - eval b + if eval a - eval b <? 0 then eval N else 0) (sub_then_maybe_add a b). + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. + pose proof mask_r_sub1 as mask_r_sub1. + clear - small_N lgr_big R_numlimbs_nz N_nz N_lt_R mask_r_sub1. + intros; autounfold with loc; cbv [sub_then_maybe_add]. + repeat match goal with H : small _ |- _ => + rewrite H; clear H end. + rewrite Rows.sub_then_maybe_add_partitions by (autorewrite with push_eval distr_length; auto with zarith). + autorewrite with push_eval. + assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst r R; reflexivity). + Z.rewrite_mod_small. + split; [ reflexivity | ]. + break_match; Z.ltb_to_lt; lia. + Qed. + + Local Lemma eval_scmul: forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> eval (@scmul n a v) = a * eval v. + Proof using lgr_big. eauto using scmul_correct, eval_canon_rep. Qed. + Local Lemma small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> small (@scmul n a v). + Proof using lgr_big. eauto using scmul_correct, small_canon_rep. Qed. + Local Lemma eval_addT : forall n a b, small a -> small b -> eval (@addT n a b) = eval a + eval b. + Proof using lgr_big. eauto using addT_correct, eval_canon_rep. Qed. + Local Lemma small_addT : forall n a b, small a -> small b -> small (@addT n a b). + Proof using lgr_big. eauto using addT_correct, small_canon_rep. Qed. + Local Lemma eval_drop_high_addT' : forall n a b, small a -> small b -> eval (@drop_high_addT' n a b) = (eval a + eval b) mod (r^Z.of_nat (S n)). + Proof using lgr_big. eauto using drop_high_addT'_correct, eval_canon_rep. Qed. + Local Lemma small_drop_high_addT' : forall n a b, small a -> small b -> small (@drop_high_addT' n a b). + Proof using lgr_big. eauto using drop_high_addT'_correct, small_canon_rep. Qed. + Local Lemma eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0. + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using conditional_sub_correct, eval_canon_rep. Qed. + Local Lemma small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v N). + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using conditional_sub_correct, small_canon_rep. Qed. + Local Lemma eval_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b <? 0 then eval N else 0. + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using sub_then_maybe_add_correct, eval_canon_rep. Qed. + Local Lemma small_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> small (sub_then_maybe_add a b). + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. eauto using sub_then_maybe_add_correct, small_canon_rep. Qed. + + Local Opaque T addT drop_high_addT' divmod zero scmul conditional_sub sub_then_maybe_add. + Create HintDb push_mont_eval discriminated. + Create HintDb word_by_word_montgomery. + Hint Unfold A'_S3 S3' S2 q s S1 a A' A_a Let_In : word_by_word_montgomery. + Let r_big' := r_big. (* to put it in the context *) + Local Ltac t_small := + repeat first [ assumption + | apply small_addT + | apply small_drop_high_addT' + | apply small_div + | apply small_zero + | apply small_scmul + | apply small_conditional_sub + | apply small_sub_then_maybe_add + | apply Z_mod_lt + | rewrite Z.mul_split_mod + | solve [ auto with zarith ] + | lia + | progress autorewrite with push_mont_eval + | progress autounfold with word_by_word_montgomery + | match goal with + | [ H : and _ _ |- _ ] => destruct H + end ]. + Hint Rewrite + eval_zero + eval_div + eval_mod + eval_addT + eval_drop_high_addT' + eval_scmul + eval_conditional_sub + eval_sub_then_maybe_add + using (repeat autounfold with word_by_word_montgomery; t_small) + : push_mont_eval. + + Local Arguments eval {_} _. + Local Arguments small {_} _. + Local Arguments divmod {_} _. + + (* Recurse for a as many iterations as A has limbs, varying A := A, S := 0, r, bounds *) + Section Iteration_proofs. + Context (pred_A_numlimbs : nat) + (A : T (S pred_A_numlimbs)) + (S : T (S R_numlimbs)) + (small_A : small A) + (small_S : small S) + (S_nonneg : 0 <= eval S). + (* Given A, B < R, we want to compute A * B / R mod N. R = bound 0 * ... * bound (n-1) *) + + Local Coercion eval : T >-> Z. + + Local Notation a := (@a pred_A_numlimbs A). + Local Notation A' := (@A' pred_A_numlimbs A). + Local Notation S1 := (@S1 pred_A_numlimbs B A S). + Local Notation s := (@s pred_A_numlimbs B A S). + Local Notation q := (@q pred_A_numlimbs B k A S). + Local Notation S2 := (@S2 pred_A_numlimbs B k A S). + Local Notation S3 := (@S3' pred_A_numlimbs B k A S). + + Local Notation eval_pre_S3 := ((S + a * B + q * N) / r). + + Lemma eval_S3_eq : eval S3 = eval_pre_S3 mod (r * r ^ Z.of_nat R_numlimbs). + Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. + clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big. + unfold S3, S2, S1. + autorewrite with push_mont_eval push_Zof_nat; []. + rewrite !Z.pow_succ_r, <- ?Z.mul_assoc by omega. + rewrite Z.mod_pull_div by Z.zero_bounds. + do 2 f_equal; nia. + Qed. + + Lemma pre_S3_bound + : eval S < eval N + eval B + -> eval_pre_S3 < eval N + eval B. + Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. + clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big. + assert (Hmod : forall a b, 0 < b -> a mod b <= b - 1) + by (intros x y; pose proof (Z_mod_lt x y); omega). + intro HS. + eapply Z.le_lt_trans. + { transitivity ((N+B-1 + (r-1)*B + (r-1)*N) / r); + [ | set_evars; ring_simplify_subterms; subst_evars; reflexivity ]. + Z.peel_le; repeat apply Z.add_le_mono; repeat apply Z.mul_le_mono_nonneg; try lia; + repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; + autorewrite with push_mont_eval; + try Z.zero_bounds; + auto with lia. } + rewrite (Z.mul_comm _ r), <- Z.add_sub_assoc, <- Z.add_opp_r, !Z.div_add_l' by lia. + autorewrite with zsimplify. + simpl; omega. + Qed. + + Lemma pre_S3_nonneg : 0 <= eval_pre_S3. + Proof using N_nz B_bounds small_B small_A small_S S_nonneg lgr_big. + clear -N_nz B_bounds small_B partition_Proper r_big' small_A small_S S_nonneg. + repeat autounfold with word_by_word_montgomery; rewrite ?Z.mul_split_mod; + autorewrite with push_mont_eval; []. + rewrite ?Npos_correct; Z.zero_bounds; lia. + Qed. + + Lemma small_A' + : small A'. + Proof using small_A lgr_big. repeat autounfold with word_by_word_montgomery; t_small. Qed. + + Lemma small_S3 + : small S3. + Proof using small_A small_S small_N N_lt_R N_nz B_bounds small_B lgr_big. + clear -small_A small_S small_N N_lt_R N_nz B_bounds small_B partition_Proper r_big'. + repeat autounfold with word_by_word_montgomery; t_small. + Qed. + + Lemma S3_nonneg : 0 <= eval S3. + Proof using small_A small_S small_B B_bounds N_nz N_lt_R small_N lgr_big. + clear -small_A small_S r_big' partition_Proper small_B B_bounds N_nz N_lt_R small_N lgr_big sub_then_maybe_add. + rewrite eval_S3_eq; Z.zero_bounds. + Qed. + + Lemma S3_bound + : eval S < eval N + eval B + -> eval S3 < eval N + eval B. + Proof using N_nz B_bounds small_B small_A small_S S_nonneg B_bounds N_nz N_lt_R small_N lgr_big. + clear -N_nz B_bounds small_B small_A small_S S_nonneg B_bounds N_nz N_lt_R small_N lgr_big partition_Proper r_big' sub_then_maybe_add. + rewrite eval_S3_eq. + intro H; pose proof (pre_S3_bound H); pose proof pre_S3_nonneg. + subst R. + rewrite Z.mod_small by nia. + assumption. + Qed. + + Lemma S1_eq : eval S1 = S + a*B. + Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S. + clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper. + cbv [S1 a A']. + repeat autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma S2_mod_r_helper : (S + a*B + q * N) mod r = 0. + Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct. + clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct. + cbv [S2 q s]; autorewrite with push_mont_eval; rewrite S1_eq. + assert (r > 0) by lia. + assert (Hr : (-(1 mod r)) mod r = r - 1 /\ (-(1)) mod r = r - 1). + { destruct (Z.eq_dec r 1) as [H'|H']. + { rewrite H'; split; reflexivity. } + { rewrite !Z_mod_nz_opp_full; rewrite ?Z.mod_mod; Z.rewrite_mod_small; [ split; reflexivity | omega.. ]. } } + autorewrite with pull_Zmod. + replace 0 with (0 mod r) by apply Zmod_0_l. + pose (Z.to_pos r) as r'. + replace r with (Z.pos r') by (subst r'; rewrite Z2Pos.id; lia). + eapply F.eq_of_Z_iff. + rewrite Z.mul_split_mod. + repeat rewrite ?F.of_Z_add, ?F.of_Z_mul, <-?F.of_Z_mod. + rewrite <-!Algebra.Hierarchy.associative. + replace ((F.of_Z r' k * F.of_Z r' (eval N))%F) with (F.opp (m:=r') F.one). + { cbv [F.of_Z F.add]; simpl. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]. + simpl. + subst r'; rewrite Z2Pos.id by lia. + rewrite (proj1 Hr), Z.mul_sub_distr_l. + push_Zmod; pull_Zmod. + apply (f_equal2 Z.modulo); omega. } + { rewrite <- F.of_Z_mul. + rewrite F.of_Z_mod. + subst r'; rewrite Z2Pos.id by lia. + rewrite k_correct. + cbv [F.of_Z F.add F.opp F.one]; simpl. + change (-(1)) with (-1) in *. + apply path_sig_hprop; [ intro; exact HProp.allpath_hprop | ]; simpl. + rewrite Z2Pos.id by lia. + rewrite (proj1 Hr), (proj2 Hr); Z.rewrite_mod_small; reflexivity. } + Qed. + + Lemma pre_S3_mod_N + : eval_pre_S3 mod N = (S + a*B)*ri mod N. + Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct ri_correct. + clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct ri_correct sub_then_maybe_add. + pose proof fun a => Z.div_to_inv_modulo N a r ri ltac:(lia) ri_correct as HH; + cbv [Z.equiv_modulo] in HH; rewrite HH; clear HH. + etransitivity; [rewrite (fun a => Z.mul_mod_l a ri N)| + rewrite (fun a => Z.mul_mod_l a ri N); reflexivity]. + rewrite S2_mod_r_helper. + push_Zmod; pull_Zmod; autorewrite with zsimplify_const. + reflexivity. + Qed. + + Lemma S3_mod_N + (Hbound : eval S < eval N + eval B) + : S3 mod N = (S + a*B)*ri mod N. + Proof using B_bounds R_numlimbs_nz lgr_big small_A small_B small_S k_correct ri_correct small_N N_lt_R N_nz S_nonneg. + clear -B_bounds R_numlimbs_nz lgr_big small_A small_B small_S r_big' partition_Proper k_correct ri_correct N_nz N_lt_R small_N sub_then_maybe_add Hbound S_nonneg. + rewrite eval_S3_eq. + pose proof (pre_S3_bound Hbound); pose proof pre_S3_nonneg. + rewrite (Z.mod_small _ (r * _)) by (subst R; nia). + apply pre_S3_mod_N. + Qed. + End Iteration_proofs. + + Section redc_proofs. + Local Notation redc_body := (@redc_body B k). + Local Notation redc_loop := (@redc_loop B k). + Local Notation pre_redc A := (@pre_redc _ A B k). + Local Notation redc A := (@redc _ A B k). + + Section body. + Context (pred_A_numlimbs : nat) + (A_S : T (S pred_A_numlimbs) * T (S R_numlimbs)). + Let A:=fst A_S. + Let S:=snd A_S. + Let A_a:=divmod A. + Let a:=snd A_a. + Context (small_A : small A) + (small_S : small S) + (S_bound : 0 <= eval S < eval N + eval B). + + Lemma small_fst_redc_body : small (fst (redc_body A_S)). + Proof using S_bound small_A small_S lgr_big. destruct A_S; apply small_A'; assumption. Qed. + Lemma small_snd_redc_body : small (snd (redc_body A_S)). + Proof using small_S small_N small_B small_A lgr_big S_bound B_bounds N_nz N_lt_R. + destruct A_S; unfold redc_body; apply small_S3; assumption. + Qed. + Lemma snd_redc_body_nonneg : 0 <= eval (snd (redc_body A_S)). + Proof using small_S small_N small_B small_A lgr_big S_bound N_nz N_lt_R B_bounds. + destruct A_S; apply S3_nonneg; assumption. + Qed. + + Lemma snd_redc_body_mod_N + : (eval (snd (redc_body A_S))) mod (eval N) = (eval S + a*eval B)*ri mod (eval N). + Proof using small_S small_N small_B small_A ri_correct lgr_big k_correct S_bound R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_S small_N small_B small_A ri_correct k_correct S_bound R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add r_big' partition_Proper. + destruct A_S; apply S3_mod_N; auto; omega. + Qed. + + Lemma fst_redc_body + : (eval (fst (redc_body A_S))) = eval (fst A_S) / r. + Proof using small_S small_A S_bound lgr_big. + destruct A_S; simpl; repeat autounfold with word_by_word_montgomery; simpl. + autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma fst_redc_body_mod_N + : (eval (fst (redc_body A_S))) mod (eval N) = ((eval (fst A_S) - a)*ri) mod (eval N). + Proof using small_S small_A ri_correct lgr_big S_bound. + rewrite fst_redc_body. + etransitivity; [ eapply Z.div_to_inv_modulo; try eassumption; lia | ]. + unfold a, A_a, A. + autorewrite with push_mont_eval. + reflexivity. + Qed. + + Lemma redc_body_bound + : eval S < eval N + eval B + -> eval (snd (redc_body A_S)) < eval N + eval B. + Proof using small_S small_N small_B small_A lgr_big S_bound N_nz N_lt_R B_bounds. + clear -small_S small_N small_B small_A S_bound N_nz N_lt_R B_bounds r_big' partition_Proper sub_then_maybe_add. + destruct A_S; apply S3_bound; unfold S in *; cbn [snd] in *; try assumption; try omega. + Qed. + End body. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.of_nat !_. + Local Ltac induction_loop count IHcount + := induction count as [|count IHcount]; intros; cbn [redc_loop nat_rect] in *; [ | (*rewrite redc_loop_comm_body in * *) ]. + Lemma redc_loop_good count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S))) + /\ 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. + induction_loop count IHcount; auto; []. + change (id (0 <= eval B < R)) in B_bounds (* don't let [destruct_head'_and] loop *). + destruct_head'_and. + repeat first [ apply conj + | apply small_fst_redc_body + | apply small_snd_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply IHcount + | solve [ auto ] ]. + Qed. + + Lemma small_redc_loop count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : small (fst (redc_loop count A_S)) /\ small (snd (redc_loop count A_S)). + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. apply redc_loop_good; assumption. Qed. + + Lemma redc_loop_bound count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : 0 <= eval (snd (redc_loop count A_S)) < eval N + eval B. + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. apply redc_loop_good; assumption. Qed. + + Local Ltac handle_IH_small := + repeat first [ apply redc_loop_good + | apply small_fst_redc_body + | apply small_snd_redc_body + | apply redc_body_bound + | apply snd_redc_body_nonneg + | apply conj + | progress cbn [fst snd] + | progress destruct_head' and + | solve [ auto ] ]. + + Lemma fst_redc_loop count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) = eval (fst A_S) / r^(Z.of_nat count). + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add Hsmall Hbound. + cbv [redc_loop]; induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount, fst_redc_body by handle_IH_small. + change (1 + R_numlimbs)%nat with (S R_numlimbs) in *. + rewrite Zdiv_Zdiv by Z.zero_bounds. + rewrite <- (Z.pow_1_r r) at 1. + rewrite <- Z.pow_add_r by lia. + replace (1 + Z.of_nat count) with (Z.of_nat (S count)) by lia. + reflexivity. } + Qed. + + Lemma fst_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : eval (fst (redc_loop count A_S)) mod (eval N) + = (eval (fst A_S) - eval (fst A_S) mod r^Z.of_nat count) + * ri^(Z.of_nat count) mod (eval N). + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds ri_correct. + clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add Hsmall Hbound ri_correct. + rewrite fst_redc_loop by assumption. + destruct count. + { simpl; autorewrite with zsimplify; reflexivity. } + { etransitivity; + [ eapply Z.div_to_inv_modulo; + try solve [ eassumption + | apply Z.lt_gt, Z.pow_pos_nonneg; lia ] + | ]. + { erewrite <- Z.pow_mul_l, <- Z.pow_1_l. + { apply Z.pow_mod_Proper; [ eassumption | reflexivity ]. } + { lia. } } + reflexivity. } + Qed. + + Local Arguments Z.pow : simpl never. + Lemma snd_redc_loop_mod_N count A_S + (Hsmall : small (fst A_S) /\ small (snd A_S)) + (Hbound : 0 <= eval (snd A_S) < eval N + eval B) + : (eval (snd (redc_loop count A_S))) mod (eval N) + = ((eval (snd A_S) + (eval (fst A_S) mod r^(Z.of_nat count))*eval B)*ri^(Z.of_nat count)) mod (eval N). + Proof using small_N small_B ri_correct lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds k_correct. + clear -small_N small_B ri_correct r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds sub_then_maybe_add k_correct Hsmall Hbound. + cbv [redc_loop]. + induction_loop count IHcount. + { simpl; autorewrite with zsimplify; reflexivity. } + { rewrite IHcount by handle_IH_small. + push_Zmod; rewrite snd_redc_body_mod_N, fst_redc_body by handle_IH_small; pull_Zmod. + autorewrite with push_mont_eval; []. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + destruct A_S as [A S]. + cbn [fst snd]. + change (Z.pos (Pos.of_succ_nat ?n)) with (Z.of_nat (Datatypes.S n)). + rewrite !Z.mul_add_distr_r. + rewrite <- !Z.mul_assoc. + replace (ri * ri^(Z.of_nat count)) with (ri^(Z.of_nat (Datatypes.S count))) + by (change (Datatypes.S count) with (1 + count)%nat; + autorewrite with push_Zof_nat; rewrite Z.pow_add_r by lia; simpl Z.succ; rewrite Z.pow_1_r; nia). + rewrite <- !Z.add_assoc. + apply Z.add_mod_Proper; [ reflexivity | ]. + unfold Z.equiv_modulo; push_Zmod; rewrite (Z.mul_mod_l (_ mod r) _ (eval N)). + rewrite Z.mod_pull_div by auto with zarith lia. + push_Zmod. + erewrite Z.div_to_inv_modulo; + [ + | apply Z.lt_gt; lia + | eassumption ]. + pull_Zmod. + match goal with + | [ |- ?x mod ?N = ?y mod ?N ] + => change (Z.equiv_modulo N x y) + end. + repeat first [ rewrite <- !Z.pow_succ_r, <- !Nat2Z.inj_succ by lia + | rewrite (Z.mul_comm _ ri) + | rewrite (Z.mul_assoc _ ri _) + | rewrite (Z.mul_comm _ (ri^_)) + | rewrite (Z.mul_assoc _ (ri^_) _) ]. + repeat first [ rewrite <- Z.mul_assoc + | rewrite <- Z.mul_add_distr_l + | rewrite (Z.mul_comm _ (eval B)) + | rewrite !Nat2Z.inj_succ, !Z.pow_succ_r by lia; + rewrite <- Znumtheory.Zmod_div_mod by (apply Z.divide_factor_r || Z.zero_bounds) + | rewrite Zplus_minus + | rewrite (Z.mul_comm r (r^_)) + | reflexivity ]. } + Qed. + + Lemma pre_redc_bound A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : 0 <= eval (pre_redc A) < eval N + eval B. + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds sub_then_maybe_add small_A. + unfold pre_redc. + apply redc_loop_good; simpl; autorewrite with push_mont_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma small_pre_redc A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : small (pre_redc A). + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds sub_then_maybe_add small_A. + unfold pre_redc. + apply redc_loop_good; simpl; autorewrite with push_mont_eval; + rewrite ?Npos_correct; auto; lia. + Qed. + + Lemma pre_redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : (eval (pre_redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). + Proof using small_N small_B lgr_big N_nz N_lt_R B_bounds R_numlimbs_nz ri_correct k_correct. + clear -small_N small_B r_big' partition_Proper lgr_big N_nz N_lt_R B_bounds R_numlimbs_nz ri_correct k_correct sub_then_maybe_add small_A A_bound. + unfold pre_redc. + rewrite snd_redc_loop_mod_N; cbn [fst snd]; + autorewrite with push_mont_eval zsimplify; + [ | rewrite ?Npos_correct; auto; lia.. ]. + Z.rewrite_mod_small. + reflexivity. + Qed. + + Lemma redc_mod_N A_numlimbs (A : T A_numlimbs) (small_A : small A) (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : (eval (redc A)) mod (eval N) = (eval A * eval B * ri^(Z.of_nat A_numlimbs)) mod (eval N). + Proof using small_N small_B ri_correct lgr_big k_correct R_numlimbs_nz N_nz N_lt_R B_bounds. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + autorewrite with push_mont_eval; []. + break_innermost_match; + try rewrite Z.add_opp_r, Zminus_mod, Z_mod_same_full; + autorewrite with zsimplify_fast; + apply pre_redc_mod_N; auto. + Qed. + + Lemma redc_bound_tight A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : 0 <= eval (redc A) < eval N + eval B + if eval N <=? eval (pre_redc A) then -eval N else 0. + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds r_big' partition_Proper small_A sub_then_maybe_add. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; omega. + Qed. + + Lemma redc_bound_N A_numlimbs (A : T A_numlimbs) + (small_A : small A) + : eval B < eval N -> 0 <= eval (redc A) < eval N. + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; omega. + Qed. + + Lemma redc_bound A_numlimbs (A : T A_numlimbs) + (small_A : small A) + (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : 0 <= eval (redc A) < R. + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add A_bound. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + rewrite eval_conditional_sub by t_small. + break_innermost_match; Z.ltb_to_lt; try omega. + Qed. + + Lemma small_redc A_numlimbs (A : T A_numlimbs) + (small_A : small A) + (A_bound : 0 <= eval A < r ^ Z.of_nat A_numlimbs) + : small (redc A). + Proof using small_N small_B lgr_big R_numlimbs_nz N_nz N_lt_R B_bounds. + clear -small_N small_B r_big' partition_Proper R_numlimbs_nz N_nz N_lt_R B_bounds small_A sub_then_maybe_add A_bound. + pose proof (@small_pre_redc _ A small_A). + pose proof (@pre_redc_bound _ A small_A). + unfold redc. + apply small_conditional_sub; [ apply small_pre_redc | .. ]; auto; omega. + Qed. + End redc_proofs. + + Section add_sub. + Context (Av Bv : T R_numlimbs) + (small_Av : small Av) + (small_Bv : small Bv) + (Av_bound : 0 <= eval Av < eval N) + (Bv_bound : 0 <= eval Bv < eval N). + + Local Ltac do_clear := + clear dependent B; clear dependent k; clear dependent ri. + + Lemma small_add : small (add Av Bv). + Proof using small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound small_N ri k R_numlimbs_nz N_nz B_bounds B. + clear -small_Bv small_Av N_lt_R Bv_bound Av_bound partition_Proper r_big' small_N ri k R_numlimbs_nz N_nz B_bounds B sub_then_maybe_add. + unfold add; t_small. + Qed. + Lemma small_sub : small (sub Av Bv). + Proof using small_N small_Bv small_Av partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R Bv_bound Av_bound. unfold sub; t_small. Qed. + Lemma small_opp : small (opp Av). + Proof using small_N small_Bv small_Av partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R Av_bound. unfold opp, sub; t_small. Qed. + + Lemma eval_add : eval (add Av Bv) = eval Av + eval Bv + if (eval N <=? eval Av + eval Bv) then -eval N else 0. + Proof using small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound small_N ri k R_numlimbs_nz N_nz B_bounds B. + clear -small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound partition_Proper r_big' small_N ri k R_numlimbs_nz N_nz B_bounds B sub_then_maybe_add. + unfold add; autorewrite with push_mont_eval; reflexivity. + Qed. + Lemma eval_sub : eval (sub Av Bv) = eval Av - eval Bv + if (eval Av - eval Bv <? 0) then eval N else 0. + Proof using small_Bv small_Av Bv_bound Av_bound small_N partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R. unfold sub; autorewrite with push_mont_eval; reflexivity. Qed. + Lemma eval_opp : eval (opp Av) = (if (eval Av =? 0) then 0 else eval N) - eval Av. + Proof using small_Av Av_bound small_N partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R. + clear -Av_bound N_nz small_Av partition_Proper r_big' small_N lgr_big R_numlimbs_nz N_nz N_lt_R. + unfold opp, sub; autorewrite with push_mont_eval. + break_innermost_match; Z.ltb_to_lt; lia. + Qed. + + Local Ltac t_mod_N := + repeat first [ progress break_innermost_match + | reflexivity + | let H := fresh in intro H; rewrite H; clear H + | progress autorewrite with zsimplify_const + | rewrite Z.add_opp_r + | progress (push_Zmod; pull_Zmod) ]. + + Lemma eval_add_mod_N : eval (add Av Bv) mod eval N = (eval Av + eval Bv) mod eval N. + Proof using small_Bv small_Av lgr_big N_lt_R Bv_bound Av_bound small_N ri k R_numlimbs_nz N_nz B_bounds B. + generalize eval_add; clear. t_mod_N. + Qed. + Lemma eval_sub_mod_N : eval (sub Av Bv) mod eval N = (eval Av - eval Bv) mod eval N. + Proof using small_Bv small_Av Bv_bound Av_bound small_N r_big' partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R. generalize eval_sub; clear. t_mod_N. Qed. + Lemma eval_opp_mod_N : eval (opp Av) mod eval N = (-eval Av) mod eval N. + Proof using small_Av Av_bound small_N r_big' partition_Proper lgr_big R_numlimbs_nz N_nz N_lt_R. generalize eval_opp; clear. t_mod_N. Qed. + + Lemma add_bound : 0 <= eval (add Av Bv) < eval N. + Proof using small_Bv small_Av lgr_big R_numlimbs_nz N_lt_R Bv_bound Av_bound small_N ri k N_nz B_bounds B. + generalize eval_add; break_innermost_match; Z.ltb_to_lt; lia. + Qed. + Lemma sub_bound : 0 <= eval (sub Av Bv) < eval N. + Proof using small_Bv small_Av R_numlimbs_nz Bv_bound Av_bound small_N r_big' partition_Proper lgr_big N_nz N_lt_R. + generalize eval_sub; break_innermost_match; Z.ltb_to_lt; lia. + Qed. + Lemma opp_bound : 0 <= eval (opp Av) < eval N. + Proof using small_Av R_numlimbs_nz Av_bound small_N r_big' partition_Proper lgr_big N_nz N_lt_R. + clear Bv small_Bv Bv_bound. + generalize eval_opp; break_innermost_match; Z.ltb_to_lt; lia. + Qed. + End add_sub. + End with_args. + + Section modops. + Context (bitwidth : Z) + (n : nat) + (m : Z). + Let r := 2^bitwidth. + Local Notation weight := (UniformWeight.uweight bitwidth). + Local Notation eval := (@eval bitwidth n). + Let m_enc := partition weight n m. + Local Coercion Z.of_nat : nat >-> Z. + Context (r' : Z) + (m' : Z) + (r'_correct : (r * r') mod m = 1) + (m'_correct : (m * m') mod r = (-1) mod r) + (bitwidth_big : 0 < bitwidth) + (m_big : 1 < m) + (n_nz : n <> 0%nat) + (m_small : m < r^n). + + Local Notation wprops := (@UniformWeight.uwprops bitwidth bitwidth_big). + Local Notation small := (@small bitwidth n). + + Local Hint Immediate (wprops). + Local Hint Immediate (weight_0 wprops). + Local Hint Immediate (weight_positive wprops). + Local Hint Immediate (weight_multiples wprops). + Local Hint Immediate (weight_divides wprops). + + Local Lemma m_enc_correct_montgomery : m = eval m_enc. + Proof using m_small m_big bitwidth_big. + clear -m_small m_big bitwidth_big. + cbv [eval m_enc]; autorewrite with push_eval; auto. + rewrite UniformWeight.uweight_eq_alt by omega. + Z.rewrite_mod_small; reflexivity. + Qed. + Local Lemma r'_pow_correct : (r'^n * r^n) mod (eval m_enc) = 1. + Proof using r'_correct m_small m_big bitwidth_big. + clear -r'_correct m_small m_big bitwidth_big. + rewrite <- Z.pow_mul_l, Z.mod_pow_full, ?(Z.mul_comm r'), <- m_enc_correct_montgomery, r'_correct. + autorewrite with zsimplify_const; auto with omega. + Z.rewrite_mod_small; omega. + Qed. + Local Lemma small_m_enc : small m_enc. + Proof using m_small m_big bitwidth_big. + clear -m_small m_big bitwidth_big. + cbv [m_enc small eval]; autorewrite with push_eval; auto. + rewrite UniformWeight.uweight_eq_alt by omega. + Z.rewrite_mod_small; reflexivity. + Qed. + + Local Ltac t_fin := + repeat match goal with + | _ => assumption + | [ |- ?x = ?x ] => reflexivity + | [ |- and _ _ ] => split + | _ => rewrite <- !m_enc_correct_montgomery + | _ => rewrite !r'_correct + | _ => rewrite !Z.mod_1_l by assumption; reflexivity + | _ => rewrite !(Z.mul_comm m' m) + | _ => lia + | _ => exact small_m_enc + | [ H : small ?x |- context[eval ?x] ] + => rewrite H; cbv [eval]; rewrite eval_partition by auto + | [ |- context[weight _] ] => rewrite UniformWeight.uweight_eq_alt by auto with omega + | _=> progress Z.rewrite_mod_small + | _ => progress Z.zero_bounds + | [ |- _ mod ?x < ?x ] => apply Z.mod_pos_bound + end. + + Definition mulmod (a b : list Z) : list Z := @redc bitwidth n m_enc n a b m'. + Definition squaremod (a : list Z) : list Z := mulmod a a. + Definition addmod (a b : list Z) : list Z := @add bitwidth n m_enc a b. + Definition submod (a b : list Z) : list Z := @sub bitwidth n m_enc a b. + Definition oppmod (a : list Z) : list Z := @opp bitwidth n m_enc a. + Definition nonzeromod (a : list Z) : Z := @nonzero a. + Definition to_bytesmod (a : list Z) : list Z := @to_bytesmod bitwidth 1 n a. + + Definition valid (a : list Z) := small a /\ 0 <= eval a < m. + + Lemma mulmod_correct0 + : forall a b : list Z, + small a -> small b + -> small (mulmod a b) + /\ (eval b < m -> 0 <= eval (mulmod a b) < m) + /\ (eval (mulmod a b) mod m = (eval a * eval b * r'^n) mod m). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + intros a b Ha Hb; repeat apply conj; cbv [small mulmod eval]; + [ eapply small_redc + | rewrite m_enc_correct_montgomery; eapply redc_bound_N + | rewrite !m_enc_correct_montgomery; eapply redc_mod_N ]; + t_fin. + Qed. + + Definition onemod : list Z := partition weight n 1. + + Definition onemod_correct : eval onemod = 1 /\ valid onemod. + Proof using n_nz m_big bitwidth_big. + clear -n_nz m_big bitwidth_big. + cbv [valid small onemod eval]; autorewrite with push_eval; t_fin. + Qed. + + Lemma eval_onemod : eval onemod = 1. + Proof. apply onemod_correct. Qed. + + Definition R2mod : list Z := partition weight n ((r^n * r^n) mod m). + + Definition R2mod_correct : eval R2mod mod m = (r^n*r^n) mod m /\ valid R2mod. + Proof using n_nz m_small m_big m'_correct bitwidth_big. + clear -n_nz m_small m_big m'_correct bitwidth_big. + cbv [valid small R2mod eval]; autorewrite with push_eval; t_fin; + rewrite !(Z.mod_small (_ mod m)) by (Z.div_mod_to_quot_rem; subst r; lia); + t_fin. + Qed. + + Definition from_montgomerymod (v : list Z) : list Z + := mulmod v onemod. + + Lemma from_montgomerymod_correct (v : list Z) + : valid v -> eval (from_montgomerymod v) mod m = (eval v * r'^n) mod m + /\ valid (from_montgomerymod v). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + clear -r'_correct n_nz m_small m_big m'_correct bitwidth_big. + intro Hv; cbv [from_montgomerymod valid] in *; destruct_head'_and. + replace (eval v * r'^n) with (eval v * eval onemod * r'^n) by (rewrite (proj1 onemod_correct); lia). + repeat apply conj; apply mulmod_correct0; auto; try apply onemod_correct; rewrite (proj1 onemod_correct); omega. + Qed. + + Lemma eval_from_montgomerymod (v : list Z) : valid v -> eval (from_montgomerymod v) mod m = (eval v * r'^n) mod m. + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + intros; apply from_montgomerymod_correct; assumption. + Qed. + Lemma valid_from_montgomerymod (v : list Z) + : valid v -> valid (from_montgomerymod v). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + intros; apply from_montgomerymod_correct; assumption. + Qed. + + Lemma mulmod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (mulmod a b)) mod m + = (eval (from_montgomerymod a) * eval (from_montgomerymod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (mulmod a b)). + Proof using r'_correct r' n_nz m_small m_big m'_correct bitwidth_big. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; + try apply mulmod_correct0; cbv [valid] in *; destruct_head'_and; auto; []. + rewrite !Z.mul_assoc. + apply Z.mul_mod_Proper; [ | reflexivity ]. + cbv [Z.equiv_modulo]; etransitivity; [ apply mulmod_correct0 | apply f_equal2; lia ]; auto. + Qed. + + Lemma eval_mulmod + : (forall a (_ : valid a) b (_ : valid b), + eval (from_montgomerymod (mulmod a b)) mod m + = (eval (from_montgomerymod a) * eval (from_montgomerymod b)) mod m). + Proof. apply mulmod_correct. Qed. + + Lemma squaremod_correct + : (forall a (_ : valid a), eval (from_montgomerymod (squaremod a)) mod m + = (eval (from_montgomerymod a) * eval (from_montgomerymod a)) mod m) + /\ (forall a (_ : valid a), valid (squaremod a)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + split; intros; cbv [squaremod]; apply mulmod_correct; assumption. + Qed. + + Lemma eval_squaremod + : (forall a (_ : valid a), + eval (from_montgomerymod (squaremod a)) mod m + = (eval (from_montgomerymod a) * eval (from_montgomerymod a)) mod m). + Proof. apply squaremod_correct. Qed. + + Definition encodemod (v : Z) : list Z + := mulmod (partition weight n v) R2mod. + + Local Ltac t_valid v := + cbv [valid]; repeat apply conj; + auto; cbv [small eval]; autorewrite with push_eval; auto; + rewrite ?UniformWeight.uweight_eq_alt by omega; + Z.rewrite_mod_small; + rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia); + rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia); + try apply Z.mod_pos_bound; subst r; try lia; try reflexivity. + Lemma encodemod_correct + : (forall v, 0 <= v < m -> eval (from_montgomerymod (encodemod v)) mod m = v mod m) + /\ (forall v, 0 <= v < m -> valid (encodemod v)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + split; intros v ?; cbv [encodemod R2mod]; [ rewrite (proj1 mulmod_correct) | apply mulmod_correct ]; + [ | now t_valid v.. ]. + push_Zmod; rewrite !eval_from_montgomerymod; [ | now t_valid v.. ]. + cbv [eval]; autorewrite with push_eval; auto. + rewrite ?UniformWeight.uweight_eq_alt by omega. + rewrite ?(Z.mod_small v) by (subst r; Z.div_mod_to_quot_rem; lia). + rewrite ?(Z.mod_small (_ mod m)) by (subst r; Z.div_mod_to_quot_rem; lia). + pull_Zmod. + rewrite <- !Z.mul_assoc; autorewrite with pull_Zpow. + generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. + rewrite (Z.mul_comm r' r); generalize r'_correct; push_Zmod; intro Heq; rewrite Heq; clear Heq; pull_Zmod; autorewrite with zsimplify_const. + Z.rewrite_mod_small. + reflexivity. + Qed. + + Lemma eval_encodemod + : (forall v, 0 <= v < m + -> eval (from_montgomerymod (encodemod v)) mod m = v mod m). + Proof. apply encodemod_correct. Qed. + + Lemma addmod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (addmod a b)) mod m + = (eval (from_montgomerymod a) + eval (from_montgomerymod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (addmod a b)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; + cbv [valid addmod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_add || eapply add_bound); + cbv [small]; rewrite <- ?m_enc_correct_montgomery; + eauto with omega; [ ]. + push_Zmod; erewrite eval_add by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma eval_addmod + : (forall a (_ : valid a) b (_ : valid b), + eval (from_montgomerymod (addmod a b)) mod m + = (eval (from_montgomerymod a) + eval (from_montgomerymod b)) mod m). + Proof. apply addmod_correct. Qed. + + Lemma submod_correct + : (forall a (_ : valid a) b (_ : valid b), eval (from_montgomerymod (submod a b)) mod m + = (eval (from_montgomerymod a) - eval (from_montgomerymod b)) mod m) + /\ (forall a (_ : valid a) b (_ : valid b), valid (submod a b)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; + cbv [valid submod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_sub || eapply sub_bound); + cbv [small]; rewrite <- ?m_enc_correct_montgomery; + eauto with omega; [ ]. + push_Zmod; erewrite eval_sub by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma eval_submod + : (forall a (_ : valid a) b (_ : valid b), + eval (from_montgomerymod (submod a b)) mod m + = (eval (from_montgomerymod a) - eval (from_montgomerymod b)) mod m). + Proof. apply submod_correct. Qed. + + Lemma oppmod_correct + : (forall a (_ : valid a), eval (from_montgomerymod (oppmod a)) mod m + = (-eval (from_montgomerymod a)) mod m) + /\ (forall a (_ : valid a), valid (oppmod a)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + repeat apply conj; intros; + push_Zmod; rewrite ?eval_from_montgomerymod; pull_Zmod; repeat apply conj; + cbv [valid oppmod] in *; destruct_head'_and; auto; + try rewrite m_enc_correct_montgomery; + try (eapply small_opp || eapply opp_bound); + cbv [small]; rewrite <- ?m_enc_correct_montgomery; + eauto with omega; [ ]. + push_Zmod; erewrite eval_opp by (cbv [small]; rewrite <- ?m_enc_correct_montgomery; eauto with omega); pull_Zmod; rewrite <- ?m_enc_correct_montgomery. + break_innermost_match; push_Zmod; pull_Zmod; autorewrite with zsimplify_const; apply f_equal2; nia. + Qed. + + Lemma eval_oppmod + : (forall a (_ : valid a), + eval (from_montgomerymod (oppmod a)) mod m + = (-eval (from_montgomerymod a)) mod m). + Proof. apply oppmod_correct. Qed. + + Lemma nonzeromod_correct + : (forall a (_ : valid a), (nonzeromod a = 0) <-> ((eval (from_montgomerymod a)) mod m = 0)). + Proof using r'_correct n_nz m_small m_big m'_correct bitwidth_big. + intros a Ha; rewrite eval_from_montgomerymod by assumption. + cbv [nonzeromod valid] in *; destruct_head'_and. + rewrite eval_nonzero; try eassumption; [ | subst r; apply conj; try eassumption; omega.. ]. + split; intro H'; [ rewrite H'; autorewrite with zsimplify_const; reflexivity | ]. + assert (H'' : ((eval a * r'^n) * r^n) mod m = 0) + by (revert H'; push_Zmod; intro H'; rewrite H'; autorewrite with zsimplify_const; reflexivity). + rewrite <- Z.mul_assoc in H''. + autorewrite with pull_Zpow push_Zmod in H''. + rewrite (Z.mul_comm r' r), r'_correct in H''. + autorewrite with zsimplify_const pull_Zmod in H''; [ | lia.. ]. + clear H'. + generalize dependent (eval a); clear. + intros z ???. + assert (z / m = 0) by (Z.div_mod_to_quot_rem; nia). + Z.div_mod_to_quot_rem; nia. + Qed. + + Lemma to_bytesmod_correct + : (forall a (_ : valid a), Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) + = eval a mod m) + /\ (forall a (_ : valid a), to_bytesmod a = partition (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (eval a mod m)). + Proof using n_nz m_small bitwidth_big. + clear -n_nz m_small bitwidth_big. + generalize (@length_small bitwidth n); + cbv [valid small to_bytesmod eval]; split; intros; (etransitivity; [ apply eval_to_bytesmod | ]); + fold weight in *; fold (UniformWeight.uweight 8) in *; subst r; + try solve [ intuition eauto with omega ]. + all: repeat first [ rewrite UniformWeight.uweight_eq_alt by omega + | omega + | reflexivity + | progress Z.rewrite_mod_small ]. + Qed. + + Lemma eval_to_bytesmod + : (forall a (_ : valid a), + Positional.eval (UniformWeight.uweight 8) (bytes_n bitwidth 1 n) (to_bytesmod a) + = eval a mod m). + Proof. apply to_bytesmod_correct. Qed. + End modops. +End WordByWordMontgomery.
\ No newline at end of file |