diff options
author | jadep <jade.philipoom@gmail.com> | 2016-07-06 13:48:40 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2016-07-06 13:48:40 -0400 |
commit | cc920cf2a3aa859a93e8e990a19a960f78cd3b1b (patch) | |
tree | ed93d93f82578239ef2e0a6843e52e1d6968a5ef /src/ModularArithmetic | |
parent | 260b20cab885deae59a305492567dc0f0d88b3a8 (diff) | |
parent | 0cea3e2f80408a25954f820faebf5cd79d2e13ae (diff) |
Merged changes, including new ZUtil conventions.
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r-- | src/ModularArithmetic/BarrettReduction/Z.v | 118 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemOpt.v | 8 | ||||
-rw-r--r-- | src/ModularArithmetic/ModularBaseSystemProofs.v | 28 | ||||
-rw-r--r-- | src/ModularArithmetic/Pow2BaseProofs.v | 4 | ||||
-rw-r--r-- | src/ModularArithmetic/PrimeFieldTheorems.v | 4 | ||||
-rw-r--r-- | src/ModularArithmetic/PseudoMersenneBaseParamProofs.v | 10 |
6 files changed, 145 insertions, 27 deletions
diff --git a/src/ModularArithmetic/BarrettReduction/Z.v b/src/ModularArithmetic/BarrettReduction/Z.v new file mode 100644 index 000000000..8b472d5d8 --- /dev/null +++ b/src/ModularArithmetic/BarrettReduction/Z.v @@ -0,0 +1,118 @@ +(*** Barrett Reduction *) +(** This file implements Barrett Reduction on [Z]. We follow Wikipedia. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. + +Local Open Scope Z_scope. + +Section barrett. + Context (n a : Z) + (n_reasonable : n <> 0). + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *) + (** In modular arithmetic, Barrett reduction is a reduction + algorithm introduced in 1986 by P.D. Barrett. A naive way of + computing *) + (** [c = a mod n] *) + (** would be to use a fast division algorithm. Barrett reduction is + an algorithm designed to optimize this operation assuming [n] is + constant, and [a < n²], replacing divisions by + multiplications. *) + + (** * General idea *) + Section general_idea. + (** Let [m = 1 / n] be the inverse of [n] as a floating point + number. Then *) + (** [a mod n = a - ⌊a m⌋ n] *) + (** where [⌊ x ⌋] denotes the floor function. The result is exact, + as long as [m] is computed with sufficient accuracy. *) + + (* [/] is [Z.div], which means truncated division *) + Local Notation "⌊am⌋" := (a / n) (only parsing). + + Theorem naive_barrett_reduction_correct + : a mod n = a - ⌊am⌋ * n. + Proof. + apply Zmod_eq_full; assumption. + Qed. + End general_idea. + + (** * Barrett algorithm *) + Section barrett_algorithm. + (** Barrett algorithm is a fixed-point analog which expresses + everything in terms of integers. Let [k] be the smallest + integer such that [2ᵏ > n]. Think of [n] as representing the + fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m = + ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number + [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *) + (** N.B. We don't need [k] to be the smallest such integer. *) + Context (k : Z) + (k_good : n < 2 ^ k) + (m : Z) + (m_good : m = 4^k / n). (* [/] is [Z.div], which is truncated *) + (** Wikipedia neglects to mention non-negativity, but we need it. + It might be possible to do with a relaxed assumption, such as + the sign of [a] and the sign of [n] being the same; but I + figured it wasn't worth it. *) + Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *) + (a_nonneg : 0 <= a). + + Lemma k_nonnegative : 0 <= k. + Proof. + destruct (Z_lt_le_dec k 0); try assumption. + rewrite !Z.pow_neg_r in * by lia; lia. + Qed. + + (** Now *) + Let q := (m * a) / 4^k. + Let r := a - q * n. + (** Because of the floor function (in Coq, because [/] means + truncated division), [q] is an integer and [r ≡ a mod n]. *) + Theorem barrett_reduction_equivalent + : r mod n = a mod n. + Proof. + subst r q m. + rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption. + reflexivity. + Qed. + + Lemma qn_small + : q * n <= a. + Proof. + pose proof k_nonnegative; subst q r m. + assert (0 <= 2^(k-1)) by zero_bounds. + Z.simplify_fractions_le. + Qed. + + (** Also, if [a < n²] then [r < 2n]. *) + (** N.B. It turns out that it is sufficient to assume [a < 4ᵏ]. *) + Context (a_small : a < 4^k). + Lemma r_small : r < 2 * n. + Proof. + Hint Rewrite (Z.div_small a (4^k)) (Z.mod_small a (4^k)) using lia : zsimplify. + Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div. + cut (r + 1 <= 2 * n); [ lia | ]. + pose proof k_nonnegative; subst r q m. + assert (0 <= 2^(k-1)) by zero_bounds. + assert (4^k <> 0) by auto with zarith lia. + assert (a mod n < n) by auto with zarith lia. + pose proof (Z.mod_pos_bound (a * 4^k / n) (4^k)). + transitivity (a - (a * 4 ^ k / n - a) / 4 ^ k * n + 1). + { rewrite <- (Z.mul_comm a); auto 6 with zarith lia. } + rewrite (Z_div_mod_eq (_ * 4^k / n) (4^k)) by lia. + autorewrite with push_Zmul push_Zopp zsimplify zstrip_div. + break_match; auto with lia. + Qed. + + (** In that case, we have *) + Theorem barrett_reduction_small + : a mod n = if r <? n + then r + else r - n. + Proof. + pose proof r_small. pose proof qn_small. + destruct (r <? n) eqn:rlt; Z.ltb_to_lt. + { symmetry; apply (Zmod_unique a n q); subst r; lia. } + { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. } + Qed. + End barrett_algorithm. +End barrett. diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index bb9b1674e..ce11b157b 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -22,7 +22,7 @@ Definition Z_div_opt := Eval compute in Z.div. Definition Z_pow_opt := Eval compute in Z.pow. Definition Z_opp_opt := Eval compute in Z.opp. Definition Z_shiftl_opt := Eval compute in Z.shiftl. -Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by. +Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by. Definition nth_default_opt {A} := Eval compute in @nth_default A. Definition set_nth_opt {A} := Eval compute in @set_nth A. @@ -502,11 +502,11 @@ Section Multiplication. cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce]. rewrite <- mul'_opt_correct. change @base with base_opt. - rewrite map_shiftl by apply k_nonneg. + rewrite Z.map_shiftl by apply k_nonneg. rewrite c_subst. rewrite k_subst. change @map with @map_opt. - change @Z_shiftl_by with @Z_shiftl_by_opt. + change @Z.shiftl_by with @Z_shiftl_by_opt. reflexivity. Defined. @@ -671,4 +671,4 @@ Section Canonicalization. auto using freeze_opt_preserves_rep. Qed. -End Canonicalization.
\ No newline at end of file +End Canonicalization. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 75806f570..29612d900 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -170,7 +170,7 @@ Section PseudoMersenneProofs. rewrite Z.sub_sub_distr, Z.sub_diag. simpl. rewrite Z.mul_comm. - rewrite mod_mult_plus; auto using modulus_nonzero. + rewrite Z.mod_add_l; auto using modulus_nonzero. rewrite <- Zplus_mod; auto. Qed. @@ -390,8 +390,8 @@ Section CarryProofs. rewrite nth_default_base_succ by omega. rewrite Z.mul_assoc. rewrite (Z.mul_comm _ (2 ^ log_cap i)). - rewrite mul_div_eq; try ring. - apply gt_lt_symmetry. + rewrite Z.mul_div_eq; try ring. + apply Z.gt_lt_iff. apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg. Qed. @@ -423,7 +423,7 @@ Section CarryProofs. rewrite <- Z.add_opp_l, <- Z.opp_sub_distr. unfold pow2_mod. rewrite Z.land_ones by apply log_cap_nonneg. - rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). + rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. rewrite Zopp_mult_distr_r. rewrite Z.mul_comm. @@ -570,7 +570,7 @@ Section CanonicalizationProofs. Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i. Proof. - unfold max_bound, log_cap; intros; apply Z_ones_pos_pos. + unfold max_bound, log_cap; intros; apply Z.ones_pos_pos. apply limb_widths_pos. rewrite nth_default_eq. apply nth_In. @@ -580,7 +580,7 @@ Section CanonicalizationProofs. Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. - unfold max_bound; intros; auto using Z_ones_nonneg. + unfold max_bound; intros; auto using Z.ones_nonneg. Qed. Local Hint Resolve max_bound_nonneg. @@ -939,7 +939,7 @@ Section CanonicalizationProofs. apply Z.add_le_mono. + apply carry_bounds_0_upper; auto; omega. + apply Z.mul_le_mono_pos_l; auto using c_pos. - apply Z_shiftr_ones; auto; + apply Z.shiftr_ones; auto; [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. split. - apply carry_bounds_lower; auto; omega. @@ -978,7 +978,7 @@ Section CanonicalizationProofs. + rewrite <-max_bound_log_cap, <-Z.add_1_l. apply Z.add_le_mono. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. - apply Z_div_floor; auto. + apply Z.div_floor; auto. destruct i. * simpl. eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. @@ -1061,7 +1061,7 @@ Section CanonicalizationProofs. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. - - apply Z_div_floor; auto. + - apply Z.div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. replace (Z.succ 1) with (2 ^ 1) by ring. rewrite <-max_bound_log_cap. @@ -1267,7 +1267,7 @@ Section CanonicalizationProofs. Lemma max_ones_nonneg : 0 <= max_ones. Proof. unfold max_ones. - apply Z_ones_nonneg. + apply Z.ones_nonneg. pose proof limb_widths_nonneg. induction limb_widths. cbv; congruence. @@ -1282,19 +1282,19 @@ Section CanonicalizationProofs. unfold max_ones. intros ? ? x_range. rewrite Z.land_comm. - rewrite Z.land_ones by apply Z_le_fold_right_max_initial. + rewrite Z.land_ones by apply Z.le_fold_right_max_initial. apply Z.mod_small. split; try omega. eapply Z.lt_le_trans; try eapply x_range. apply Z.pow_le_mono_r; try omega. rewrite log_cap_eq. destruct (lt_dec i (length limb_widths)). - + apply Z_le_fold_right_max. + + apply Z.le_fold_right_max. - apply limb_widths_nonneg. - rewrite nth_default_eq. auto using nth_In. + rewrite nth_default_out_of_bounds by omega. - apply Z_le_fold_right_max_initial. + apply Z.le_fold_right_max_initial. Qed. Lemma full_isFull'_true : forall j us, (length us = length base) -> @@ -1802,7 +1802,7 @@ Section CanonicalizationProofs. + match goal with |- (?a ?= ?b) = (?c ?= ?d) => rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end. apply CompOpp_inj; rewrite !CompOpp_involutive. - apply gt_lt_symmetry in Hgt. + apply Z.gt_lt_iff in Hgt. etransitivity; try apply Z_compare_decode_step_lt; auto; omega. Qed. diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v index 23393d7ef..1504ca0df 100644 --- a/src/ModularArithmetic/Pow2BaseProofs.v +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -97,7 +97,7 @@ Section Pow2BaseProofs. Proof. intros. repeat rewrite nth_default_base by omega. - apply mod_same_pow. + apply Z.mod_same_pow. split; [apply sum_firstn_limb_widths_nonneg | ]. destruct (NPeano.Nat.eq_dec i 0); subst. + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq. @@ -199,7 +199,7 @@ Section BitwiseDecodeEncode. intros. simpl; f_equal. match goal with H : bounded _ _ |- _ => - rewrite Z_lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end. + rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end. rewrite Z.shiftl_mul_pow2 by auto. ring. Qed. diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v index 2021e8514..a2f606f30 100644 --- a/src/ModularArithmetic/PrimeFieldTheorems.v +++ b/src/ModularArithmetic/PrimeFieldTheorems.v @@ -460,8 +460,8 @@ Section SquareRootsPrime5Mod8. apply Z2N.inj_iff; try zero_bounds. rewrite <- Z.mul_cancel_l with (p := 2) by omega. ring_simplify. - rewrite mul_div_eq by omega. - rewrite mul_div_eq by omega. + rewrite Z.mul_div_eq by omega. + rewrite Z.mul_div_eq by omega. rewrite (Zmod_div_mod 2 8 q) by (try omega; apply Zmod_divide; omega || auto). rewrite q_5mod8. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v index 50c1f3ea6..c07da850f 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -42,7 +42,7 @@ Section PseudoMersenneBaseParamProofs. rewrite (Z.mul_comm r). subst r. assert (i + j - length base < length base)%nat by omega. - rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos; + rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos; [ | subst b; unfold base; rewrite nth_default_base; try assumption ]; zero_bounds; auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg). rewrite (Zminus_0_l_reverse (b i * b j)) at 1. @@ -51,7 +51,7 @@ Section PseudoMersenneBaseParamProofs. unfold base; repeat rewrite nth_default_base by auto. do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg. symmetry. - apply mod_same_pow. + apply Z.mod_same_pow. split. + apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg. + rewrite base_length, base_from_limb_widths_length in * by auto. @@ -65,7 +65,7 @@ Section PseudoMersenneBaseParamProofs. destruct In_b_base as [i nth_err_b]. apply nth_error_subst in nth_err_b; [ | auto ]. rewrite nth_err_b. - apply gt_lt_symmetry. + apply Z.gt_lt_iff. apply Z.pow_pos_nonneg; omega || auto using sum_firstn_limb_widths_nonneg. Qed. @@ -84,10 +84,10 @@ Section PseudoMersenneBaseParamProofs. unfold base in *. repeat rewrite nth_default_base by (omega || auto). rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))). - rewrite mul_div_eq by (apply gt_lt_symmetry; zero_bounds; + rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds; auto using sum_firstn_limb_widths_nonneg). rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg. - rewrite mod_same_pow; try ring. + rewrite Z.mod_same_pow; try ring. split; [ auto using sum_firstn_limb_widths_nonneg | ]. apply limb_widths_good. rewrite <- base_length; assumption. |