aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic
diff options
context:
space:
mode:
authorGravatar jadep <jade.philipoom@gmail.com>2016-07-06 13:48:40 -0400
committerGravatar jadep <jade.philipoom@gmail.com>2016-07-06 13:48:40 -0400
commitcc920cf2a3aa859a93e8e990a19a960f78cd3b1b (patch)
treeed93d93f82578239ef2e0a6843e52e1d6968a5ef /src/ModularArithmetic
parent260b20cab885deae59a305492567dc0f0d88b3a8 (diff)
parent0cea3e2f80408a25954f820faebf5cd79d2e13ae (diff)
Merged changes, including new ZUtil conventions.
Diffstat (limited to 'src/ModularArithmetic')
-rw-r--r--src/ModularArithmetic/BarrettReduction/Z.v118
-rw-r--r--src/ModularArithmetic/ModularBaseSystemOpt.v8
-rw-r--r--src/ModularArithmetic/ModularBaseSystemProofs.v28
-rw-r--r--src/ModularArithmetic/Pow2BaseProofs.v4
-rw-r--r--src/ModularArithmetic/PrimeFieldTheorems.v4
-rw-r--r--src/ModularArithmetic/PseudoMersenneBaseParamProofs.v10
6 files changed, 145 insertions, 27 deletions
diff --git a/src/ModularArithmetic/BarrettReduction/Z.v b/src/ModularArithmetic/BarrettReduction/Z.v
new file mode 100644
index 000000000..8b472d5d8
--- /dev/null
+++ b/src/ModularArithmetic/BarrettReduction/Z.v
@@ -0,0 +1,118 @@
+(*** Barrett Reduction *)
+(** This file implements Barrett Reduction on [Z]. We follow Wikipedia. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
+Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.
+
+Local Open Scope Z_scope.
+
+Section barrett.
+ Context (n a : Z)
+ (n_reasonable : n <> 0).
+ (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *)
+ (** In modular arithmetic, Barrett reduction is a reduction
+ algorithm introduced in 1986 by P.D. Barrett. A naive way of
+ computing *)
+ (** [c = a mod n] *)
+ (** would be to use a fast division algorithm. Barrett reduction is
+ an algorithm designed to optimize this operation assuming [n] is
+ constant, and [a < n²], replacing divisions by
+ multiplications. *)
+
+ (** * General idea *)
+ Section general_idea.
+ (** Let [m = 1 / n] be the inverse of [n] as a floating point
+ number. Then *)
+ (** [a mod n = a - ⌊a m⌋ n] *)
+ (** where [⌊ x ⌋] denotes the floor function. The result is exact,
+ as long as [m] is computed with sufficient accuracy. *)
+
+ (* [/] is [Z.div], which means truncated division *)
+ Local Notation "⌊am⌋" := (a / n) (only parsing).
+
+ Theorem naive_barrett_reduction_correct
+ : a mod n = a - ⌊am⌋ * n.
+ Proof.
+ apply Zmod_eq_full; assumption.
+ Qed.
+ End general_idea.
+
+ (** * Barrett algorithm *)
+ Section barrett_algorithm.
+ (** Barrett algorithm is a fixed-point analog which expresses
+ everything in terms of integers. Let [k] be the smallest
+ integer such that [2ᵏ > n]. Think of [n] as representing the
+ fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m =
+ ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number
+ [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *)
+ (** N.B. We don't need [k] to be the smallest such integer. *)
+ Context (k : Z)
+ (k_good : n < 2 ^ k)
+ (m : Z)
+ (m_good : m = 4^k / n). (* [/] is [Z.div], which is truncated *)
+ (** Wikipedia neglects to mention non-negativity, but we need it.
+ It might be possible to do with a relaxed assumption, such as
+ the sign of [a] and the sign of [n] being the same; but I
+ figured it wasn't worth it. *)
+ Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *)
+ (a_nonneg : 0 <= a).
+
+ Lemma k_nonnegative : 0 <= k.
+ Proof.
+ destruct (Z_lt_le_dec k 0); try assumption.
+ rewrite !Z.pow_neg_r in * by lia; lia.
+ Qed.
+
+ (** Now *)
+ Let q := (m * a) / 4^k.
+ Let r := a - q * n.
+ (** Because of the floor function (in Coq, because [/] means
+ truncated division), [q] is an integer and [r ≡ a mod n]. *)
+ Theorem barrett_reduction_equivalent
+ : r mod n = a mod n.
+ Proof.
+ subst r q m.
+ rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption.
+ reflexivity.
+ Qed.
+
+ Lemma qn_small
+ : q * n <= a.
+ Proof.
+ pose proof k_nonnegative; subst q r m.
+ assert (0 <= 2^(k-1)) by zero_bounds.
+ Z.simplify_fractions_le.
+ Qed.
+
+ (** Also, if [a < n²] then [r < 2n]. *)
+ (** N.B. It turns out that it is sufficient to assume [a < 4ᵏ]. *)
+ Context (a_small : a < 4^k).
+ Lemma r_small : r < 2 * n.
+ Proof.
+ Hint Rewrite (Z.div_small a (4^k)) (Z.mod_small a (4^k)) using lia : zsimplify.
+ Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div.
+ cut (r + 1 <= 2 * n); [ lia | ].
+ pose proof k_nonnegative; subst r q m.
+ assert (0 <= 2^(k-1)) by zero_bounds.
+ assert (4^k <> 0) by auto with zarith lia.
+ assert (a mod n < n) by auto with zarith lia.
+ pose proof (Z.mod_pos_bound (a * 4^k / n) (4^k)).
+ transitivity (a - (a * 4 ^ k / n - a) / 4 ^ k * n + 1).
+ { rewrite <- (Z.mul_comm a); auto 6 with zarith lia. }
+ rewrite (Z_div_mod_eq (_ * 4^k / n) (4^k)) by lia.
+ autorewrite with push_Zmul push_Zopp zsimplify zstrip_div.
+ break_match; auto with lia.
+ Qed.
+
+ (** In that case, we have *)
+ Theorem barrett_reduction_small
+ : a mod n = if r <? n
+ then r
+ else r - n.
+ Proof.
+ pose proof r_small. pose proof qn_small.
+ destruct (r <? n) eqn:rlt; Z.ltb_to_lt.
+ { symmetry; apply (Zmod_unique a n q); subst r; lia. }
+ { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. }
+ Qed.
+ End barrett_algorithm.
+End barrett.
diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v
index bb9b1674e..ce11b157b 100644
--- a/src/ModularArithmetic/ModularBaseSystemOpt.v
+++ b/src/ModularArithmetic/ModularBaseSystemOpt.v
@@ -22,7 +22,7 @@ Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
-Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by.
+Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.
Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
@@ -502,11 +502,11 @@ Section Multiplication.
cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce].
rewrite <- mul'_opt_correct.
change @base with base_opt.
- rewrite map_shiftl by apply k_nonneg.
+ rewrite Z.map_shiftl by apply k_nonneg.
rewrite c_subst.
rewrite k_subst.
change @map with @map_opt.
- change @Z_shiftl_by with @Z_shiftl_by_opt.
+ change @Z.shiftl_by with @Z_shiftl_by_opt.
reflexivity.
Defined.
@@ -671,4 +671,4 @@ Section Canonicalization.
auto using freeze_opt_preserves_rep.
Qed.
-End Canonicalization. \ No newline at end of file
+End Canonicalization.
diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v
index 75806f570..29612d900 100644
--- a/src/ModularArithmetic/ModularBaseSystemProofs.v
+++ b/src/ModularArithmetic/ModularBaseSystemProofs.v
@@ -170,7 +170,7 @@ Section PseudoMersenneProofs.
rewrite Z.sub_sub_distr, Z.sub_diag.
simpl.
rewrite Z.mul_comm.
- rewrite mod_mult_plus; auto using modulus_nonzero.
+ rewrite Z.mod_add_l; auto using modulus_nonzero.
rewrite <- Zplus_mod; auto.
Qed.
@@ -390,8 +390,8 @@ Section CarryProofs.
rewrite nth_default_base_succ by omega.
rewrite Z.mul_assoc.
rewrite (Z.mul_comm _ (2 ^ log_cap i)).
- rewrite mul_div_eq; try ring.
- apply gt_lt_symmetry.
+ rewrite Z.mul_div_eq; try ring.
+ apply Z.gt_lt_iff.
apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg.
Qed.
@@ -423,7 +423,7 @@ Section CarryProofs.
rewrite <- Z.add_opp_l, <- Z.opp_sub_distr.
unfold pow2_mod.
rewrite Z.land_ones by apply log_cap_nonneg.
- rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg).
+ rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg).
rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
rewrite Zopp_mult_distr_r.
rewrite Z.mul_comm.
@@ -570,7 +570,7 @@ Section CanonicalizationProofs.
Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i.
Proof.
- unfold max_bound, log_cap; intros; apply Z_ones_pos_pos.
+ unfold max_bound, log_cap; intros; apply Z.ones_pos_pos.
apply limb_widths_pos.
rewrite nth_default_eq.
apply nth_In.
@@ -580,7 +580,7 @@ Section CanonicalizationProofs.
Lemma max_bound_nonneg : forall i, 0 <= max_bound i.
Proof.
- unfold max_bound; intros; auto using Z_ones_nonneg.
+ unfold max_bound; intros; auto using Z.ones_nonneg.
Qed.
Local Hint Resolve max_bound_nonneg.
@@ -939,7 +939,7 @@ Section CanonicalizationProofs.
apply Z.add_le_mono.
+ apply carry_bounds_0_upper; auto; omega.
+ apply Z.mul_le_mono_pos_l; auto using c_pos.
- apply Z_shiftr_ones; auto;
+ apply Z.shiftr_ones; auto;
[ | pose proof (B_compat_log_cap (pred (length base))); omega ].
split.
- apply carry_bounds_lower; auto; omega.
@@ -978,7 +978,7 @@ Section CanonicalizationProofs.
+ rewrite <-max_bound_log_cap, <-Z.add_1_l.
apply Z.add_le_mono.
- rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
- apply Z_div_floor; auto.
+ apply Z.div_floor; auto.
destruct i.
* simpl.
eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ].
@@ -1061,7 +1061,7 @@ Section CanonicalizationProofs.
+ rewrite <-max_bound_log_cap, <-Z.add_1_l.
rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg.
apply Z.add_le_mono.
- - apply Z_div_floor; auto.
+ - apply Z.div_floor; auto.
eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ].
replace (Z.succ 1) with (2 ^ 1) by ring.
rewrite <-max_bound_log_cap.
@@ -1267,7 +1267,7 @@ Section CanonicalizationProofs.
Lemma max_ones_nonneg : 0 <= max_ones.
Proof.
unfold max_ones.
- apply Z_ones_nonneg.
+ apply Z.ones_nonneg.
pose proof limb_widths_nonneg.
induction limb_widths.
cbv; congruence.
@@ -1282,19 +1282,19 @@ Section CanonicalizationProofs.
unfold max_ones.
intros ? ? x_range.
rewrite Z.land_comm.
- rewrite Z.land_ones by apply Z_le_fold_right_max_initial.
+ rewrite Z.land_ones by apply Z.le_fold_right_max_initial.
apply Z.mod_small.
split; try omega.
eapply Z.lt_le_trans; try eapply x_range.
apply Z.pow_le_mono_r; try omega.
rewrite log_cap_eq.
destruct (lt_dec i (length limb_widths)).
- + apply Z_le_fold_right_max.
+ + apply Z.le_fold_right_max.
- apply limb_widths_nonneg.
- rewrite nth_default_eq.
auto using nth_In.
+ rewrite nth_default_out_of_bounds by omega.
- apply Z_le_fold_right_max_initial.
+ apply Z.le_fold_right_max_initial.
Qed.
Lemma full_isFull'_true : forall j us, (length us = length base) ->
@@ -1802,7 +1802,7 @@ Section CanonicalizationProofs.
+ match goal with |- (?a ?= ?b) = (?c ?= ?d) =>
rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end.
apply CompOpp_inj; rewrite !CompOpp_involutive.
- apply gt_lt_symmetry in Hgt.
+ apply Z.gt_lt_iff in Hgt.
etransitivity; try apply Z_compare_decode_step_lt; auto; omega.
Qed.
diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v
index 23393d7ef..1504ca0df 100644
--- a/src/ModularArithmetic/Pow2BaseProofs.v
+++ b/src/ModularArithmetic/Pow2BaseProofs.v
@@ -97,7 +97,7 @@ Section Pow2BaseProofs.
Proof.
intros.
repeat rewrite nth_default_base by omega.
- apply mod_same_pow.
+ apply Z.mod_same_pow.
split; [apply sum_firstn_limb_widths_nonneg | ].
destruct (NPeano.Nat.eq_dec i 0); subst.
+ case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq.
@@ -199,7 +199,7 @@ Section BitwiseDecodeEncode.
intros.
simpl; f_equal.
match goal with H : bounded _ _ |- _ =>
- rewrite Z_lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end.
+ rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end.
rewrite Z.shiftl_mul_pow2 by auto.
ring.
Qed.
diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v
index 2021e8514..a2f606f30 100644
--- a/src/ModularArithmetic/PrimeFieldTheorems.v
+++ b/src/ModularArithmetic/PrimeFieldTheorems.v
@@ -460,8 +460,8 @@ Section SquareRootsPrime5Mod8.
apply Z2N.inj_iff; try zero_bounds.
rewrite <- Z.mul_cancel_l with (p := 2) by omega.
ring_simplify.
- rewrite mul_div_eq by omega.
- rewrite mul_div_eq by omega.
+ rewrite Z.mul_div_eq by omega.
+ rewrite Z.mul_div_eq by omega.
rewrite (Zmod_div_mod 2 8 q) by
(try omega; apply Zmod_divide; omega || auto).
rewrite q_5mod8.
diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
index 50c1f3ea6..c07da850f 100644
--- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
+++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v
@@ -42,7 +42,7 @@ Section PseudoMersenneBaseParamProofs.
rewrite (Z.mul_comm r).
subst r.
assert (i + j - length base < length base)%nat by omega.
- rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos;
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos;
[ | subst b; unfold base; rewrite nth_default_base; try assumption ];
zero_bounds; auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg).
rewrite (Zminus_0_l_reverse (b i * b j)) at 1.
@@ -51,7 +51,7 @@ Section PseudoMersenneBaseParamProofs.
unfold base; repeat rewrite nth_default_base by auto.
do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
symmetry.
- apply mod_same_pow.
+ apply Z.mod_same_pow.
split.
+ apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg.
+ rewrite base_length, base_from_limb_widths_length in * by auto.
@@ -65,7 +65,7 @@ Section PseudoMersenneBaseParamProofs.
destruct In_b_base as [i nth_err_b].
apply nth_error_subst in nth_err_b; [ | auto ].
rewrite nth_err_b.
- apply gt_lt_symmetry.
+ apply Z.gt_lt_iff.
apply Z.pow_pos_nonneg; omega || auto using sum_firstn_limb_widths_nonneg.
Qed.
@@ -84,10 +84,10 @@ Section PseudoMersenneBaseParamProofs.
unfold base in *.
repeat rewrite nth_default_base by (omega || auto).
rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))).
- rewrite mul_div_eq by (apply gt_lt_symmetry; zero_bounds;
+ rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds;
auto using sum_firstn_limb_widths_nonneg).
rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg.
- rewrite mod_same_pow; try ring.
+ rewrite Z.mod_same_pow; try ring.
split; [ auto using sum_firstn_limb_widths_nonneg | ].
apply limb_widths_good.
rewrite <- base_length; assumption.