From 943d76f97fcbb02ca5d417266ad9dcd7a9561a73 Mon Sep 17 00:00:00 2001 From: jadep Date: Tue, 18 Sep 2018 12:00:37 -0400 Subject: prove [eval_conditional_sub] --- src/Experiments/NewPipeline/Arithmetic.v | 175 ++++++++++++++++++++++--------- src/Util/ListUtil.v | 1 + 2 files changed, 124 insertions(+), 52 deletions(-) diff --git a/src/Experiments/NewPipeline/Arithmetic.v b/src/Experiments/NewPipeline/Arithmetic.v index 7ad7bb1a3..06461a63c 100644 --- a/src/Experiments/NewPipeline/Arithmetic.v +++ b/src/Experiments/NewPipeline/Arithmetic.v @@ -689,40 +689,42 @@ Module Positional. Definition drop_high_to_length (n : nat) (p:list Z) : list Z := firstn n p. - (* - Lemma eval_drop_high_to_length n m p : - (forall i, weight (S i) mod weight i = 0) -> length p = m -> (n <= m)%nat -> + + (* Helper for eval_drop_high_to_length *) + Lemma eval_drop_high_to_length' q + (weight_multiples: forall i, weight (S i) mod weight i = 0) + (weight_positive: forall i, 0 < weight i) : + forall n p m, + length p = n -> length q = m -> + eval (n+m) (p ++ q) mod weight n + = eval n p mod weight n. + Proof using Type. + induction q using rev_ind; intros; distr_length. + { subst m; rewrite app_nil_r. + autorewrite with natsimplify; reflexivity. } + { rewrite app_assoc. + rewrite eval_snoc with (n:=(n+pred m)%nat) by distr_length. + rewrite weight_div_mod with (j:=(n+pred m)%nat) (i:=n) by auto with omega. + push_Zmod. rewrite IHq by lia. + autorewrite with zsimplify_fast. + reflexivity. } + Qed. + Lemma eval_drop_high_to_length n m p + (weight_multiples: forall i, weight (S i) mod weight i = 0) + (weight_positive: forall i, 0 < weight i) : + length p = m -> (n <= m)%nat -> eval n (drop_high_to_length n p) mod weight n = eval m p mod weight n. Proof using Type. - cbv [eval drop_high_to_length to_associational]; intros. - replace m with (n + (m - n))%nat in * by (f_equal; omega). - generalize dependent (m - n)%nat; clear m; intro m; intros H' H''. - rewrite seq_add, map_app, <- (firstn_skipn n p), combine_app_samelength, firstn_skipn, Associational.eval_app; - push; try omega **. - rewrite <- (Z.add_0_r (Associational.eval _)) at 1. - apply Z.add_mod_Proper; [ reflexivity | cbv [Z.equiv_modulo] ]. - generalize (skipn_length n p); rewrite H', minus_plus. - generalize (skipn n p); clear dependent p; clear H''; intros p Hp. - rewrite Zmod_0_l. - subst. - cbv [Associational.eval]. - revert n; induction p as [|p ps IHps]; intro; [ reflexivity | ]. - cbn in *. - push_Zmod; pull_Zmod; autorewrite with zsimplify_const. - rewrite <- IHps. - { cbn; reflexivity - Search (0 mod _). - rewrite Z.mod_0_l - Search (?x + ?y - ?x)%nat. - Search Z.equiv_modulo Proper. - pose proof H as H''. - rewrite <- (firstn_skipn n p) in H''. - distr_length. - - rewrite Nat.min_l in H'' by omega. - Qed. - Hint Rewrite eval_drop_high_to_length : push_eval.*) + cbv [drop_high_to_length]; intros. + rewrite <-(firstn_skipn n p). + rewrite firstn_app_inleft by (distr_length; lia). + rewrite firstn_firstn. autorewrite with natsimplify. + replace m with (n + (m-n))%nat by omega. + rewrite eval_drop_high_to_length' by (auto; distr_length; lia). + reflexivity. + Qed. + Hint Rewrite eval_drop_high_to_length : push_eval. Lemma length_drop_high_to_length n p : length (drop_high_to_length n p) = Nat.min n (length p). Proof using Type. clear; cbv [drop_high_to_length]; intros; distr_length. Qed. @@ -972,20 +974,30 @@ Module Positional. Lemma length_zselect mask cond p : length (zselect mask cond p) = length p. Proof using Type. clear dependent weight. cbv [zselect Let_In]; break_match; intros; distr_length. Qed. + + (* We need an explicit equality proof here, because sometimes it + matters that we retain the same bounds when selecting. *) + Lemma select_eq cond n : forall p q, + length p = n -> length q = n -> + select cond p q = if dec (cond = 0) then p else q. + Proof using weight. + cbv [select]; induction n; intros; + destruct p; distr_length; + destruct q; distr_length; + repeat match goal with + | _ => progress autorewrite with push_combine push_map + | _ => rewrite IHn by distr_length + | _ => rewrite Z.zselect_correct + | _ => break_match; reflexivity + end. + Qed. Lemma eval_select n cond p q : length p = n -> length q = n -> eval n (select cond p q) = if dec (cond = 0) then eval n p else eval n q. - Proof using Type. - cbv [select Let_In]; intro; subst. - rewrite <- (List.rev_involutive q), <- (List.rev_involutive p). - generalize (rev p) (rev q); clear p q; intros p q; revert q. - induction p as [|p ps IHps], q as [|q qs]; cbn [length map combine rev]; distr_length; rewrite ?Nat.add_1_r; try omega. - { break_match; reflexivity. } - { intro; rewrite !combine_snoc, !map_app by (distr_length; omega). - cbn [map]. - rewrite !eval_snoc with (n:=length ps), IHps by (distr_length; omega* ). - rewrite !Z.zselect_correct; break_match; reflexivity. } + Proof using weight. + intros; erewrite select_eq by eauto. + break_match; reflexivity. Qed. Lemma length_select_min cond p q : length (select cond p q) = Nat.min (length p) (length q). @@ -999,7 +1011,7 @@ Module Positional. End Positional. (* Hint Rewrite disappears after the end of a section *) Hint Rewrite length_zeros length_add_to_nth length_from_associational @length_add @length_carry_reduce @length_chained_carries @length_encode @length_sub @length_opp @length_select @length_zselect @length_select_min @length_extend_to_length @length_drop_high_to_length : distr_length. -Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length (*@eval_drop_high_to_length*) using (solve [auto; distr_length]): push_eval. +Hint Rewrite @eval_zeros @eval_nil @eval_snoc_S @eval_select @eval_zselect @eval_extend_to_length @eval_drop_high_to_length using (solve [auto; distr_length]): push_eval. Section Positional_nonuniform. Context (weight weight' : nat -> Z). @@ -1402,7 +1414,16 @@ Module Partition. induction n; cbn [recursive_partition]; [reflexivity | ]. intros; distr_length; auto. Qed. - Hint Rewrite length_recursive_partition : distr_length. + + Lemma drop_high_to_length_partition n m x : + (n <= m)%nat -> + Positional.drop_high_to_length n (partition m x) = partition n x. + Proof. + cbv [Positional.drop_high_to_length partition]; intros. + autorewrite with push_firstn. + rewrite Nat.min_l by omega. + reflexivity. + Qed. End Partition. Hint Rewrite length_partition length_recursive_partition : distr_length. Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval. @@ -2233,6 +2254,25 @@ Module Rows. snd (sub n p q) = (Positional.eval weight n p - Positional.eval weight n q) / weight n. Proof using wprops. solver. Qed. + Lemma conditional_sub_partitions n p q + (Hp : p = partition weight n (Positional.eval weight n p)) : + length q = n -> + 0 <= Positional.eval weight n q < weight n -> + conditional_sub n p q = partition weight n (if Positional.eval weight n q <=? Positional.eval weight n p then Positional.eval weight n p - Positional.eval weight n q else Positional.eval weight n p). + Proof. + cbv [conditional_sub]; intros. + rewrite (surjective_pairing (sub _ _ _)). + assert (length p = n) by (rewrite Hp; distr_length). + assert (0 <= Positional.eval weight n p < weight n) by + (rewrite Hp; autorewrite with push_eval; auto using Z.mod_pos_bound). + rewrite sub_partitions, sub_div; distr_length. + erewrite Positional.select_eq by (distr_length; eauto). + rewrite Z.div_sub_small, Z.ltb_antisym by omega. + destruct (Positional.eval weight n q <=? Positional.eval weight n p); + cbn [negb]; autorewrite with zsimplify_fast; + break_match; congruence. + Qed. + Lemma mul_partitions base n m p q : base <> 0 -> m <> 0%nat -> length p = n -> length q = n -> fst (mul base n m p q) = partition weight m (Positional.eval weight n p * Positional.eval weight n q). @@ -2881,6 +2921,13 @@ Module UniformWeight. rewrite !Z.pow_succ_r by auto using Nat2Z.is_nonneg. ring. Qed. + Lemma uweight_S lgr (Hr : 0 <= lgr) n : uweight lgr (S n) = 2 ^ lgr * uweight lgr n. + Proof. + rewrite !uweight_eq_alt by auto. + autorewrite with push_Zof_nat. + rewrite Z.pow_succ_r by auto using Nat2Z.is_nonneg. + reflexivity. + Qed. Lemma uweight_1 lgr : uweight lgr 1 = 2^lgr. Proof using Type. @@ -3186,11 +3233,9 @@ Module WordByWordMontgomery. autounfold with loc; intro n; destruct (zerop n). { cbn; intros; subst; cbn; rewrite Z.add_with_get_carry_full_mod; cbn; omega. } intros; repeat t_step. - repeat first [ reflexivity - | rewrite UniformWeight.uweight_eq_alt by omega - | progress autorewrite with push_Zof_nat - | rewrite Z.pow_succ_r by lia - | progress Z.rewrite_mod_small ]. + rewrite UniformWeight.uweight_S by omega. + rewrite UniformWeight.uweight_eq_alt by omega. + Z.rewrite_mod_small; reflexivity. Qed. Local Lemma small_scmul : forall n a v, small v -> 0 <= a < r -> 0 <= eval v < r^Z.of_nat n -> small (@scmul n a v). Proof using lgr_big. @@ -3232,9 +3277,7 @@ Module WordByWordMontgomery. apply Z.lt_le_trans with (m:=2 * weight n). { rewrite <-Z.add_diag. auto using Z.add_lt_mono with zarith. } - { rewrite !UniformWeight.uweight_eq_alt by omega. - autorewrite with push_Zof_nat push_Zpow. - rewrite Z.pow_succ_r by auto. + { rewrite UniformWeight.uweight_S by omega. (* In versions newer than 8.7, auto with zarith is sufficient to solve this from here. *) apply Z.mul_le_mono_nonneg_r. @@ -3273,7 +3316,35 @@ Module WordByWordMontgomery. autorewrite with push_eval. auto using partition_eq_mod with zarith. Qed. - Local Axiom eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0. + Local Lemma eval_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> eval (conditional_sub v N) = eval v + if eval N <=? eval v then -eval N else 0. + Proof using small_N lgr_big R_numlimbs_nz N_nz N_lt_R. + clear - small_N ri lgr_big R_numlimbs_nz N_nz N_lt_R. + intros; autounfold with loc; cbv [conditional_sub]. + repeat match goal with H : small _ |- _ => + rewrite H; clear H end. + autorewrite with push_eval. + assert (eval N mod weight R_numlimbs < weight (S R_numlimbs)) + by (rewrite UniformWeight.uweight_S, !UniformWeight.uweight_eq_alt by omega; subst r R; rewrite Z.mod_small by omega; assert (0 < 2 ^ lgr) by auto with zarith; nia). + rewrite Rows.conditional_sub_partitions + by (repeat (autorewrite with distr_length push_eval; auto using partition_eq_mod with zarith)). + rewrite drop_high_to_length_partition by omega. + (* TODO : do we need eval_drop_high_to_length? *) + autorewrite with push_eval. + assert (eval N < weight R_numlimbs) by + (subst r R; rewrite UniformWeight.uweight_eq_alt; omega). + assert (weight R_numlimbs = R) by (rewrite UniformWeight.uweight_eq_alt by omega; subst R; reflexivity). + assert (eval v < weight (S R_numlimbs)). + { apply Z.lt_le_trans with (m:=2 * R); [ lia | ]. + subst r R; rewrite UniformWeight.uweight_S, UniformWeight.uweight_eq_alt by omega. + apply Z.mul_le_mono_nonneg_r; [auto with zarith | ]. + transitivity (2 ^ 1); [ reflexivity | ]; + apply Z.pow_le_mono_r; omega. } + Z.rewrite_mod_small. + break_match; autorewrite with zsimplify_fast; Z.ltb_to_lt. + { rewrite Z.add_opp_r. fold (eval N). + auto using Z.mod_small with lia. } + { auto using Z.mod_small with lia. } + Qed. Local Axiom small_conditional_sub : forall v, small v -> 0 <= eval v < eval N + R -> small (conditional_sub v N). Local Axiom eval_sub_then_maybe_add : forall a b, small a -> small b -> 0 <= eval a < eval N -> 0 <= eval b < eval N -> eval (sub_then_maybe_add a b) = eval a - eval b + if eval a - eval b