From d71fe6f9e283ba7b297da63fa5240c9117acd570 Mon Sep 17 00:00:00 2001 From: jadep Date: Tue, 12 Mar 2019 14:39:46 -0400 Subject: fix up a messy proof --- src/Arithmetic.v | 53 +++++++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 24 deletions(-) (limited to 'src') diff --git a/src/Arithmetic.v b/src/Arithmetic.v index 7a9babf1a..e4df3c2f5 100644 --- a/src/Arithmetic.v +++ b/src/Arithmetic.v @@ -3663,6 +3663,15 @@ Module UniformWeight. intros. rewrite <-Znumtheory.Zmod_div_mod; auto using uwprops; [ ]. rewrite !uweight_eq_alt'. apply Divide.Z.divide_pow_le. nia. Qed. + + Lemma uweight_pull_mod lgr (Hr : 0 < lgr) x i j : + (j <= i)%nat -> + x mod (uweight lgr i) / uweight lgr j = (x / uweight lgr j) mod (uweight lgr (i - j)). + Proof. + intros. rewrite Z.mod_pull_div by auto using Z.lt_le_incl, uwprops. + rewrite <-uweight_sum_indices by lia. + repeat (f_equal; try lia). + Qed. End UniformWeight. Module WordByWordMontgomery. @@ -5095,6 +5104,10 @@ Module BarrettReduction. Section Proofs. + (* TODO: move these? *) + Local Hint Resolve Z.lt_gt : zarith. + Hint Rewrite Z.div_add' using solve [auto with zarith] : zsimplify. + Lemma shiftr'_correct m n : forall t tn, (m <= tn)%nat -> 0 <= t < w tn -> 0 <= n < width -> @@ -5103,31 +5116,27 @@ Module BarrettReduction. cbv [shiftr']. induction m; intros; [ reflexivity | ]. rewrite !partition_step, seq_snoc. autorewrite with distr_length natsimplify push_map push_nth_default. - rewrite IHm by auto with zarith. - rewrite Z.rshi_correct, UniformWeight.uweight_S by auto with zarith. + rewrite IHm, Z.rshi_correct, UniformWeight.uweight_S by auto with zarith. rewrite <-Z.mod_pull_div by auto with zarith. destruct (Nat.eq_dec (S m) tn); [subst tn | ]; rewrite !nth_default_partition by omega. { rewrite nth_default_out_of_bounds by distr_length. autorewrite with zsimplify. Z.rewrite_mod_small. rewrite Z.div_div_comm by auto with zarith; reflexivity. } - { (* oh god the worst *) - rewrite !UniformWeight.uweight_S by auto with zarith. - rewrite <-!Z.mod_pull_div by auto with zarith. - rewrite Z.mod_pull_div with (b:=2^n) by auto with zarith. - rewrite <-Z.div_div by auto with zarith. - rewrite (Z.div_div_comm _ (2^width) (w m)) by auto with zarith. - rewrite Z.mod_pull_div with (b:=2^width) by auto with zarith. - rewrite Z.mul_div_eq'; auto using Z.lt_gt with zarith. - rewrite Z.rem_mul_r with (b:=2^width) (c:=2^width) by auto with zarith. - autorewrite with zsimplify. - rewrite <-Z.rem_mul_r by auto with zarith. - rewrite !Z.mod_pull_div by auto with zarith. - rewrite <-Znumtheory.Zmod_div_mod by - (try solve [Z.zero_bounds]; - rewrite <-!Z.pow_add_r by auto with zarith; - auto using Z.mul_divide_mono_r, Divide.Z.divide_pow_le with zarith). - rewrite Z.div_div_comm by auto with zarith. - repeat (f_equal; try ring). } + { repeat match goal with + | _ => rewrite UniformWeight.uweight_pull_mod by auto with zarith + | _ => rewrite Z.mod_mod_small by auto with zarith + | _ => rewrite <-Znumtheory.Zmod_div_mod by (Z.zero_bounds; auto with zarith) + | _ => rewrite UniformWeight.uweight_eq_alt with (n:=1%nat) by auto with zarith + | |- context [(t / w (S m)) mod 2^width * 2^width] => + replace (t / w (S m)) with (t / w m / 2^width) by + (rewrite UniformWeight.uweight_S, Z.div_div by auto with zarith; f_equal; lia); + rewrite Z.mod_pull_div with (b:=2^width) by auto with zarith; + rewrite Z.mul_div_eq' by auto with zarith + | _ => progress autorewrite with natsimplify zsimplify_fast zsimplify + end. + replace (2^width*2^width) with (2^width*2^(width-n)*2^n) by (autorewrite with pull_Zpow; f_equal; lia). + rewrite <-Z.mod_pull_div, <-Znumtheory.Zmod_div_mod by (Z.zero_bounds; auto with zarith). + rewrite Z.div_div_comm by Z.zero_bounds. reflexivity. } Qed. Lemma shiftr_correct m n : forall t tn, @@ -5225,10 +5234,6 @@ Module BarrettReduction. ring_simplify. ring. Qed. - (* TODO: move these? *) - Local Hint Resolve Z.lt_gt : zarith. - Hint Rewrite Z.div_add' using solve [auto with zarith] : zsimplify. - Lemma mul_high_correct a b (Ha : a / w sz = 1) a0b1 (Ha0b1 : a0b1 = a mod w sz * (b / w sz)) : -- cgit v1.2.3