diff options
author | Andres Erbsen <andreser@mit.edu> | 2017-04-06 22:53:07 -0400 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2017-04-06 22:53:07 -0400 |
commit | c9fc5a3cdf1f5ea2d104c150c30d1b1a6ac64239 (patch) | |
tree | db7187f6984acff324ca468e7b33d9285806a1eb /src/Arithmetic | |
parent | 21198245dab432d3c0ba2bb8a02254e7d0594382 (diff) |
rename-everything
Diffstat (limited to 'src/Arithmetic')
-rw-r--r-- | src/Arithmetic/BarrettReduction/Generalized.v | 140 | ||||
-rw-r--r-- | src/Arithmetic/BarrettReduction/HAC.v | 158 | ||||
-rw-r--r-- | src/Arithmetic/BarrettReduction/Wikipedia.v | 122 | ||||
-rw-r--r-- | src/Arithmetic/Core.v | 980 | ||||
-rw-r--r-- | src/Arithmetic/Karatsuba.v | 49 | ||||
-rw-r--r-- | src/Arithmetic/ModularArithmeticPre.v | 139 | ||||
-rw-r--r-- | src/Arithmetic/ModularArithmeticTheorems.v | 347 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/Definition.v | 179 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/Proofs.v | 296 | ||||
-rw-r--r-- | src/Arithmetic/PrimeFieldTheorems.v | 294 | ||||
-rw-r--r-- | src/Arithmetic/Saturated.v | 285 |
11 files changed, 2989 insertions, 0 deletions
diff --git a/src/Arithmetic/BarrettReduction/Generalized.v b/src/Arithmetic/BarrettReduction/Generalized.v new file mode 100644 index 000000000..76058463c --- /dev/null +++ b/src/Arithmetic/BarrettReduction/Generalized.v @@ -0,0 +1,140 @@ +(*** Barrett Reduction *) +(** This file implements a slightly-generalized version of Barrett + Reduction on [Z]. This version follows a middle path between the + Handbook of Applied Cryptography (Algorithm 14.42) and Wikipedia. + We split up the shifting and the multiplication so that we don't + need to store numbers that are quite so large, but we don't do + early reduction modulo [b^(k+offset)] (we generalize from HAC's [k + ± 1] to [k ± offset]). This leads to weaker conditions on the + base ([b]), exponent ([k]), and the [offset] than those given in + the HAC. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.BreakMatch. + +Local Open Scope Z_scope. + +Section barrett. + Context (n a : Z) + (n_reasonable : n <> 0). + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *) + (** In modular arithmetic, Barrett reduction is a reduction + algorithm introduced in 1986 by P.D. Barrett. A naive way of + computing *) + (** [c = a mod n] *) + (** would be to use a fast division algorithm. Barrett reduction is + an algorithm designed to optimize this operation assuming [n] is + constant, and [a < n²], replacing divisions by + multiplications. *) + + (** * General idea *) + Section general_idea. + (** Let [m = 1 / n] be the inverse of [n] as a floating point + number. Then *) + (** [a mod n = a - ⌊a m⌋ n] *) + (** where [⌊ x ⌋] denotes the floor function. The result is exact, + as long as [m] is computed with sufficient accuracy. *) + + (* [/] is [Z.div], which means truncated division *) + Local Notation "⌊am⌋" := (a / n) (only parsing). + + Theorem naive_barrett_reduction_correct + : a mod n = a - ⌊am⌋ * n. + Proof using n_reasonable. + apply Zmod_eq_full; assumption. + Qed. + End general_idea. + + (** * Barrett algorithm *) + Section barrett_algorithm. + (** Barrett algorithm is a fixed-point analog which expresses + everything in terms of integers. Let [k] be the smallest + integer such that [2ᵏ > n]. Think of [n] as representing the + fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m = + ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number + [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *) + (** N.B. We don't need [k] to be the smallest such integer. *) + (** N.B. We generalize to an arbitrary base. *) + (** N.B. We generalize from [k ± 1] to [k ± offset]. *) + Context (b : Z) + (base_good : 0 < b) + (k : Z) + (k_good : n < b ^ k) + (m : Z) + (m_good : m = b^(2*k) / n) (* [/] is [Z.div], which is truncated *) + (offset : Z) + (offset_nonneg : 0 <= offset). + (** Wikipedia neglects to mention non-negativity, but we need it. + It might be possible to do with a relaxed assumption, such as + the sign of [a] and the sign of [n] being the same; but I + figured it wasn't worth it. *) + Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *) + (a_nonneg : 0 <= a). + + Context (k_big_enough : offset <= k) + (a_small : a < b^(2*k)) + (** We also need that [n] is large enough; [n] larger than + [bᵏ⁻¹] works, but we ask for something more precise. *) + (n_large : a mod b^(k-offset) <= n). + + (** Now *) + + Let q := (m * (a / b^(k-offset))) / b^(k+offset). + Let r := a - q * n. + (** Because of the floor function (in Coq, because [/] means + truncated division), [q] is an integer and [r ≡ a mod n]. *) + Theorem barrett_reduction_equivalent + : r mod n = a mod n. + Proof using m_good offset. + subst r q m. + rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption. + reflexivity. + Qed. + + Lemma qn_small + : q * n <= a. + Proof using a_nonneg a_small base_good k_big_enough m_good n_pos n_reasonable offset_nonneg. + subst q r m. + assert (0 < b^(k-offset)). zero_bounds. + assert (0 < b^(k+offset)) by zero_bounds. + assert (0 < b^(2 * k)) by zero_bounds. + Z.simplify_fractions_le. + autorewrite with pull_Zpow pull_Zdiv zsimplify; reflexivity. + Qed. + + Lemma q_nice : { b : bool * bool | q = a / n + (if fst b then -1 else 0) + (if snd b then -1 else 0) }. + Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg. + assert (0 < b^(k+offset)) by zero_bounds. + assert (0 < b^(k-offset)) by zero_bounds. + assert (a / b^(k-offset) <= b^(2*k) / b^(k-offset)) by auto with zarith lia. + assert (a / b^(k-offset) <= b^(k+offset)) by (autorewrite with pull_Zpow zsimplify in *; assumption). + subst q r m. + rewrite (Z.div_mul_diff_exact''' (b^(2*k)) n (a/b^(k-offset))) by auto with lia zero_bounds. + rewrite (Z_div_mod_eq (b^(2*k) * _ / n) (b^(k+offset))) by lia. + autorewrite with push_Zmul push_Zopp zsimplify zstrip_div zdiv_to_mod. + rewrite Z.div_sub_mod_cond, !Z.div_sub_small by auto with zero_bounds zarith. + eexists (_, _); reflexivity. + Qed. + + Lemma r_small : r < 3 * n. + Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg q. + Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div. + assert (a mod n < n) by auto with zarith lia. + unfold r; rewrite (proj2_sig q_nice); generalize (proj1_sig q_nice); intro; subst q m. + autorewrite with push_Zmul zsimplify zstrip_div. + break_match; auto with lia. + Qed. + + (** In that case, we have *) + Theorem barrett_reduction_small + : a mod n = let r := if r <? n then r else r-n in + let r := if r <? n then r else r-n in + r. + Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg q. + pose proof r_small. pose proof qn_small. cbv zeta. + destruct (r <? n) eqn:Hr, (r-n <? n) eqn:?; try rewrite Hr; Z.ltb_to_lt; try lia. + { symmetry; apply (Zmod_unique a n q); subst r; lia. } + { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. } + { symmetry; apply (Zmod_unique a n (q + 2)); subst r; lia. } + Qed. + End barrett_algorithm. +End barrett. diff --git a/src/Arithmetic/BarrettReduction/HAC.v b/src/Arithmetic/BarrettReduction/HAC.v new file mode 100644 index 000000000..70661ee96 --- /dev/null +++ b/src/Arithmetic/BarrettReduction/HAC.v @@ -0,0 +1,158 @@ +(*** Barrett Reduction *) +(** This file implements a slightly-generalized version of Barrett + Reduction on [Z]. This version follows the Handbook of Applied + Cryptography (Algorithm 14.42) rather closely; the only deviations + are that we generalize from [k ± 1] to [k ± offset] for an + arbitrary offset, and we weaken the conditions on the base [b] in + [bᵏ] slightly. Contrasted with some other versions, this version + does reduction modulo [b^(k+offset)] early (ensuring that we don't + have to carry around extra precision), but requires more stringint + conditions on the base ([b]), exponent ([k]), and the [offset]. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.BreakMatch. + +Local Open Scope Z_scope. + +Section barrett. + (** Quoting the Handbook of Applied Cryptography <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>: *) + (** Barrett reduction (Algorithm 14.42) computes [r = x mod m] given + [x] and [m]. The algorithm requires the precomputation of the + quantity [µ = ⌊b²ᵏ/m⌋]; it is advantageous if many reductions + are performed with a single modulus. For example, each RSA + encryption for one entity requires reduction modulo that + entity’s public key modulus. The precomputation takes a fixed + amount of work, which is negligible in comparison to modular + exponentiation cost. Typically, the radix [b] is chosen to be + close to the word-size of the processor. Hence, assume [b > 3] in + Algorithm 14.42 (see Note 14.44 (ii)). *) + + (** * Barrett modular reduction *) + Section barrett_modular_reduction. + Context (m b x k μ offset : Z) + (m_pos : 0 < m) + (base_pos : 0 < b) + (k_good : m < b^k) + (μ_good : μ = b^(2*k) / m) (* [/] is [Z.div], which is truncated *) + (x_nonneg : 0 <= x) + (offset_nonneg : 0 <= offset) + (k_big_enough : offset <= k) + (x_small : x < b^(2*k)) + (m_small : 3 * m <= b^(k+offset)) + (** We also need that [m] is large enough; [m] larger than + [bᵏ⁻¹] works, but we ask for something more precise. *) + (m_large : x mod b^(k-offset) <= m). + + Let q1 := x / b^(k-offset). Let q2 := q1 * μ. Let q3 := q2 / b^(k+offset). + Let r1 := x mod b^(k+offset). Let r2 := (q3 * m) mod b^(k+offset). + (** At this point, the HAC says "If [r < 0] then [r ← r + bᵏ⁺¹]". + This is equivalent to reduction modulo [b^(k+offset)], as we + prove below. The version involving modular reduction has the + benefit of being cheaper to implement, and making the proofs + simpler, so we primarily use that version. *) + Let r_mod_3m := (r1 - r2) mod b^(k+offset). + Let r_mod_3m_orig := let r := r1 - r2 in + if r <? 0 then r + b^(k+offset) else r. + + Lemma r_mod_3m_eq_orig : r_mod_3m = r_mod_3m_orig. + Proof using base_pos k_big_enough m_pos m_small offset_nonneg r1 r2. + assert (0 <= r1 < b^(k+offset)) by (subst r1; auto with zarith). + assert (0 <= r2 < b^(k+offset)) by (subst r2; auto with zarith). + subst r_mod_3m r_mod_3m_orig; cbv zeta. + break_match; Z.ltb_to_lt. + { symmetry; apply (Zmod_unique (r1 - r2) _ (-1)); lia. } + { symmetry; apply (Zmod_unique (r1 - r2) _ 0); lia. } + Qed. + + (** 14.43 Fact By the division algorithm (Definition 2.82), there + exist integers [Q] and [R] such that [x = Qm + R] and [0 ≤ R < + m]. In step 1 of Algorithm 14.42 (Barrett modular reduction), + the following inequality is satisfied: [Q - 2 ≤ q₃ ≤ Q]. *) + (** We prove this by providing a more useful form for [q₃]. *) + Let Q := x / m. + Let R := x mod m. + Lemma q3_nice : { b : bool * bool | q3 = Q + (if fst b then -1 else 0) + (if snd b then -1 else 0) }. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg x_nonneg x_small μ_good. + assert (0 < b^(k+offset)) by zero_bounds. + assert (0 < b^(k-offset)) by zero_bounds. + assert (x / b^(k-offset) <= b^(2*k) / b^(k-offset)) by auto with zarith lia. + assert (x / b^(k-offset) <= b^(k+offset)) by (autorewrite with pull_Zpow zsimplify in *; assumption). + subst q1 q2 q3 Q r_mod_3m r_mod_3m_orig r1 r2 R μ. + rewrite (Z.div_mul_diff_exact' (b^(2*k)) m (x/b^(k-offset))) by auto with lia zero_bounds. + rewrite (Z_div_mod_eq (_ * b^(2*k) / m) (b^(k+offset))) by lia. + autorewrite with push_Zmul push_Zopp zsimplify zstrip_div zdiv_to_mod. + rewrite Z.div_sub_mod_cond, !Z.div_sub_small; auto with zero_bounds zarith. + eexists (_, _); reflexivity. + Qed. + + Fact q3_in_range : Q - 2 <= q3 <= Q. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good. + rewrite (proj2_sig q3_nice). + break_match; lia. + Qed. + + (** 14.44 Note (partial justification of correctness of Barrett reduction) *) + (** (i) Algorithm 14.42 is based on the observation that [⌊x/m⌋] + can be written as [Q = + ⌊(x/bᵏ⁻¹)(b²ᵏ/m)(1/bᵏ⁺¹)⌋]. Moreover, [Q] can be + approximated by the quantity [q₃ = ⌊⌊x/bᵏ⁻¹⌋µ/bᵏ⁺¹⌋]. + Fact 14.43 guarantees that [q₃] is never larger than the + true quotient [Q], and is at most 2 smaller. *) + Lemma x_minus_q3_m_in_range : 0 <= x - q3 * m < 3 * m. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good. + pose proof q3_in_range. + assert (0 <= R < m) by (subst R; auto with zarith). + assert (0 <= (Q - q3) * m + R < 3 * m) by nia. + subst Q R; autorewrite with push_Zmul zdiv_to_mod in *; lia. + Qed. + + Lemma r_mod_3m_eq_alt : r_mod_3m = x - q3 * m. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good. + pose proof x_minus_q3_m_in_range. + subst r_mod_3m r_mod_3m_orig r1 r2. + autorewrite with pull_Zmod zsimplify; reflexivity. + Qed. + + (** This version uses reduction modulo [b^(k+offset)]. *) + Theorem barrett_reduction_equivalent + : r_mod_3m mod m = x mod m. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r1 r2 x_nonneg x_small μ_good. + rewrite r_mod_3m_eq_alt. + autorewrite with zsimplify push_Zmod; reflexivity. + Qed. + + (** This version, which matches the original in the HAC, uses + conditional addition of [b^(k+offset)]. *) + Theorem barrett_reduction_orig_equivalent + : r_mod_3m_orig mod m = x mod m. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r_mod_3m x_nonneg x_small μ_good. rewrite <- r_mod_3m_eq_orig; apply barrett_reduction_equivalent. Qed. + + Lemma r_small : 0 <= r_mod_3m < 3 * m. + Proof using Q R base_pos k_big_enough m_large m_pos m_small offset_nonneg q3 x_nonneg x_small μ_good. + pose proof x_minus_q3_m_in_range. + subst Q R r_mod_3m r_mod_3m_orig r1 r2. + autorewrite with pull_Zmod zsimplify; lia. + Qed. + + + (** This version uses reduction modulo [b^(k+offset)]. *) + Theorem barrett_reduction_small (r := r_mod_3m) + : x mod m = let r := if r <? m then r else r-m in + let r := if r <? m then r else r-m in + r. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r1 r2 x_nonneg x_small μ_good. + pose proof r_small. cbv zeta. + destruct (r <? m) eqn:Hr, (r-m <? m) eqn:?; subst r; rewrite !r_mod_3m_eq_alt, ?Hr in *; Z.ltb_to_lt; try lia. + { symmetry; eapply (Zmod_unique x m q3); lia. } + { symmetry; eapply (Zmod_unique x m (q3 + 1)); lia. } + { symmetry; eapply (Zmod_unique x m (q3 + 2)); lia. } + Qed. + + (** This version, which matches the original in the HAC, uses + conditional addition of [b^(k+offset)]. *) + Theorem barrett_reduction_small_orig (r := r_mod_3m_orig) + : x mod m = let r := if r <? m then r else r-m in + let r := if r <? m then r else r-m in + r. + Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r_mod_3m x_nonneg x_small μ_good. subst r; rewrite <- r_mod_3m_eq_orig; apply barrett_reduction_small. Qed. + End barrett_modular_reduction. +End barrett. diff --git a/src/Arithmetic/BarrettReduction/Wikipedia.v b/src/Arithmetic/BarrettReduction/Wikipedia.v new file mode 100644 index 000000000..69ce10c4b --- /dev/null +++ b/src/Arithmetic/BarrettReduction/Wikipedia.v @@ -0,0 +1,122 @@ +(*** Barrett Reduction *) +(** This file implements Barrett Reduction on [Z]. We follow Wikipedia. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. + +Local Open Scope Z_scope. + +Section barrett. + Context (n a : Z) + (n_reasonable : n <> 0). + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *) + (** In modular arithmetic, Barrett reduction is a reduction + algorithm introduced in 1986 by P.D. Barrett. A naive way of + computing *) + (** [c = a mod n] *) + (** would be to use a fast division algorithm. Barrett reduction is + an algorithm designed to optimize this operation assuming [n] is + constant, and [a < n²], replacing divisions by + multiplications. *) + + (** * General idea *) + Section general_idea. + (** Let [m = 1 / n] be the inverse of [n] as a floating point + number. Then *) + (** [a mod n = a - ⌊a m⌋ n] *) + (** where [⌊ x ⌋] denotes the floor function. The result is exact, + as long as [m] is computed with sufficient accuracy. *) + + (* [/] is [Z.div], which means truncated division *) + Local Notation "⌊am⌋" := (a / n) (only parsing). + + Theorem naive_barrett_reduction_correct + : a mod n = a - ⌊am⌋ * n. + Proof using n_reasonable. + apply Zmod_eq_full; assumption. + Qed. + End general_idea. + + (** * Barrett algorithm *) + Section barrett_algorithm. + (** Barrett algorithm is a fixed-point analog which expresses + everything in terms of integers. Let [k] be the smallest + integer such that [2ᵏ > n]. Think of [n] as representing the + fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m = + ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number + [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *) + (** N.B. We don't need [k] to be the smallest such integer. *) + Context (k : Z) + (k_good : n < 2 ^ k) + (m : Z) + (m_good : m = 4^k / n). (* [/] is [Z.div], which is truncated *) + (** Wikipedia neglects to mention non-negativity, but we need it. + It might be possible to do with a relaxed assumption, such as + the sign of [a] and the sign of [n] being the same; but I + figured it wasn't worth it. *) + Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *) + (a_nonneg : 0 <= a). + + Lemma k_nonnegative : 0 <= k. + Proof using Type*. + destruct (Z_lt_le_dec k 0); try assumption. + rewrite !Z.pow_neg_r in * by lia; lia. + Qed. + + (** Now *) + Let q := (m * a) / 4^k. + Let r := a - q * n. + (** Because of the floor function (in Coq, because [/] means + truncated division), [q] is an integer and [r ≡ a mod n]. *) + Theorem barrett_reduction_equivalent + : r mod n = a mod n. + Proof using m_good. + subst r q m. + rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption. + reflexivity. + Qed. + + Lemma qn_small + : q * n <= a. + Proof using a_nonneg k_good m_good n_pos n_reasonable. + pose proof k_nonnegative; subst q r m. + assert (0 <= 2^(k-1)) by zero_bounds. + Z.simplify_fractions_le. + Qed. + + (** Also, if [a < n²] then [r < 2n]. *) + (** N.B. It turns out that it is sufficient to assume [a < 4ᵏ]. *) + Context (a_small : a < 4^k). + Lemma q_nice : { b : bool | q = a / n + if b then -1 else 0 }. + Proof using a_nonneg a_small k_good m_good n_pos n_reasonable. + assert (0 <= (4 ^ k * a / n) mod 4 ^ k < 4 ^ k) by auto with zarith lia. + assert (0 <= a * (4 ^ k mod n) / n < 4 ^ k) by (auto with zero_bounds zarith lia). + subst q r m. + rewrite (Z.div_mul_diff_exact''' (4^k) n a) by lia. + rewrite (Z_div_mod_eq (4^k * _ / n) (4^k)) by lia. + autorewrite with push_Zmul push_Zopp zsimplify zstrip_div. + eexists; reflexivity. + Qed. + + Lemma r_small : r < 2 * n. + Proof using a_nonneg a_small k_good m_good n_pos n_reasonable q. + Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div. + assert (a mod n < n) by auto with zarith lia. + unfold r; rewrite (proj2_sig q_nice); generalize (proj1_sig q_nice); intro; subst q m. + autorewrite with push_Zmul zsimplify zstrip_div. + break_match; auto with lia. + Qed. + + (** In that case, we have *) + Theorem barrett_reduction_small + : a mod n = if r <? n + then r + else r - n. + Proof using a_nonneg a_small k_good m_good n_pos n_reasonable q. + pose proof r_small. pose proof qn_small. + destruct (r <? n) eqn:rlt; Z.ltb_to_lt. + { symmetry; apply (Zmod_unique a n q); subst r; lia. } + { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. } + Qed. + End barrett_algorithm. +End barrett. diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v new file mode 100644 index 000000000..2613765d0 --- /dev/null +++ b/src/Arithmetic/Core.v @@ -0,0 +1,980 @@ +(***** + +This file provides a generalized version of arithmetic with "mixed +radix" numerical systems. Later, parameters are entered into the +general functions, and they are partially evaluated until only runtime +basic arithmetic operations remain. + +CPS +--- + +Fuctions are written in continuation passing style (CPS). This means +that each operation is passed a "continuation" function, which it is +expected to call on its own output (like a callback). See the end of +this comment for a motivating example explaining why we do CPS, +despite a fair amount of resulting boilerplate code for each +operation. The code block for an operation called A would look like +this: + +``` +Definition A_cps x y {T} f : T := ... + +Definition A x y := A_cps x y id. +Lemma A_cps_id x y : forall {T} f, @A_cps x y T f = f (A x y). +Hint Opaque A : uncps. +Hint Rewrite A_cps_id : uncps. + +Lemma eval_A x y : eval (A x y) = ... +Hint Rewrite eval_A : push_basesystem_eval. +``` + +`A_cps` is the main, CPS-style definition of the operation (`f` is the +continuation function). `A` is the non-CPS version of `A_cps`, simply +defined by passing an identity function to `A_cps`. `A_cps_id` states +that we can replace the CPS version with the non-cps version. `eval_A` +is the actual correctness lemma for the operation, stating that it has +the correct arithmetic properties. In general, the middle block +containing `A` and `A_cps_id` is boring boilerplate and can be safely +ignored. + +HintDbs +------- + ++ `uncps` : Converts CPS operations to their non-CPS versions. ++ `push_basesystem_eval` : Contains all the correctness lemmas for + operations in this file, which are in terms of the `eval` function. + +Positional/Associational +------------------------ + +We represent mixed-radix numbers in a few different ways: + ++ "Positional" : a tuple of numbers and a weight function (nat->Z), +which is evaluated by multiplying the `i`th element of the tuple by +`weight i`, and then summing the products. ++ "Associational" : a list of pairs of numbers--the first is the +weight, the second is the runtime value. Evaluated by multiplying each +pair and summing the products. + +The associational representation is good for basic operations like +addition and multiplication; for addition, one can simply just append +two associational lists. But the end-result code should use the +positional representation (with each digit representing a machine +word). Since converting to and fro can be easily compiled away once +the weight function is known, we use associational to write most of +the operations and liberally convert back and forth to ensure correct +output. In particular, it is important to convert before carrying. + +Runtime Operations +------------------ + +Since some instances of e.g. Z.add or Z.mul operate on (compile-time) +weights, and some operate on runtime values, we need a way to +differentiate these cases before partial evaluation. We define a +runtime_scope to mark certain additions/multiplications as runtime +values, so they will not be unfolded during partial evaluation. For +instance, if we have: + +``` +Definition f (x y : Z * Z) := (fst x + fst y, (snd x + snd y)%RT). +``` + +then when we are partially evaluating `f`, we can easily exclude the +runtime operations (`cbv - [runtime_add]`) and prevent Coq from trying +to simplify the second addition. + + +Why CPS? +-------- + +Let's suppose we want to add corresponding elements of two `list Z`s +(so on inputs `[1,2,3]` and `[2,3,1]`, we get `[3,5,4]`). We might +write our function like this : + +``` +Fixpoint add_lists (p q : list Z) := + match p, q with + | p0 :: p', q0 :: q' => + dlet sum := p0 + q0 in + sum :: add_lists p' q' + | _, _ => nil + end. +``` + +(Note : `dlet` is a notation for `Let_In`, which is just a dumb +wrapper for `let`. This allows us to `cbv - [Let_In]` if we want to +not simplify certain `let`s.) + +A CPS equivalent of `add_lists` would look like this: + +``` +Fixpoint add_lists_cps (p q : list Z) {T} (f:list Z->T) := + match p, q with + | p0 :: p', q0 :: q' => + dlet sum := p0 + q0 in + add_lists_cps p' q' (fun r => f (sum :: r)) + | _, _ => f nil + end. +``` + +Now let's try some partial evaluation. The expression we'll evaluate is: + +``` +Definition x := + (fun a0 a1 a2 b0 b1 b2 => + let r := add_lists [a0;a1;a2] [b0;b1;b2] in + let rr := add_lists r r in + add_lists rr rr). +``` + +Or, using `add_lists_cps`: + +``` +Definition y := + (fun a0 a1 a2 b0 b1 b2 => + add_lists_cps [a0;a1;a2] [b0;b1;b2] + (fun r => add_lists_cps r r + (fun rr => add_lists_cps rr rr id))). +``` + +If we run `Eval cbv -[Z.add] in x` and `Eval cbv -[Z.add] in y`, we get +identical output: + +``` +fun a0 a1 a2 b0 b1 b2 : Z => + [a0 + b0 + (a0 + b0) + (a0 + b0 + (a0 + b0)); + a1 + b1 + (a1 + b1) + (a1 + b1 + (a1 + b1)); + a2 + b2 + (a2 + b2) + (a2 + b2 + (a2 + b2))] +``` + +However, there are a lot of common subexpressions here--this is what +the `dlet` we put into the functions should help us avoid. Let's try +`Eval cbv -[Let_In Z.add] in x`: + +``` +fun a0 a1 a2 b0 b1 b2 : Z => + (fix add_lists (p q : list Z) {struct p} : + list Z := + match p with + | [] => [] + | p0 :: p' => + match q with + | [] => [] + | q0 :: q' => + dlet sum := p0 + q0 in + sum :: add_lists p' q' + end + end) + ((fix add_lists (p q : list Z) {struct p} : + list Z := + match p with + | [] => [] + | p0 :: p' => + match q with + | [] => [] + | q0 :: q' => + dlet sum := p0 + q0 in + sum :: add_lists p' q' + end + end) + (dlet sum := a0 + b0 in + sum + :: (dlet sum0 := a1 + b1 in + sum0 :: (dlet sum1 := a2 + b2 in + [sum1]))) + (dlet sum := a0 + b0 in + sum + :: (dlet sum0 := a1 + b1 in + sum0 :: (dlet sum1 := a2 + b2 in + [sum1])))) + ((fix add_lists (p q : list Z) {struct p} : + list Z := + match p with + | [] => [] + | p0 :: p' => + match q with + | [] => [] + | q0 :: q' => + dlet sum := p0 + q0 in + sum :: add_lists p' q' + end + end) + (dlet sum := a0 + b0 in + sum + :: (dlet sum0 := a1 + b1 in + sum0 :: (dlet sum1 := a2 + b2 in + [sum1]))) + (dlet sum := a0 + b0 in + sum + :: (dlet sum0 := a1 + b1 in + sum0 :: (dlet sum1 := a2 + b2 in + [sum1])))) +``` + +Not so great. Because the `dlet`s are stuck in the inner terms, we +can't simplify the expression very nicely. Let's try that on the CPS +version (`Eval cbv -[Let_In Z.add] in y`): + +``` +fun a0 a1 a2 b0 b1 b2 : Z => + dlet sum := a0 + b0 in + dlet sum0 := a1 + b1 in + dlet sum1 := a2 + b2 in + dlet sum2 := sum + sum in + dlet sum3 := sum0 + sum0 in + dlet sum4 := sum1 + sum1 in + dlet sum5 := sum2 + sum2 in + dlet sum6 := sum3 + sum3 in + dlet sum7 := sum4 + sum4 in + [sum5; sum6; sum7] +``` + +Isn't that lovely? Since we can push continuation functions "under" +the `dlet`s, we can end up with a nice, concise, simplified +expression. + +One might suggest that we could just inline the `dlet`s and do common +subexpression elimination. But some of our terms have so many `dlet`s +that inlining them all would make a term too huge to process in +reasonable time, so this is not really an option. + +*****) + +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega. +Require Import Coq.ZArith.BinIntDef. +Local Open Scope Z_scope. + +Require Import Crypto.Algebra.Nsatz. +Require Import Crypto.Util.Decidable Crypto.Util.LetIn. +Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma. +Require Import Crypto.Util.CPSUtil Crypto.Util.Prod. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.Tactics.VM. + +Require Import Coq.Lists.List. Import ListNotations. +Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple. + +Local Ltac prove_id := + repeat match goal with + | _ => progress intros + | _ => progress simpl + | _ => progress cbv [Let_In] + | _ => progress (autorewrite with uncps push_id in * ) + | _ => break_innermost_match_step + | _ => contradiction + | _ => reflexivity + | _ => nsatz + | _ => solve [auto] + end. + +Create HintDb push_basesystem_eval discriminated. +Local Ltac prove_eval := + repeat match goal with + | _ => progress intros + | _ => progress simpl + | _ => progress cbv [Let_In] + | _ => progress (autorewrite with push_basesystem_eval uncps push_id cancel_pair in * ) + | _ => break_innermost_match_step + | _ => split + | H : _ /\ _ |- _ => destruct H + | H : Some _ = Some _ |- _ => progress (inversion H; subst) + | _ => discriminate + | _ => reflexivity + | _ => nsatz + end. + +Definition mod_eq (m:positive) a b := a mod m = b mod m. +Global Instance mod_eq_equiv m : RelationClasses.Equivalence (mod_eq m). +Proof. constructor; congruence. Qed. +Definition mod_eq_dec m a b : {mod_eq m a b} + {~ mod_eq m a b} + := Z.eq_dec _ _. +Lemma mod_eq_Z2F_iff m a b : + mod_eq m a b <-> Logic.eq (F.of_Z m a) (F.of_Z m b). +Proof. rewrite <-F.eq_of_Z_iff; reflexivity. Qed. + +Delimit Scope runtime_scope with RT. + +Definition runtime_mul := Z.mul. +Global Notation "a * b" := (runtime_mul a%RT b%RT) : runtime_scope. +Definition runtime_add := Z.add. +Global Notation "a + b" := (runtime_add a%RT b%RT) : runtime_scope. +Definition runtime_opp := Z.opp. +Global Notation "- a" := (runtime_opp a%RT) : runtime_scope. +Definition runtime_and := Z.land. +Global Notation "a &' b" := (runtime_and a%RT b%RT) : runtime_scope. +Definition runtime_shr := Z.shiftr. +Global Notation "a >> b" := (runtime_shr a%RT b%RT) : runtime_scope. + +Module B. + Definition limb := (Z*Z)%type. (* position coefficient and run-time value *) + Module Associational. + Definition eval (p:list limb) : Z := + List.fold_right Z.add 0%Z (List.map (fun t => fst t * snd t) p). + + Lemma eval_nil : eval nil = 0. Proof. reflexivity. Qed. + Lemma eval_cons p q : eval (p::q) = (fst p) * (snd p) + eval q. Proof. reflexivity. Qed. + Lemma eval_app p q: eval (p++q) = eval p + eval q. + Proof. induction p; simpl eval; rewrite ?eval_nil, ?eval_cons; nsatz. Qed. + Hint Rewrite eval_nil eval_cons eval_app : push_basesystem_eval. + + Definition multerm (t t' : limb) : limb := + (fst t * fst t', (snd t * snd t')%RT). + Lemma eval_map_multerm (a:limb) (q:list limb) + : eval (List.map (multerm a) q) = fst a * snd a * eval q. + Proof. + induction q; cbv [multerm]; simpl List.map; + autorewrite with push_basesystem_eval cancel_pair; nsatz. + Qed. Hint Rewrite eval_map_multerm : push_basesystem_eval. + + Definition mul_cps (p q:list limb) {T} (f : list limb->T) := + flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f. + + Definition mul (p q:list limb) := mul_cps p q id. + Lemma mul_cps_id p q: forall {T} f, @mul_cps p q T f = f (mul p q). + Proof. cbv [mul_cps mul]; prove_id. Qed. + Hint Opaque mul : uncps. + Hint Rewrite mul_cps_id : uncps. + + Lemma eval_mul p q: eval (mul p q) = eval p * eval q. + Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed. + Hint Rewrite eval_mul : push_basesystem_eval. + + Fixpoint split_cps (s:Z) (xs:list limb) + {T} (f :list limb*list limb->T) := + match xs with + | nil => f (nil, nil) + | cons x xs' => + split_cps s xs' + (fun sxs' => + if dec (fst x mod s = 0) + then f (fst sxs', cons (fst x / s, snd x) (snd sxs')) + else f (cons x (fst sxs'), snd sxs')) + end. + + Definition split s xs := split_cps s xs id. + Lemma split_cps_id s p: forall {T} f, + @split_cps s p T f = f (split s p). + Proof. + induction p; + repeat match goal with + | _ => rewrite IHp + | _ => progress (cbv [split]; prove_id) + end. + Qed. + Hint Opaque split : uncps. + Hint Rewrite split_cps_id : uncps. + + Lemma eval_split s p (s_nonzero:s<>0): + eval (fst (split s p)) + s*eval (snd (split s p)) = eval p. + Proof. + cbv [split]; induction p; prove_eval. + match goal with + H:_ |- _ => + unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H) + end; nsatz. + Qed. Hint Rewrite @eval_split using auto : push_basesystem_eval. + + Definition reduce_cps (s:Z) (c:list limb) (p:list limb) + {T} (f : list limb->T) := + split_cps s p + (fun ab => mul_cps c (snd ab) + (fun rr =>f (fst ab ++ rr))). + + Definition reduce s c p := reduce_cps s c p id. + Lemma reduce_cps_id s c p {T} f: + @reduce_cps s c p T f = f (reduce s c p). + Proof. cbv [reduce_cps reduce]; prove_id. Qed. + Hint Opaque reduce : uncps. + Hint Rewrite reduce_cps_id : uncps. + + Lemma reduction_rule a b s c m (m_eq:Z.pos m = s - c): + (a + s * b) mod m = (a + c * b) mod m. + Proof. + rewrite m_eq. pose proof (Pos2Z.is_pos m). + replace (a + s * b) with ((a + c*b) + b*(s-c)) by ring. + rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod by omega. + trivial. + Qed. + Lemma eval_reduce s c p (s_nonzero:s<>0) m (m_eq : Z.pos m = s - eval c) : + mod_eq m (eval (reduce s c p)) (eval p). + Proof. + cbv [reduce reduce_cps mod_eq]; prove_eval. + erewrite <-reduction_rule by eauto; prove_eval. + Qed. + Hint Rewrite eval_reduce using (omega || assumption) : push_basesystem_eval. + (* Why TF does this hint get picked up outside the section (while other eval_ hints do not?) *) + + + Definition negate_snd_cps (p:list limb) {T} (f:list limb ->T) := + map_cps (fun cx => (fst cx, (-snd cx)%RT)) p f. + + Definition negate_snd p := negate_snd_cps p id. + Lemma negate_snd_id p {T} f : @negate_snd_cps p T f = f (negate_snd p). + Proof. cbv [negate_snd_cps negate_snd]; prove_id. Qed. + Hint Opaque negate_snd : uncps. + Hint Rewrite negate_snd_id : uncps. + + Lemma eval_negate_snd p : eval (negate_snd p) = - eval p. + Proof. + cbv [negate_snd_cps negate_snd]; induction p; prove_eval. + Qed. Hint Rewrite eval_negate_snd : push_basesystem_eval. + + Section Carries. + Context {modulo div:Z->Z->Z}. + Context {div_mod : forall a b:Z, b <> 0 -> + a = b * (div a b) + modulo a b}. + + Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) := + if dec (fst t = w) + then dlet t2 := snd t in + f ((w*fw, div t2 fw) :: (w, modulo t2 fw) :: @nil limb) + else f [t]. + + Definition carryterm w fw t := carryterm_cps w fw t id. + Lemma carryterm_cps_id w fw t {T} f : + @carryterm_cps w fw t T f + = f (@carryterm w fw t). + Proof using Type. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed. + Hint Opaque carryterm : uncps. + Hint Rewrite carryterm_cps_id : uncps. + + + Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0): + eval (carryterm w fw t) = eval [t]. + Proof using Type*. + cbv [carryterm_cps carryterm Let_In]; prove_eval. + specialize (div_mod (snd t) fw fw_nonzero). + nsatz. + Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval. + + Definition carry_cps (w fw:Z) (p:list limb) {T} (f:list limb->T) := + flat_map_cps (carryterm_cps w fw) p f. + + Definition carry w fw p := carry_cps w fw p id. + Lemma carry_cps_id w fw p {T} f: + @carry_cps w fw p T f = f (carry w fw p). + Proof using Type. cbv [carry_cps carry]; prove_id. Qed. + Hint Opaque carry : uncps. + Hint Rewrite carry_cps_id : uncps. + + Lemma eval_carry w fw p (fw_nonzero:fw<>0): + eval (carry w fw p) = eval p. + Proof using Type*. cbv [carry_cps carry]; induction p; prove_eval. Qed. + Hint Rewrite eval_carry using auto : push_basesystem_eval. + End Carries. + + End Associational. + Hint Rewrite + @Associational.carry_cps_id + @Associational.carryterm_cps_id + @Associational.reduce_cps_id + @Associational.split_cps_id + @Associational.mul_cps_id : uncps. + + Module Positional. + Section Positional. + Import Associational. + Context (weight : nat -> Z) (* [weight i] is the weight of position [i] *) + (weight_0 : weight 0%nat = 1%Z) + (weight_nonzero : forall i, weight i <> 0). + + (** Converting from positional to associational *) + Definition to_associational_cps {n:nat} (xs:tuple Z n) + {T} (f:list limb->T) := + map_cps weight (seq 0 n) + (fun r => + to_list_cps n xs (fun rr => combine_cps r rr f)). + + Definition to_associational {n} xs := + @to_associational_cps n xs _ id. + Lemma to_associational_cps_id {n} x {T} f: + @to_associational_cps n x T f = f (to_associational x). + Proof using Type. cbv [to_associational_cps to_associational]; prove_id. Qed. + Hint Opaque to_associational : uncps. + Hint Rewrite @to_associational_cps_id : uncps. + + Definition eval {n} x := + @to_associational_cps n x _ Associational.eval. + + Lemma eval_to_associational {n} x : + Associational.eval (@to_associational n x) = eval x. + Proof using Type. + cbv [to_associational_cps eval to_associational]; prove_eval. + Qed. Hint Rewrite @eval_to_associational : push_basesystem_eval. + + (** (modular) equality that tolerates redundancy **) + Definition eq {sz} m (a b : tuple Z sz) : Prop := + mod_eq m (eval a) (eval b). + + (** Converting from associational to positional *) + + Definition zeros n : tuple Z n := Tuple.repeat 0 n. + Lemma eval_zeros n : eval (zeros n) = 0. + Proof using Type. + cbv [eval Associational.eval to_associational_cps zeros]. + pose proof (seq_length n 0). generalize dependent (seq 0 n). + intro xs; revert n; induction xs; intros; + [autorewrite with uncps; reflexivity|]. + intros; destruct n; [distr_length|]. + specialize (IHxs n). autorewrite with uncps in *. + rewrite !@Tuple.to_list_repeat in *. + simpl List.repeat. rewrite map_cons, combine_cons, map_cons. + simpl fold_right. rewrite IHxs by distr_length. ring. + Qed. Hint Rewrite eval_zeros : push_basesystem_eval. + + Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) := + @on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f. + + Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id. + Lemma add_to_nth_cps_id {n} i x xs {T} f: + @add_to_nth_cps n i x xs T f = f (add_to_nth i x xs). + Proof using weight. + cbv [add_to_nth_cps add_to_nth]; erewrite !on_tuple_cps_correct + by (intros; autorewrite with uncps; reflexivity); prove_id. + Unshelve. + intros; subst. autorewrite with uncps push_id. distr_length. + Qed. + Hint Opaque add_to_nth : uncps. + Hint Rewrite @add_to_nth_cps_id : uncps. + + Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i<n)%nat) (xs:tuple Z n): + eval (@add_to_nth n i x xs) = weight i * x + eval xs. + Proof using Type. + cbv [eval to_associational_cps add_to_nth add_to_nth_cps runtime_add]. + erewrite on_tuple_cps_correct by (intros; autorewrite with uncps; reflexivity). + prove_eval. + cbv [Tuple.on_tuple]. + rewrite !Tuple.to_list_from_list. + autorewrite with uncps push_id. + rewrite ListUtil.combine_update_nth_r at 1. + rewrite <-(update_nth_id i (List.combine _ _)) at 2. + rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _ (weight 0, 0)); cbv [ListUtil.splice_nth id]; + repeat match goal with + | _ => progress (apply Zminus_eq; ring_simplify) + | _ => progress autorewrite with push_basesystem_eval cancel_pair distr_length + | _ => progress rewrite <-?ListUtil.map_nth_default_always, ?map_fst_combine, ?List.firstn_all2, ?ListUtil.map_nth_default_always, ?nth_default_seq_inbouns, ?plus_O_n + end; trivial; lia. + Unshelve. + intros; subst. autorewrite with uncps push_id. distr_length. + Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval. + + Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) := + if dec (fst t mod weight i = 0) + then f (i, let c := fst t / weight i in (c * snd t)%RT) + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + + Definition place t i := place_cps t i id. + Lemma place_cps_id t i {T} f : + @place_cps t i T f = f (place t i). + Proof using Type. cbv [place]; induction i; prove_id. Qed. + Hint Opaque place : uncps. + Hint Rewrite place_cps_id : uncps. + + Lemma place_cps_in_range (t:limb) (n:nat) + : (fst (place_cps t n id) < S n)%nat. + Proof using Type. induction n; simpl; break_match; simpl; omega. Qed. + Lemma weight_place_cps t i + : weight (fst (place_cps t i id)) * snd (place_cps t i id) + = fst t * snd t. + Proof using Type*. + induction i; cbv [id]; simpl place_cps; break_match; + autorewrite with cancel_pair; + try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end; + nsatz || auto. + Qed. + + Definition from_associational_cps n (p:list limb) + {T} (f:tuple Z n->T):= + fold_right_cps + (fun t st => + place_cps t (pred n) + (fun p=> add_to_nth_cps (fst p) (snd p) st id)) + (zeros n) p f. + + Definition from_associational n p := from_associational_cps n p id. + Lemma from_associational_cps_id {n} p {T} f: + @from_associational_cps n p T f = f (from_associational n p). + Proof using Type. + cbv [from_associational_cps from_associational]; prove_id. + Qed. + Hint Opaque from_associational : uncps. + Hint Rewrite @from_associational_cps_id : uncps. + + Lemma eval_from_associational {n} p (n_nonzero:n<>O): + eval (from_associational n p) = Associational.eval p. + Proof using Type*. + cbv [from_associational_cps from_associational]; induction p; + [|pose proof (place_cps_in_range a (pred n))]; prove_eval. + cbv [place]; rewrite weight_place_cps. nsatz. + Qed. + Hint Rewrite @eval_from_associational using omega + : push_basesystem_eval. + + Section Carries. + Context {modulo div : Z->Z->Z}. + Context {div_mod : forall a b:Z, b <> 0 -> + a = b * (div a b) + modulo a b}. + Definition carry_cps {n m} (index:nat) (p:tuple Z n) + {T} (f:tuple Z m->T) := + to_associational_cps p + (fun P => @Associational.carry_cps + modulo div + (weight index) + (weight (S index) / weight index) + P T + (fun R => from_associational_cps m R f)). + + Definition carry {n m} i p := @carry_cps n m i p _ id. + Lemma carry_cps_id {n m} i p {T} f: + @carry_cps n m i p T f = f (carry i p). + Proof. + cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity. + Qed. + Hint Opaque carry : uncps. Hint Rewrite @carry_cps_id : uncps. + + Lemma eval_carry {n m} i p: (n <> 0%nat) -> (m <> 0%nat) -> + weight (S i) / weight i <> 0 -> + eval (carry (n:=n) (m:=m) i p) = eval p. + Proof. + cbv [carry_cps carry]; intros. prove_eval. + rewrite @eval_carry by eauto. + apply eval_to_associational. + Qed. + Hint Rewrite @eval_carry : push_basesystem_eval. + + (* N.B. It is important to reverse [idxs] here. Like + [fold_right], [fold_right_cps2] is written such that the first + terms in the list are actually used last in the computation. For + example, running: + + `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).` + + will produce [fun a b c d => (a + (b + (c + d)))].*) + Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat) + {T} (f:tuple Z n->T) := + fold_right_cps2 carry_cps p (rev idxs) f. + + Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id. + Lemma chained_carries_id {n} p idxs : forall {T} f, + @chained_carries_cps n p idxs T f = f (chained_carries p idxs). + Proof using Type. cbv [chained_carries_cps chained_carries]; prove_id. Qed. + Hint Opaque chained_carries : uncps. + Hint Rewrite @chained_carries_id : uncps. + + Lemma eval_chained_carries {n} (p:tuple Z n) idxs : + (forall i, In i idxs -> weight (S i) / weight i <> 0) -> + eval (chained_carries p idxs) = eval p. + Proof using Type*. + cbv [chained_carries chained_carries_cps]; intros; + autorewrite with uncps push_id. + apply fold_right_invariant; [|intro; rewrite <-in_rev]; + destruct n; prove_eval; auto. + Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval. + + (* Reverse of [eval]; ranslate from Z to basesystem by putting + everything in first digit and then carrying. This function, like + [eval], is not defined using CPS. *) + Definition encode {n} (x : Z) : tuple Z n := + chained_carries (from_associational n [(1,x)]) (seq 0 n). + Lemma eval_encode {n} x : (n <> 0%nat) -> + (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) -> + eval (@encode n x) = x. + Proof using Type*. cbv [encode]; intros; prove_eval; auto. Qed. + Hint Rewrite @eval_encode : push_basesystem_eval. + + End Carries. + + Section Wrappers. + (* Simple wrappers for Associational definitions; convert to + associational, do the operation, convert back. *) + + Definition add_cps {n} (p q : tuple Z n) {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => to_associational_cps q + (fun Q => from_associational_cps n (P++Q) f)). + + Definition mul_cps {n m} (p q : tuple Z n) {T} (f:tuple Z m->T) := + to_associational_cps p + (fun P => to_associational_cps q + (fun Q => Associational.mul_cps P Q + (fun PQ => from_associational_cps m PQ f))). + + Definition reduce_cps {m n} (s:Z) (c:list B.limb) (p : tuple Z m) + {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => Associational.reduce_cps s c P + (fun R => from_associational_cps n R f)). + + Definition carry_reduce_cps {n div modulo} + (s:Z) (c:list limb) (p : tuple Z n) + {T} (f: tuple Z n ->T) := + carry_cps (div:=div) (modulo:=modulo) (n:=n) (m:=S n) (pred n) p + (fun r => reduce_cps (m:=S n) (n:=n) s c r f). + + Definition negate_snd_cps {n} (p : tuple Z n) + {T} (f:tuple Z n->T) := + to_associational_cps p + (fun P => Associational.negate_snd_cps P + (fun R => from_associational_cps n R f)). + + End Wrappers. + Hint Unfold + Positional.add_cps + Positional.mul_cps + Positional.reduce_cps + Positional.carry_reduce_cps + Positional.negate_snd_cps + . + + Section Subtraction. + Context {m n} {coef : tuple Z n} + {coef_mod : mod_eq m (eval coef) 0}. + + Definition sub_cps (p q : tuple Z n) {T} (f:tuple Z n->T):= + add_cps coef p + (fun cp => negate_snd_cps q + (fun _q => add_cps cp _q f)). + + Definition sub p q := sub_cps p q id. + Lemma sub_id p q {T} f : @sub_cps p q T f = f (sub p q). + Proof using Type. cbv [sub_cps sub]; autounfold; prove_id. Qed. + Hint Opaque sub : uncps. + Hint Rewrite sub_id : uncps. + + Lemma eval_sub p q : mod_eq m (eval (sub p q)) (eval p - eval q). + Proof using Type*. + cbv [sub sub_cps]; autounfold; destruct n; prove_eval. + transitivity (eval coef + (eval p - eval q)). + { apply f_equal2; ring. } + { cbv [mod_eq] in *; rewrite Z.add_mod_full, coef_mod, Z.add_0_l, Zmod_mod. reflexivity. } + Qed. + + Definition opp_cps (p : tuple Z n) {T} (f:tuple Z n->T):= + sub_cps (zeros n) p f. + End Subtraction. + + (* Lemmas about converting to/from F. Will be useful in proving + that basesystem is isomorphic to F.commutative_ring_modulo.*) + Section F. + Context {sz:nat} {sz_nonzero : sz<>0%nat} {m :positive}. + Context (weight_divides : forall i : nat, weight (S i) / weight i <> 0). + Context {modulo div:Z->Z->Z} + {div_mod : forall a b:Z, b <> 0 -> + a = b * (div a b) + modulo a b}. + + Definition Fencode (x : F m) : tuple Z sz := + encode (div:=div) (modulo:=modulo) (F.to_Z x). + + Definition Fdecode (x : tuple Z sz) : F m := F.of_Z m (eval x). + + Lemma Fdecode_Fencode_id x : Fdecode (Fencode x) = x. + Proof using div_mod sz_nonzero weight_0 weight_divides weight_nonzero. + cbv [Fdecode Fencode]; rewrite @eval_encode by auto. + apply F.of_Z_to_Z. + Qed. + + Lemma eq_Feq_iff a b : + Logic.eq (Fdecode a) (Fdecode b) <-> eq m a b. + Proof using Type. cbv [Fdecode]; rewrite <-F.eq_of_Z_iff; reflexivity. Qed. + End F. + + + End Positional. + + (* Helper lemmas and definitions for [eval]; this needs to be in a + separate section so the weight function can change. *) + Section EvalHelpers. + Lemma eval_single wt (x:Z) : eval (n:=1) wt x = wt 0%nat * x. + Proof. cbv - [Z.mul Z.add]. ring. Qed. + + Lemma eval_step {n} (x:tuple Z n) : forall wt z, + eval wt (Tuple.append z x) = wt 0%nat * z + eval (fun i => wt (S i)) x. + Proof. + destruct n; [reflexivity|]. + intros; cbv [eval to_associational_cps]. + autorewrite with uncps. rewrite map_S_seq. reflexivity. + Qed. + + Lemma eval_wt_equiv {n} :forall wta wtb (x:tuple Z n), + (forall i, wta i = wtb i) -> eval wta x = eval wtb x. + Proof. + destruct n; [reflexivity|]. + induction n; intros; [rewrite !eval_single, H; reflexivity|]. + simpl tuple in *; destruct x. + change (t, z) with (Tuple.append (n:=S n) z t). + rewrite !eval_step. rewrite (H 0%nat). apply Group.cancel_left. + apply IHn; auto. + Qed. + + Definition eval_from {n} weight (offset:nat) (x : tuple Z n) : Z := + eval (fun i => weight (i+offset)%nat) x. + + Lemma eval_from_0 {n} wt x : @eval_from n wt 0 x = eval wt x. + Proof. cbv [eval_from]. auto using eval_wt_equiv. Qed. + End EvalHelpers. + + End Positional. + Hint Unfold + Positional.add_cps + Positional.mul_cps + Positional.reduce_cps + Positional.carry_reduce_cps + Positional.negate_snd_cps + Positional.opp_cps + . + Hint Rewrite + @Associational.carry_cps_id + @Associational.carryterm_cps_id + @Associational.reduce_cps_id + @Associational.split_cps_id + @Associational.mul_cps_id + @Positional.carry_cps_id + @Positional.from_associational_cps_id + @Positional.place_cps_id + @Positional.add_to_nth_cps_id + @Positional.to_associational_cps_id + @Positional.chained_carries_id + @Positional.sub_id + : uncps. + Hint Rewrite + @Associational.eval_mul + @Positional.eval_to_associational + @Associational.eval_carry + @Associational.eval_carryterm + @Associational.eval_reduce + @Associational.eval_split + @Positional.eval_zeros + @Positional.eval_carry + @Positional.eval_from_associational + @Positional.eval_add_to_nth + @Positional.eval_chained_carries + @Positional.eval_sub + using (assumption || vm_decide) : push_basesystem_eval. +End B. + +(* Modulo and div that do shifts if possible, otherwise normal mod/div *) +Section DivMod. + Definition modulo (a b : Z) : Z := + if dec (2 ^ (Z.log2 b) = b) + then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT + else Z.modulo a b. + + Definition div (a b : Z) : Z := + if dec (2 ^ (Z.log2 b) = b) + then let x := Z.log2 b in (a >> x)%RT + else Z.div a b. + + Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b. + Proof. + cbv [div modulo]; intros. break_match; auto using Z.div_mod. + rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg. + pose proof (Z.div_mod a b H). congruence. + Qed. +End DivMod. + +Import B. + +Ltac basesystem_partial_evaluation_RHS := + let t0 := match goal with |- _ _ ?t => t end in + let t := (eval cbv delta [ + (* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], [runtime_opp], [runtime_mul], [runtime_shr], or [runtime_and] *) +Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries_cps Positional.chained_carries Positional.sub_cps Positional.sub Positional.negate_snd_cps Positional.add_cps Positional.opp_cps Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry Associational.negate_snd_cps Associational.negate_snd div modulo + ] in t0) in + let t := (eval pattern @runtime_mul in t) in + let t := match t with ?t _ => t end in + let t := (eval pattern @runtime_add in t) in + let t := match t with ?t _ => t end in + let t := (eval pattern @runtime_opp in t) in + let t := match t with ?t _ => t end in + let t := (eval pattern @runtime_shr in t) in + let t := match t with ?t _ => t end in + let t := (eval pattern @runtime_and in t) in + let t := match t with ?t _ => t end in + let t := (eval pattern @Let_In in t) in + let t := match t with ?t _ => t end in + let t1 := fresh "t1" in + pose t as t1; + transitivity (t1 + (@Let_In) + (@runtime_and) + (@runtime_shr) + (@runtime_opp) + (@runtime_add) + (@runtime_mul)); + [replace_with_vm_compute t1; clear t1|reflexivity]. + +(** This block of tactic code works around bug #5434 + (https://coq.inria.fr/bugs/show_bug.cgi?id=5434), that + [vm_compute] breaks an invariant in pretyping/constr_matching.ml. + So we refresh all of the names in match statements in the goal by + crawling it. + + In particular, [replace_with_vm_compute] creates a [vm_compute]d + term which has anonymous binders where pretyping expects there to + be named binders. This shows up when you try to match on the + function (the branch statement of the match) with an Ltac pattern + like [(fun x : ?T => ?C)] rather than [(fun x : ?T => @?C x)]; we + use the former in reification to save the cost of many extra + invocations of [cbv beta]. Luckily, patterns like [(fun x : ?T => + @?C x)] don't trigger this anomaly, so we can walk the term, + fixing all match statements whose branches are functions whose + binder names were eaten by [vm_compute] (note that in a match, + every branch where the corresponding constructor takes arguments + is represented internally as a function (lambda term)). We fix + the match statements by pulling out the branch with the [@?] + pattern that doesn't trigger the anomaly, and then recreating the + match with a destructuring [let] that hasn't been through + [vm_compute], and therefore has name information that + constr_matching is happy with. *) +Ltac replace_match_with_destructuring_match T := + match T with + | ?F ?X + => let F' := replace_match_with_destructuring_match F in + let X' := replace_match_with_destructuring_match X in + constr:(F' X') + (* we must use [@?f a b] here and not [?f], or else we get an anomaly *) + | match ?d with pair a b => @?f a b end + => let d' := replace_match_with_destructuring_match d in + let T' := fresh in + constr:(let '(a, b) := d' in + match f a b with + | T' => ltac:(let v := (eval cbv beta delta [T'] in T') in + let v := replace_match_with_destructuring_match v in + exact v) + end) + | ?x => x + end. +Ltac do_replace_match_with_destructuring_match_in_goal := + let G := get_goal in + let G' := replace_match_with_destructuring_match G in + change G'. + +(* TODO : move *) +Lemma F_of_Z_opp {m} x : F.of_Z m (- x) = F.opp (F.of_Z m x). +Proof. + cbv [F.opp]; intros. rewrite F.to_Z_of_Z, <-Z.sub_0_l. + etransitivity; rewrite F.of_Z_mod; + [rewrite Z.opp_mod_mod|]; reflexivity. +Qed. + +Hint Rewrite <-@F.of_Z_add : pull_FofZ. +Hint Rewrite <-@F.of_Z_mul : pull_FofZ. +Hint Rewrite <-@F.of_Z_sub : pull_FofZ. +Hint Rewrite <-@F_of_Z_opp : pull_FofZ. + +Ltac F_mod_eq := + cbv [Positional.Fdecode]; autorewrite with pull_FofZ; + apply mod_eq_Z2F_iff. + +Ltac solve_op_mod_eq wt x := + transitivity (Positional.eval wt x); repeat autounfold; + [|autorewrite with uncps push_id push_basesystem_eval; + reflexivity]; + cbv [mod_eq]; apply f_equal2; [|reflexivity]; + apply f_equal; + basesystem_partial_evaluation_RHS; + do_replace_match_with_destructuring_match_in_goal. + +Ltac solve_op_F wt x := F_mod_eq; solve_op_mod_eq wt x. diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v new file mode 100644 index 000000000..0f20bb238 --- /dev/null +++ b/src/Arithmetic/Karatsuba.v @@ -0,0 +1,49 @@ +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Algebra.Nsatz. +Require Import Crypto.Util.ZUtil. +Local Open Scope Z_scope. + +Section Karatsuba. + Context {T : Type} (eval : T -> Z) + (sub : T -> T -> T) + (eval_sub : forall x y, eval (sub x y) = eval x - eval y) + (mul : T -> T -> T) + (eval_mul : forall x y, eval (mul x y) = eval x * eval y) + (add : T -> T -> T) + (eval_add : forall x y, eval (add x y) = eval x + eval y) + (scmul : Z -> T -> T) + (eval_scmul : forall c x, eval (scmul c x) = c * eval x) + (split : Z -> T -> T * T) + (eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval x) + . + + Definition karatsuba_mul s (x y : T) : T := + let xab := split s x in + let yab := split s y in + let xy0 := mul (fst xab) (fst yab) in + let xy2 := mul (snd xab) (snd yab) in + let xy1 := sub (mul (add (fst xab) (snd xab)) (add (fst yab) (snd yab))) (add xy2 xy0) in + add (add (scmul (s^2) xy2) (scmul s xy1)) xy0. + + Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) : + eval (karatsuba_mul s x y) = eval x * eval y. + Proof using Type*. cbv [karatsuba_mul]; repeat rewrite ?eval_sub, ?eval_mul, ?eval_add, ?eval_scmul. + rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. Qed. + + + Definition goldilocks_mul s (xs ys : T) : T := + let a_b := split s xs in + let c_d := split s ys in + let ac := mul (fst a_b) (fst c_d) in + (add (add ac (mul (snd a_b) (snd c_d))) + (scmul s (sub (mul (add (fst a_b) (snd a_b)) (add (fst c_d) (snd c_d))) ac))). + + Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper Z.modulo_equiv_modulo_Proper. + + Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys : + (eval (goldilocks_mul s xs ys)) mod p = (eval xs * eval ys) mod p. + Proof using Type*. cbv [goldilocks_mul]; Zmod_to_equiv_modulo. + repeat rewrite ?eval_mul, ?eval_add, ?eval_sub, ?eval_scmul, <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify. + setoid_rewrite s2_modp. + apply f_equal2; nsatz. Qed. +End Karatsuba. diff --git a/src/Arithmetic/ModularArithmeticPre.v b/src/Arithmetic/ModularArithmeticPre.v new file mode 100644 index 000000000..b27ffd16d --- /dev/null +++ b/src/Arithmetic/ModularArithmeticPre.v @@ -0,0 +1,139 @@ +Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.Numbers.BinNums Coq.ZArith.Zdiv Coq.ZArith.Znumtheory. +Require Import Coq.Logic.Eqdep_dec. +Require Import Coq.Logic.EqdepFacts. +Require Import Coq.omega.Omega. +Require Import Crypto.Util.NumTheoryUtil. +Require Export Crypto.Util.FixCoqMistakes. + +Lemma Z_mod_mod x m : x mod m = (x mod m) mod m. + symmetry. + destruct (BinInt.Z.eq_dec m 0). + - subst; rewrite !Zdiv.Zmod_0_r; reflexivity. + - apply BinInt.Z.mod_mod; assumption. +Qed. + +Lemma exist_reduced_eq: forall (m : Z) (a b : Z), a = b -> forall pfa pfb, + exist (fun z : Z => z = z mod m) a pfa = + exist (fun z : Z => z = z mod m) b pfb. +Proof. + intuition; simpl in *; try congruence. + subst. + f_equal. + eapply UIP_dec, Z.eq_dec. +Qed. + +Definition mulmod m := fun a b => a * b mod m. +Definition powmod_pos m := Pos.iter_op (mulmod m). +Definition powmod m a x := match x with N0 => 1 mod m | Npos p => powmod_pos m p (a mod m) end. + +Lemma mulmod_assoc: + forall m x y z : Z, mulmod m x (mulmod m y z) = mulmod m (mulmod m x y) z. +Proof. + unfold mulmod; intros. + rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal. + apply Z.mul_assoc. +Qed. + +Lemma powmod_1plus: + forall m a : Z, + forall x : N, powmod m a (1 + x) = (a * (powmod m a x mod m)) mod m. +Proof. + intros m a x. + rewrite N.add_1_l. + cbv beta delta [powmod N.succ]. + destruct x. simpl; rewrite ?Zdiv.Zmult_mod_idemp_r, Z.mul_1_r; auto. + unfold powmod_pos. + rewrite Pos.iter_op_succ by (apply mulmod_assoc). + unfold mulmod. + rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal. +Qed. + + +Lemma N_pos_1plus : forall p, (N.pos p = 1 + (N.pred (N.pos p)))%N. + intros. + rewrite <-N.pos_pred_spec. + rewrite N.add_1_l. + rewrite N.pos_pred_spec. + rewrite N.succ_pred; eauto. + discriminate. +Qed. + +Lemma powmod_Zpow_mod : forall m a n, powmod m a n = (a^Z.of_N n) mod m. +Proof. + induction n using N.peano_ind; [auto|]. + rewrite <-N.add_1_l. + rewrite powmod_1plus. + rewrite IHn. + rewrite Zmod_mod. + rewrite N.add_1_l. + rewrite N2Z.inj_succ. + rewrite Z.pow_succ_r by (apply N2Z.is_nonneg). + rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal. +Qed. + +Local Obligation Tactic := idtac. + +Program Definition pow_impl_sig {m} (a:{z : Z | z = z mod m}) (x:N) : {z : Z | z = z mod m} + := powmod m (proj1_sig a) x. +Next Obligation. + intros; destruct x; [simpl; rewrite Zmod_mod; reflexivity|]. + rewrite N_pos_1plus. + rewrite powmod_1plus. + rewrite Zmod_mod; reflexivity. +Qed. + +Program Definition pow_impl {m} : + {pow0 + : {z : BinNums.Z | z = z mod m} -> BinNums.N -> {z : BinNums.Z | z = z mod m} + | + forall a : {z : BinNums.Z | z = z mod m}, + pow0 a 0%N = + exist (fun z : BinNums.Z => z = z mod m) (1 mod m) (Z_mod_mod 1 m) /\ + (forall x : BinNums.N, + pow0 a (1 + x)%N = + exist (fun z : BinNums.Z => z = z mod m) + ((proj1_sig a * proj1_sig (pow0 a x)) mod m) + (Z_mod_mod (proj1_sig a * proj1_sig (pow0 a x)) m))} := pow_impl_sig. +Next Obligation. + split; intros; apply exist_reduced_eq; + rewrite ?powmod_1plus, ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; reflexivity. +Qed. + +Program Definition mod_inv_sig {m} (a:{z : Z | z = z mod m}) : {z : Z | z = z mod m} := + let (a, _) := a in + match a return _ with + | 0%Z => 0 (* m = 2 *) + | _ => powmod m a (Z.to_N (m-2)) + end. +Next Obligation. + intros; break_match; rewrite ?powmod_Zpow_mod, ?Zmod_mod, ?Zmod_0_l; reflexivity. +Qed. + +Program Definition inv_impl {m : BinNums.Z} : + {inv0 : {z : BinNums.Z | z = z mod m} -> {z : BinNums.Z | z = z mod m} | + inv0 (exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m)) = + exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m) /\ + (Znumtheory.prime m -> + forall a : {z : BinNums.Z | z = z mod m}, + a <> exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m) -> + exist (fun z : BinNums.Z => z = z mod m) + ((proj1_sig (inv0 a) * proj1_sig a) mod m) + (Z_mod_mod (proj1_sig (inv0 a) * proj1_sig a) m) = + exist (fun z : BinNums.Z => z = z mod m) (1 mod m) (Z_mod_mod 1 m))} + := mod_inv_sig. +Next Obligation. + split. + { apply exist_reduced_eq; rewrite Zmod_0_l; reflexivity. } + intros Hm [a pfa] Ha'. apply exist_reduced_eq. + assert (Hm':0 <= m - 2) by (pose proof prime_ge_2 m Hm; omega). + assert (Ha:a mod m<>0) by (intro; apply Ha', exist_reduced_eq; congruence). + cbv [proj1_sig mod_inv_sig]. + transitivity ((a*powmod m a (Z.to_N (m - 2))) mod m); [destruct a; f_equal; ring|]. + rewrite !powmod_Zpow_mod. + rewrite Z2N.id by assumption. + rewrite Zmult_mod_idemp_r. + rewrite <-Z.pow_succ_r by assumption. + replace (Z.succ (m - 2)) with (m-1) by omega. + rewrite (Zmod_small 1) by omega. + apply (fermat_little m Hm a Ha). +Qed.
\ No newline at end of file diff --git a/src/Arithmetic/ModularArithmeticTheorems.v b/src/Arithmetic/ModularArithmeticTheorems.v new file mode 100644 index 000000000..990aa9dc8 --- /dev/null +++ b/src/Arithmetic/ModularArithmeticTheorems.v @@ -0,0 +1,347 @@ +Require Import Coq.omega.Omega. +Require Import Crypto.Spec.ModularArithmetic. +Require Import Crypto.Arithmetic.ModularArithmeticPre. + +Require Import Coq.ZArith.BinInt Coq.ZArith.Zdiv Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *) +Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid. +Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Ring_tac. + +Require Import Crypto.Algebra.Hierarchy Crypto.Algebra.Ring Crypto.Algebra.Field. +Require Import Crypto.Util.Decidable Crypto.Util.ZUtil. +Require Export Crypto.Util.FixCoqMistakes. + +Module F. + Ltac unwrap_F := + intros; + repeat match goal with [ x : F _ |- _ ] => destruct x end; + lazy iota beta delta [F.add F.sub F.mul F.opp F.to_Z F.of_Z proj1_sig] in *; + try apply eqsig_eq; + pull_Zmod. + + (* FIXME: remove the pose proof once [monoid] no longer contains decidable equality *) + Global Instance eq_dec {m} : DecidableRel (@eq (F m)). pose proof dec_eq_Z. exact _. Defined. + + Global Instance commutative_ring_modulo m + : @Algebra.Hierarchy.commutative_ring (F m) Logic.eq 0%F 1%F F.opp F.add F.sub F.mul. + Proof. + repeat (split || intro); unwrap_F; + autorewrite with zsimplify; solve [ exact _ | auto with zarith | congruence]. + Qed. + + Lemma pow_spec {m} a : F.pow a 0%N = 1%F :> F m /\ forall x, F.pow a (1 + x)%N = F.mul a (F.pow a x). + Proof. change (@F.pow m) with (proj1_sig (@F.pow_with_spec m)); destruct (@F.pow_with_spec m); eauto. Qed. + + Global Instance char_gt {m} : + @Ring.char_ge + (F m) Logic.eq F.zero F.one F.opp F.add F.sub F.mul + m. + Proof. + Admitted. + + Section FandZ. + Context {m:positive}. + Local Open Scope F_scope. + + Theorem eq_to_Z_iff (x y : F m) : x = y <-> F.to_Z x = F.to_Z y. + Proof using Type. destruct x, y; intuition; simpl in *; try apply (eqsig_eq _ _); congruence. Qed. + + Lemma eq_of_Z_iff : forall x y : Z, x mod m = y mod m <-> F.of_Z m x = F.of_Z m y. + Proof using Type. split; unwrap_F; congruence. Qed. + + + Lemma to_Z_of_Z : forall z, F.to_Z (F.of_Z m z) = z mod m. + Proof using Type. unwrap_F; trivial. Qed. + + Lemma of_Z_to_Z x : F.of_Z m (F.to_Z x) = x :> F m. + Proof using Type. unwrap_F; congruence. Qed. + + + Lemma of_Z_mod : forall x, F.of_Z m x = F.of_Z m (x mod m). + Proof using Type. unwrap_F; trivial. Qed. + + Lemma mod_to_Z : forall (x:F m), F.to_Z x mod m = F.to_Z x. + Proof using Type. unwrap_F. congruence. Qed. + + Lemma to_Z_0 : F.to_Z (0:F m) = 0%Z. + Proof using Type. unwrap_F. apply Zmod_0_l. Qed. + + Lemma of_Z_small_nonzero z : (0 < z < m)%Z -> F.of_Z m z <> 0. + Proof using Type. intros Hrange Hnz. inversion Hnz. rewrite Zmod_small, Zmod_0_l in *; omega. Qed. + + Lemma to_Z_nonzero (x:F m) : x <> 0 -> F.to_Z x <> 0%Z. + Proof using Type. intros Hnz Hz. rewrite <- Hz, of_Z_to_Z in Hnz; auto. Qed. + + Lemma to_Z_range (x : F m) : 0 < m -> 0 <= F.to_Z x < m. + Proof using Type. intros. rewrite <- mod_to_Z. apply Z.mod_pos_bound. trivial. Qed. + + Lemma to_Z_nonzero_range (x : F m) : (x <> 0) -> 0 < m -> (1 <= F.to_Z x < m)%Z. + Proof using Type. + unfold not; intros Hnz Hlt. + rewrite eq_to_Z_iff, to_Z_0 in Hnz; pose proof (to_Z_range x Hlt). + omega. + Qed. + + Lemma of_Z_add : forall (x y : Z), + F.of_Z m (x + y) = F.of_Z m x + F.of_Z m y. + Proof using Type. unwrap_F; trivial. Qed. + + Lemma to_Z_add : forall x y : F m, + F.to_Z (x + y) = ((F.to_Z x + F.to_Z y) mod m)%Z. + Proof using Type. unwrap_F; trivial. Qed. + + Lemma of_Z_mul x y : F.of_Z m (x * y) = F.of_Z _ x * F.of_Z _ y :> F m. + Proof using Type. unwrap_F. trivial. Qed. + + Lemma to_Z_mul : forall x y : F m, + F.to_Z (x * y) = ((F.to_Z x * F.to_Z y) mod m)%Z. + Proof using Type. unwrap_F; trivial. Qed. + + Lemma of_Z_sub x y : F.of_Z _ (x - y) = F.of_Z _ x - F.of_Z _ y :> F m. + Proof using Type. unwrap_F. trivial. Qed. + + Lemma to_Z_opp : forall x : F m, F.to_Z (F.opp x) = (- F.to_Z x) mod m. + Proof using Type. unwrap_F; trivial. Qed. + + Lemma of_Z_pow x n : F.of_Z _ x ^ n = F.of_Z _ (x ^ (Z.of_N n) mod m) :> F m. + Proof using Type. + intros. + induction n using N.peano_ind; + destruct (pow_spec (F.of_Z m x)) as [pow_0 pow_succ] . { + rewrite pow_0. + unwrap_F; trivial. + } { + rewrite N2Z.inj_succ. + rewrite Z.pow_succ_r by apply N2Z.is_nonneg. + rewrite <- N.add_1_l. + rewrite pow_succ. + rewrite IHn. + unwrap_F; trivial. + } + Qed. + + Lemma to_Z_pow : forall (x : F m) n, + F.to_Z (x ^ n)%F = (F.to_Z x ^ Z.of_N n mod m)%Z. + Proof using Type. + intros. + symmetry. + induction n using N.peano_ind; + destruct (pow_spec x) as [pow_0 pow_succ] . { + rewrite pow_0, Z.pow_0_r; auto. + } { + rewrite N2Z.inj_succ. + rewrite Z.pow_succ_r by apply N2Z.is_nonneg. + rewrite <- N.add_1_l. + rewrite pow_succ. + rewrite <- Zmult_mod_idemp_r. + rewrite IHn. + apply to_Z_mul. + } + Qed. + + Lemma square_iff (x:F m) : + (exists y : F m, y * y = x) <-> (exists y : Z, y * y mod m = F.to_Z x)%Z. + Proof using Type. + setoid_rewrite eq_to_Z_iff; setoid_rewrite to_Z_mul; split; intro H; destruct H as [x' H]. + - eauto. + - exists (F.of_Z _ x'); rewrite !to_Z_of_Z; pull_Zmod; auto. + Qed. + End FandZ. + + Section FandNat. + Import NPeano Nat. + Local Infix "mod" := modulo : nat_scope. + Local Open Scope nat_scope. + + Context {m:BinPos.positive}. + + Lemma to_nat_of_nat (n:nat) : F.to_nat (F.of_nat m n) = (n mod (Z.to_nat m))%nat. + Proof using Type. + unfold F.to_nat, F.of_nat. + rewrite F.to_Z_of_Z. + assert (Pos.to_nat m <> 0)%nat as HA by (pose proof Pos2Nat.is_pos m; omega). + pose proof (mod_Zmod n (Pos.to_nat m) HA) as Hmod. + rewrite positive_nat_Z in Hmod. + rewrite <- Hmod. + rewrite <-Nat2Z.id, Z2Nat.inj_pos; omega. + Qed. + + Lemma of_nat_to_nat x : F.of_nat m (F.to_nat x) = x. + Proof using Type. + + unfold F.to_nat, F.of_nat. + rewrite Z2Nat.id; [ eapply F.of_Z_to_Z | eapply F.to_Z_range; reflexivity]. + Qed. + + Lemma Pos_to_nat_nonzero p : Pos.to_nat p <> 0%nat. + Admitted. + + Lemma of_nat_mod (n:nat) : F.of_nat m (n mod (Z.to_nat m)) = F.of_nat m n. + Proof using Type. + unfold F.of_nat. + rewrite (F.of_Z_mod (Z.of_nat n)), ?mod_Zmod, ?Z2Nat.id; [reflexivity|..]. + { apply Pos2Z.is_nonneg. } + { rewrite Z2Nat.inj_pos. apply Pos_to_nat_nonzero. } + Qed. + + Lemma to_nat_mod (x:F m) (Hm:(0 < m)%Z) : F.to_nat x mod (Z.to_nat m) = F.to_nat x. + Proof using Type. + + unfold F.to_nat. + rewrite <-F.mod_to_Z at 2. + apply Z.mod_to_nat; [assumption|]. + apply F.to_Z_range; assumption. + Qed. + + Lemma of_nat_add x y : + F.of_nat m (x + y) = (F.of_nat m x + F.of_nat m y)%F. + Proof using Type. unfold F.of_nat; rewrite Nat2Z.inj_add, F.of_Z_add; reflexivity. Qed. + + Lemma of_nat_mul x y : + F.of_nat m (x * y) = (F.of_nat m x * F.of_nat m y)%F. + Proof using Type. unfold F.of_nat; rewrite Nat2Z.inj_mul, F.of_Z_mul; reflexivity. Qed. + End FandNat. + + Section RingTacticGadgets. + Context (m:positive). + + Definition ring_theory : ring_theory 0%F 1%F (@F.add m) (@F.mul m) (@F.sub m) (@F.opp m) eq + := Algebra.Ring.ring_theory_for_stdlib_tactic. + + Lemma pow_pow_N (x : F m) : forall (n : N), (x ^ id n)%F = pow_N 1%F F.mul x n. + Proof using Type. + destruct (pow_spec x) as [HO HS]; intros. + destruct n; auto; unfold id. + rewrite ModularArithmeticPre.N_pos_1plus at 1. + rewrite HS. + simpl. + induction p using Pos.peano_ind. + - simpl. rewrite HO. apply Algebra.Hierarchy.right_identity. + - rewrite (@pow_pos_succ (F m) (@F.mul m) eq _ _ associative x). + rewrite <-IHp, Pos.pred_N_succ, ModularArithmeticPre.N_pos_1plus, HS. + trivial. + Qed. + + Lemma power_theory : power_theory 1%F (@F.mul m) eq id (@F.pow m). + Proof using Type. split; apply pow_pow_N. Qed. + + (***** Division Theory *****) + Definition quotrem(a b: F m): F m * F m := + let '(q, r) := (Z.quotrem (F.to_Z a) (F.to_Z b)) in (F.of_Z _ q , F.of_Z _ r). + Lemma div_theory : div_theory eq (@F.add m) (@F.mul m) (@id _) quotrem. + Proof using Type. + constructor; intros; unfold quotrem, id. + + replace (Z.quotrem (F.to_Z a) (F.to_Z b)) with (Z.quot (F.to_Z a) (F.to_Z b), Z.rem (F.to_Z a) (F.to_Z b)) by + try (unfold Z.quot, Z.rem; rewrite <- surjective_pairing; trivial). + + unwrap_F; rewrite <-Z.quot_rem'; trivial. + Qed. + + (* Define a "ring morphism" between GF and Z, i.e. an equivalence + * between 'inject (ZFunction (X))' and 'GFFunction (inject (X))'. + * + * Doing this allows the [ring] tactic to do coefficient + * manipulations in Z rather than F, because we know it's equivalent + * to inject the result afterward. *) + Lemma ring_morph: ring_morph 0%F 1%F F.add F.mul F.sub F.opp eq + 0%Z 1%Z Z.add Z.mul Z.sub Z.opp Z.eqb (F.of_Z m). + Proof using Type. split; intros; unwrap_F; solve [ auto | rewrite (proj1 (Z.eqb_eq x y)); trivial]. Qed. + + (* Redefine our division theory under the ring morphism *) + Lemma morph_div_theory: + Ring_theory.div_theory eq Zplus Zmult (F.of_Z m) Z.quotrem. + Proof using Type. + split; intros. + replace (Z.quotrem a b) with (Z.quot a b, Z.rem a b); + try (unfold Z.quot, Z.rem; rewrite <- surjective_pairing; trivial). + unwrap_F; rewrite <- (Z.quot_rem' a b); trivial. + Qed. + + End RingTacticGadgets. + + Ltac is_constant t := match t with F.of_Z _ ?x => x | _ => NotConstant end. + Ltac is_pow_constant t := Ncst t. + + Section VariousModulo. + Context {m:positive}. + Local Open Scope F_scope. + + Add Ring _theory : (ring_theory m) + (morphism (ring_morph m), + constants [is_constant], + div (morph_div_theory m), + power_tac (power_theory m) [is_pow_constant]). + + Lemma mul_nonzero_l : forall a b : F m, a*b <> 0 -> a <> 0. + Proof using Type. intros a b Hnz Hz. rewrite Hz in Hnz; apply Hnz; ring. Qed. + + Lemma mul_nonzero_r : forall a b : F m, a*b <> 0 -> b <> 0. + Proof using Type. intros a b Hnz Hz. rewrite Hz in Hnz; apply Hnz; ring. Qed. + End VariousModulo. + + Section Pow. + Context {m:positive}. + Add Ring _theory' : (ring_theory m) + (morphism (ring_morph m), + constants [is_constant], + div (morph_div_theory m), + power_tac (power_theory m) [is_pow_constant]). + Local Open Scope F_scope. + + Import Algebra.ScalarMult. + Global Instance pow_is_scalarmult + : is_scalarmult (G:=F m) (eq:=eq) (add:=F.mul) (zero:=1%F) (mul := fun n x => x ^ (N.of_nat n)). + Proof using Type. + split; intros; rewrite ?Nat2N.inj_succ, <-?N.add_1_l; + match goal with + | [x:F m |- _ ] => solve [destruct (@pow_spec m P); auto] + | |- Proper _ _ => solve_proper + end. + Qed. + + (* TODO: move this somewhere? *) + Create HintDb nat2N discriminated. + Hint Rewrite Nat2N.inj_iff + (eq_refl _ : (0%N = N.of_nat 0)) + (eq_refl _ : (1%N = N.of_nat 1)) + (eq_refl _ : (2%N = N.of_nat 2)) + (eq_refl _ : (3%N = N.of_nat 3)) + : nat2N. + Hint Rewrite <- Nat2N.inj_double Nat2N.inj_succ_double Nat2N.inj_succ + Nat2N.inj_add Nat2N.inj_mul Nat2N.inj_sub Nat2N.inj_pred + Nat2N.inj_div2 Nat2N.inj_max Nat2N.inj_min Nat2N.id + : nat2N. + + Ltac pow_to_scalarmult_ref := + repeat (autorewrite with nat2N; + match goal with + | |- context [ (_^?n)%F ] => + rewrite <-(N2Nat.id n); generalize (N.to_nat n); clear n; + let m := fresh n in intro m + | |- context [ (_^N.of_nat ?n)%F ] => + let rw := constr:(scalarmult_ext(zero:=F.of_Z m 1) n) in + setoid_rewrite rw (* rewriting moduloa reduction *) + end). + + Lemma pow_0_r (x:F m) : x^0 = 1. + Proof using Type. pow_to_scalarmult_ref. apply scalarmult_0_l. Qed. + + Lemma pow_add_r (x:F m) (a b:N) : x^(a+b) = x^a * x^b. + Proof using Type. pow_to_scalarmult_ref; apply scalarmult_add_l. Qed. + + Lemma pow_0_l (n:N) : n <> 0%N -> 0^n = 0 :> F m. + Proof using Type. pow_to_scalarmult_ref; destruct n; simpl; intros; [congruence|ring]. Qed. + + Lemma pow_pow_l (x:F m) (a b:N) : (x^a)^b = x^(a*b). + Proof using Type. pow_to_scalarmult_ref. apply scalarmult_assoc. Qed. + + Lemma pow_1_r (x:F m) : x^1 = x. + Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed. + + Lemma pow_2_r (x:F m) : x^2 = x*x. + Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed. + + Lemma pow_3_r (x:F m) : x^3 = x*x*x. + Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed. + End Pow. +End F. diff --git a/src/Arithmetic/MontgomeryReduction/Definition.v b/src/Arithmetic/MontgomeryReduction/Definition.v new file mode 100644 index 000000000..78d3c037f --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/Definition.v @@ -0,0 +1,179 @@ +(*** Montgomery Multiplication *) +(** This file implements Montgomery Form, Montgomery Reduction, and + Montgomery Multiplication on [Z]. We follow Wikipedia. *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Local Open Scope Z_scope. +Delimit Scope montgomery_naive_scope with montgomery_naive. +Delimit Scope montgomery_scope with montgomery. +Definition montgomeryZ := Z. +Bind Scope montgomery_scope with montgomeryZ. + +Section montgomery. + Context (N : Z) + (R : Z) + (R' : Z). (* R' is R⁻¹ mod N *) + Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope. + Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope. + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Montgomery_modular_multiplication>: *) + (** In modular arithmetic computation, Montgomery modular + multiplication, more commonly referred to as Montgomery + multiplication, is a method for performing fast modular + multiplication, introduced in 1985 by the American mathematician + Peter L. Montgomery. *) + (** Given two integers [a] and [b], the classical modular + multiplication algorithm computes [ab mod N]. Montgomery + multiplication works by transforming [a] and [b] into a special + representation known as Montgomery form. For a modulus [N], the + Montgomery form of [a] is defined to be [aR mod N] for some + constant [R] depending only on [N] and the underlying computer + architecture. If [aR mod N] and [bR mod N] are the Montgomery + forms of [a] and [b], then their Montgomery product is [abR mod + N]. Montgomery multiplication is a fast algorithm to compute the + Montgomery product. Transforming the result out of Montgomery + form yields the classical modular product [ab mod N]. *) + + Definition to_montgomery_naive (x : Z) : montgomeryZ := x * R. + Definition from_montgomery_naive (x : montgomeryZ) : Z := x * R'. + + (** * Modular arithmetic and Montgomery form *) + Section general. + (** Addition and subtraction in Montgomery form are the same as + ordinary modular addition and subtraction because of the + distributive law: *) + (** [aR + bR = (a+b)R], *) + (** [aR - bR = (a-b)R]. *) + (** This is a consequence of the fact that, because [gcd(R, N) = + 1], multiplication by [R] is an isomorphism on the additive + group [ℤ/Nℤ]. *) + + Definition add : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR + bR. + Definition sub : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR - bR. + + (** Multiplication in Montgomery form, however, is seemingly more + complicated. The usual product of [aR] and [bR] does not + represent the product of [a] and [b] because it has an extra + factor of R: *) + (** [(aR mod N)(bR mod N) mod N = (abR)R mod N]. *) + (** Computing products in Montgomery form requires removing the + extra factor of [R]. While division by [R] is cheap, the + intermediate product [(aR mod N)(bR mod N)] is not divisible + by [R] because the modulo operation has destroyed that + property. *) + (** Removing the extra factor of R can be done by multiplying by + an integer [R′] such that [RR' ≡ 1 (mod N)], that is, by an + [R′] whose residue class is the modular inverse of [R] mod + [N]. Then, working modulo [N], *) + (** [(aR mod N)(bR mod N)R' ≡ (aR)(bR)R⁻¹ ≡ (ab)R (mod N)]. *) + + Definition mul_naive : montgomeryZ -> montgomeryZ -> montgomeryZ + := fun aR bR => aR * bR * R'. + End general. + + (** * The REDC algorithm *) + Section redc. + (** While the above algorithm is correct, it is slower than + multiplication in the standard representation because of the + need to multiply by [R′] and divide by [N]. Montgomery + reduction, also known as REDC, is an algorithm that + simultaneously computes the product by [R′] and reduces modulo + [N] more quickly than the naive method. The speed is because + all computations are done using only reduction and divisions + with respect to [R], not [N]: *) + (** +<< +function REDC is + input: Integers R and N with gcd(R, N) = 1, + Integer N′ in [0, R − 1] such that NN′ ≡ −1 mod R, + Integer T in the range [0, RN − 1] + output: Integer S in the range [0, N − 1] such that S ≡ TR⁻¹ mod N + + m ← ((T mod R)N′) mod R + t ← (T + mN) / R + if t ≥ N then + return t − N + else + return t + end if +end function +>> *) + Context (N' : Z). (* N' is (-N⁻¹) mod R *) + Section redc. + Context (T : Z). + + Let m := ((T mod R) * N') mod R. + Let t := (T + m * N) / R. + Definition prereduce : montgomeryZ := t. + + Definition partial_reduce : montgomeryZ + := if R <=? t then + prereduce - N + else + prereduce. + + Definition partial_reduce_alt : montgomeryZ + := let v0 := (T + m * N) in + let v := (v0 mod (R * R)) / R in + if R * R <=? v0 then + (v - N) mod R + else + v. + + Definition reduce : montgomeryZ + := if N <=? t then + prereduce - N + else + prereduce. + + Definition reduce_via_partial : montgomeryZ + := if N <=? partial_reduce then + partial_reduce - N + else + partial_reduce. + + Definition reduce_via_partial_alt : montgomeryZ + := if N <=? partial_reduce then + partial_reduce - N + else + partial_reduce. + End redc. + + (** * Arithmetic in Montgomery form *) + Section arithmetic. + (** Many operations of interest modulo [N] can be expressed + equally well in Montgomery form. Addition, subtraction, + negation, comparison for equality, multiplication by an + integer not in Montgomery form, and greatest common divisors + with [N] may all be done with the standard algorithms. *) + (** When [R > N], most other arithmetic operations can be + expressed in terms of REDC. This assumption implies that the + product of two representatives [mod N] is less than [RN], + the exact hypothesis necessary for REDC to generate correct + output. In particular, the product of [aR mod N] and [bR mod + N] is [REDC((aR mod N)(bR mod N))]. The combined operation + of multiplication and REDC is often called Montgomery + multiplication. *) + Definition mul : montgomeryZ -> montgomeryZ -> montgomeryZ + := fun aR bR => reduce (aR * bR). + + (** Conversion into Montgomery form is done by computing + [REDC((a mod N)(R² mod N))]. Conversion out of Montgomery + form is done by computing [REDC(aR mod N)]. The modular + inverse of [aR mod N] is [REDC((aR mod N)⁻¹(R³ mod + N))]. Modular exponentiation can be done using + exponentiation by squaring by initializing the initial + product to the Montgomery representation of 1, that is, to + [R mod N], and by replacing the multiply and square steps by + Montgomery multiplies. *) + Definition to_montgomery (a : Z) : montgomeryZ := reduce (a * (R*R mod N)). + Definition from_montgomery (aR : montgomeryZ) : Z := reduce aR. + End arithmetic. + End redc. +End montgomery. + +Infix "+" := add : montgomery_scope. +Infix "-" := sub : montgomery_scope. +Infix "*" := mul_naive : montgomery_naive_scope. +Infix "*" := mul : montgomery_scope. diff --git a/src/Arithmetic/MontgomeryReduction/Proofs.v b/src/Arithmetic/MontgomeryReduction/Proofs.v new file mode 100644 index 000000000..d5de00213 --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/Proofs.v @@ -0,0 +1,296 @@ +(*** Montgomery Multiplication *) +(** This file implements the proofs for Montgomery Form, Montgomery + Reduction, and Montgomery Multiplication on [Z]. We follow + Wikipedia. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.Structures.Equalities. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SimplifyRepeatedIfs. +Require Import Crypto.Util.Notations. + +Declare Module Nop : Nop. +Module Import ImportEquivModuloInstances := Z.EquivModuloInstances Nop. + +Local Existing Instance eq_Reflexive. (* speed up setoid_rewrite as per https://coq.inria.fr/bugs/show_bug.cgi?id=4978 *) + +Local Open Scope Z_scope. + +Section montgomery. + Context (N : Z) + (N_reasonable : N <> 0) + (R : Z) + (R_good : Z.gcd N R = 1). + Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope. + Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope. + Context (R' : Z) + (R'_good : R * R' ≡ 1). + + Lemma R'_good' : R' * R ≡ 1. + Proof using R'_good. rewrite <- R'_good; apply f_equal2; lia. Qed. + + Local Notation to_montgomery_naive := (to_montgomery_naive R) (only parsing). + Local Notation from_montgomery_naive := (from_montgomery_naive R') (only parsing). + + Lemma to_from_montgomery_naive x : to_montgomery_naive (from_montgomery_naive x) ≡ x. + Proof using R'_good. + unfold to_montgomery_naive, from_montgomery_naive. + rewrite <- Z.mul_assoc, R'_good'. + autorewrite with zsimplify; reflexivity. + Qed. + Lemma from_to_montgomery_naive x : from_montgomery_naive (to_montgomery_naive x) ≡ x. + Proof using R'_good. + unfold to_montgomery_naive, from_montgomery_naive. + rewrite <- Z.mul_assoc, R'_good. + autorewrite with zsimplify; reflexivity. + Qed. + + (** * Modular arithmetic and Montgomery form *) + Section general. + Local Infix "+" := add : montgomery_scope. + Local Infix "-" := sub : montgomery_scope. + Local Infix "*" := (mul_naive R') : montgomery_scope. + + Lemma add_correct_naive x y : from_montgomery_naive (x + y) = from_montgomery_naive x + from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, add; lia. Qed. + Lemma add_correct_naive_to x y : to_montgomery_naive (x + y) = (to_montgomery_naive x + to_montgomery_naive y)%montgomery. + Proof using Type. unfold to_montgomery_naive, add; autorewrite with push_Zmul; reflexivity. Qed. + Lemma sub_correct_naive x y : from_montgomery_naive (x - y) = from_montgomery_naive x - from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, sub; lia. Qed. + Lemma sub_correct_naive_to x y : to_montgomery_naive (x - y) = (to_montgomery_naive x - to_montgomery_naive y)%montgomery. + Proof using Type. unfold to_montgomery_naive, sub; autorewrite with push_Zmul; reflexivity. Qed. + + Theorem mul_correct_naive x y : from_montgomery_naive (x * y) = from_montgomery_naive x * from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, mul_naive; lia. Qed. + Theorem mul_correct_naive_to x y : to_montgomery_naive (x * y) ≡ (to_montgomery_naive x * to_montgomery_naive y)%montgomery. + Proof using R'_good. + unfold to_montgomery_naive, mul_naive. + rewrite <- !Z.mul_assoc, R'_good. + autorewrite with zsimplify; apply (f_equal2 Z.modulo); lia. + Qed. + End general. + + (** * The REDC algorithm *) + Section redc. + Context (N' : Z) + (N'_in_range : 0 <= N' < R) + (N'_good : N * N' ≡ᵣ -1). + + Lemma N'_good' : N' * N ≡ᵣ -1. + Proof using N'_good. rewrite <- N'_good; apply f_equal2; lia. Qed. + + Lemma N'_good'_alt x : (((x mod R) * (N' mod R)) mod R) * (N mod R) ≡ᵣ x * -1. + Proof using N'_good. + rewrite <- N'_good', Z.mul_assoc. + unfold Z.equiv_modulo; push_Zmod. + reflexivity. + Qed. + + Section redc. + Context (T : Z). + + Local Notation m := (((T mod R) * N') mod R). + Local Notation prereduce := (prereduce N R N'). + + Local Ltac t_fin_correct := + unfold Z.equiv_modulo; push_Zmod; autorewrite with zsimplify; reflexivity. + + Lemma prereduce_correct : prereduce T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + transitivity ((T + m * N) * R'). + { unfold prereduce. + autorewrite with zstrip_div; push_Zmod. + rewrite N'_good'_alt. + autorewrite with zsimplify pull_Zmod. + reflexivity. } + t_fin_correct. + Qed. + + Lemma reduce_correct : reduce N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold reduce. + break_match; rewrite prereduce_correct; t_fin_correct. + Qed. + + Lemma partial_reduce_correct : partial_reduce N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold partial_reduce. + break_match; rewrite prereduce_correct; t_fin_correct. + Qed. + + Lemma reduce_via_partial_correct : reduce_via_partial N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold reduce_via_partial. + break_match; rewrite partial_reduce_correct; t_fin_correct. + Qed. + + Let m_small : 0 <= m < R. Proof. auto with zarith. Qed. + + Section generic. + Lemma prereduce_in_range_gen B + : 0 <= N + -> 0 <= T <= R * B + -> 0 <= prereduce T < B + N. + Proof using N_reasonable m_small. unfold prereduce; auto with zarith nia. Qed. + End generic. + + Section N_very_small. + Context (N_very_small : 0 <= 4 * N < R). + + Lemma prereduce_in_range_very_small + : 0 <= T <= (2 * N - 1) * (2 * N - 1) + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable N_very_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + End N_very_small. + + Section N_small. + Context (N_small : 0 <= 2 * N < R). + + Lemma prereduce_in_range_small + : 0 <= T <= (2 * N - 1) * (N - 1) + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + + Lemma prereduce_in_range_small_fully_reduced + : 0 <= T <= 2 * N + -> 0 <= prereduce T <= N. + Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen 1); nia. Qed. + End N_small. + + Section N_small_enough. + Context (N_small_enough : 0 <= N < R). + + Lemma prereduce_in_range_small_enough + : 0 <= T <= R * R + -> 0 <= prereduce T < R + N. + Proof using N_reasonable N_small_enough m_small. pose proof (prereduce_in_range_gen R); nia. Qed. + + Lemma reduce_in_range_R + : 0 <= T <= R * R + -> 0 <= reduce N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma partial_reduce_in_range_R + : 0 <= T <= R * R + -> 0 <= partial_reduce N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma reduce_via_partial_in_range_R + : 0 <= T <= R * R + -> 0 <= reduce_via_partial N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold reduce_via_partial, partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + End N_small_enough. + + Section unconstrained. + Lemma prereduce_in_range + : 0 <= T <= R * N + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + + Lemma reduce_in_range + : 0 <= T <= R * N + -> 0 <= reduce N R N' T < N. + Proof using N_reasonable m_small. + intro H; pose proof (prereduce_in_range H). + unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma partial_reduce_in_range + : 0 <= T <= R * N + -> Z.min 0 (R - N) <= partial_reduce N R N' T < 2 * N. + Proof using N_reasonable m_small. + intro H; pose proof (prereduce_in_range H). + unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; + apply Z.min_case_strong; nia. + Qed. + + Lemma reduce_via_partial_in_range + : 0 <= T <= R * N + -> Z.min 0 (R - N) <= reduce_via_partial N R N' T < N. + Proof using N_reasonable m_small. + intro H; pose proof (partial_reduce_in_range H). + unfold reduce_via_partial in *; break_match; Z.ltb_to_lt; lia. + Qed. + End unconstrained. + + Section alt. + Context (N_in_range : 0 <= N < R) + (T_representable : 0 <= T < R * R). + Lemma partial_reduce_alt_eq : partial_reduce_alt N R N' T = partial_reduce N R N' T. + Proof using N_in_range N_reasonable T_representable m_small. + assert (0 <= T + m * N < 2 * (R * R)) by nia. + assert (0 <= T + m * N < R * (R + N)) by nia. + assert (0 <= (T + m * N) / R < R + N) by auto with zarith. + assert ((T + m * N) / R - N < R) by lia. + assert (R * R <= T + m * N -> R <= (T + m * N) / R) by auto with zarith. + assert (T + m * N < R * R -> (T + m * N) / R < R) by auto with zarith. + assert (H' : (T + m * N) mod (R * R) = if R * R <=? T + m * N then T + m * N - R * R else T + m * N) + by (break_match; Z.ltb_to_lt; autorewrite with zsimplify; lia). + unfold partial_reduce, partial_reduce_alt, prereduce. + rewrite H'; clear H'. + simplify_repeated_ifs. + set (m' := m) in *. + autorewrite with zsimplify; push_Zmod; autorewrite with zsimplify; pull_Zmod. + break_match; Z.ltb_to_lt; autorewrite with zsimplify; try reflexivity; lia. + Qed. + End alt. + End redc. + + (** * Arithmetic in Montgomery form *) + Section arithmetic. + Local Infix "*" := (mul N R N') : montgomery_scope. + + Local Notation to_montgomery := (to_montgomery N R N'). + Local Notation from_montgomery := (from_montgomery N R N'). + Lemma to_from_montgomery a : to_montgomery (from_montgomery a) ≡ a. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, from_montgomery. + transitivity ((a * 1) * 1); [ | apply f_equal2; lia ]. + rewrite <- !R'_good, !reduce_correct. + unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + apply f_equal2; lia. + Qed. + Lemma from_to_montgomery a : from_montgomery (to_montgomery a) ≡ a. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, from_montgomery. + rewrite !reduce_correct. + transitivity (a * ((R * (R * R' mod N) * R') mod N)). + { unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + apply f_equal2; lia. } + { repeat first [ rewrite R'_good + | reflexivity + | push_Zmod; pull_Zmod; progress autorewrite with zsimplify + | progress unfold Z.equiv_modulo ]. } + Qed. + + Theorem mul_correct x y : from_montgomery (x * y) ≡ from_montgomery x * from_montgomery y. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold from_montgomery, mul. + rewrite !reduce_correct; apply f_equal2; lia. + Qed. + Theorem mul_correct_to x y : to_montgomery (x * y) ≡ (to_montgomery x * to_montgomery y)%montgomery. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, mul. + rewrite !reduce_correct. + transitivity (x * y * R * 1 * 1 * 1); + [ rewrite <- R'_good at 1 + | rewrite <- R'_good at 1 2 3 ]; + autorewrite with zsimplify; + unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + { apply f_equal2; lia. } + { apply f_equal2; lia. } + Qed. + End arithmetic. + End redc. +End montgomery. + +Module Import LocalizeEquivModuloInstances := Z.RemoveEquivModuloInstances Nop. diff --git a/src/Arithmetic/PrimeFieldTheorems.v b/src/Arithmetic/PrimeFieldTheorems.v new file mode 100644 index 000000000..c253752c5 --- /dev/null +++ b/src/Arithmetic/PrimeFieldTheorems.v @@ -0,0 +1,294 @@ +Require Export Crypto.Spec.ModularArithmetic. +Require Export Crypto.Arithmetic.ModularArithmeticTheorems. +Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Field_theory Coq.setoid_ring.Field_tac. + +Require Import Coq.nsatz.Nsatz. +Require Import Crypto.Arithmetic.ModularArithmeticPre. +Require Import Crypto.Util.NumTheoryUtil. +Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid. +Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.ZArith.ZArith Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *) +Require Import Coq.Logic.Eqdep_dec. +Require Import Crypto.Util.NumTheoryUtil Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Decidable. +Require Export Crypto.Util.FixCoqMistakes. +Require Crypto.Algebra.Hierarchy Crypto.Algebra.Field. + +Existing Class prime. +Local Open Scope F_scope. + +Module F. + Section Field. + Context (q:positive) {prime_q:prime q}. + Lemma inv_spec : F.inv 0%F = (0%F : F q) + /\ (prime q -> forall x : F q, x <> 0%F -> (F.inv x * x)%F = 1%F). + Proof using Type. change (@F.inv q) with (proj1_sig (@F.inv_with_spec q)); destruct (@F.inv_with_spec q); eauto. Qed. + + Lemma inv_0 : F.inv 0%F = F.of_Z q 0. Proof using Type. destruct inv_spec; auto. Qed. + Lemma inv_nonzero (x:F q) : (x <> 0 -> F.inv x * x%F = 1)%F. Proof using Type*. destruct inv_spec; auto. Qed. + + Global Instance field_modulo : @Algebra.Hierarchy.field (F q) Logic.eq 0%F 1%F F.opp F.add F.sub F.mul F.inv F.div. + Proof using Type*. + repeat match goal with + | _ => solve [ solve_proper + | apply F.commutative_ring_modulo + | apply inv_nonzero + | cbv [not]; pose proof prime_ge_2 q prime_q; + rewrite F.eq_to_Z_iff, !F.to_Z_of_Z, !Zmod_small; omega ] + | _ => split + end. + Qed. + End Field. + + Section NumberThoery. + Context {q:positive} {prime_q:prime q} {two_lt_q: 2 < q}. + + (* TODO: move to PrimeFieldTheorems *) + Lemma to_Z_1 : @F.to_Z q 1 = 1%Z. + Proof using two_lt_q. simpl. rewrite Zmod_small; omega. Qed. + + Lemma Fq_inv_fermat (x:F q) : F.inv x = x ^ Z.to_N (q - 2)%Z. + Proof using Type*. + destruct (dec (x = 0%F)) as [?|Hnz]. + { subst x; rewrite inv_0, F.pow_0_l; trivial. + change (0%N) with (Z.to_N 0%Z); rewrite Z2N.inj_iff; omega. } + erewrite <-Algebra.Field.inv_unique; try reflexivity. + rewrite F.eq_to_Z_iff, F.to_Z_mul, F.to_Z_pow, Z2N.id, to_Z_1 by omega. + apply (fermat_inv q _ (F.to_Z x)); rewrite F.mod_to_Z; eapply F.to_Z_nonzero; trivial. + Qed. + + Lemma euler_criterion (a : F q) (a_nonzero : a <> 0) : + (a ^ (Z.to_N (q / 2)) = 1) <-> (exists b, b*b = a). + Proof using Type*. + pose proof F.to_Z_nonzero_range a; pose proof (odd_as_div q). + specialize_by (destruct (Z.prime_odd_or_2 _ prime_q); try omega; trivial). + rewrite F.eq_to_Z_iff, !F.to_Z_pow, !to_Z_1, !Z2N.id by omega. + rewrite F.square_iff, <-(euler_criterion (q/2)) by (trivial || omega); reflexivity. + Qed. + + Global Instance Decidable_square (x:F q) : Decidable (exists y, y*y = x). + Proof. + destruct (dec (x = 0)). + { left. abstract (exists 0; subst; apply Ring.mul_0_l). } + { eapply Decidable_iff_to_impl; [eapply euler_criterion; assumption | exact _]. } + Defined. + End NumberThoery. + + Section SquareRootsPrime3Mod4. + Context {q:positive} {prime_q: prime q} {q_3mod4 : q mod 4 = 3}. + + Add Field _field2 : (Algebra.Field.field_theory_for_stdlib_tactic(T:=F q)) + (morphism (F.ring_morph q), + constants [F.is_constant], + div (F.morph_div_theory q), + power_tac (F.power_theory q) [F.is_pow_constant]). + + Definition sqrt_3mod4 (a : F q) : F q := a ^ Z.to_N (q / 4 + 1). + + Global Instance Proper_sqrt_3mod4 : Proper (eq ==> eq ) sqrt_3mod4. + Proof using Type. repeat intro; subst; reflexivity. Qed. + + Lemma two_lt_q_3mod4 : 2 < q. + Proof using Type*. + pose proof (prime_ge_2 q _) as two_le_q. + destruct (Zle_lt_or_eq _ _ two_le_q) as [H|H]; [exact H|]. + rewrite <-H in q_3mod4; discriminate. + Qed. + Local Hint Resolve two_lt_q_3mod4. + + Lemma sqrt_3mod4_correct (x:F q) : + ((exists y, y*y = x) <-> (sqrt_3mod4 x)*(sqrt_3mod4 x) = x)%F. + Proof using Type*. + cbv [sqrt_3mod4]; intros. + destruct (F.eq_dec x 0); + repeat match goal with + | |- _ => progress subst + | |- _ => progress rewrite ?F.pow_0_l, <-?F.pow_add_r + | |- _ => progress rewrite <-?Z2N.inj_0, <-?Z2N.inj_add by zero_bounds + | |- _ => rewrite <-@euler_criterion by auto + | |- ?x ^ (?f _) = ?a <-> ?x ^ (?f _) = ?a => do 3 f_equiv; [ ] + | |- _ => rewrite !Zmod_odd in *; repeat (break_match; break_match_hyps); omega + | |- _ => rewrite Z.rem_mul_r in * by omega + | |- (exists x, _) <-> ?B => assert B by field; solve [intuition eauto] + | |- (?x ^ Z.to_N ?a = 1) <-> _ => + transitivity (x ^ Z.to_N a * x ^ Z.to_N 1 = x); + [ rewrite F.pow_1_r, Algebra.Field.mul_cancel_l_iff by auto; reflexivity | ] + | |- (_ <> _)%N => rewrite Z2N.inj_iff by zero_bounds + | |- (?a <> 0)%Z => assert (0 < a) by zero_bounds; omega + | |- (_ = _)%Z => replace 4 with (2 * 2)%Z in * by ring; + rewrite <-Z.div_div by zero_bounds; + rewrite Z.add_diag, Z.mul_add_distr_l, Z.mul_div_eq by omega + end. + Qed. + End SquareRootsPrime3Mod4. + + Section SquareRootsPrime5Mod8. + Context {q:positive} {prime_q: prime q} {q_5mod8 : q mod 8 = 5}. + Local Open Scope F_scope. + Add Field _field3 : (Algebra.Field.field_theory_for_stdlib_tactic(T:=F q)) + (morphism (F.ring_morph q), + constants [F.is_constant], + div (F.morph_div_theory q), + power_tac (F.power_theory q) [F.is_pow_constant]). + + (* Any nonsquare element raised to (q-1)/4 (real implementations use 2 ^ ((q-1)/4) ) + would work for sqrt_minus1 *) + Context (sqrt_minus1 : F q) (sqrt_minus1_valid : sqrt_minus1 * sqrt_minus1 = F.opp 1). + + Lemma two_lt_q_5mod8 : 2 < q. + Proof using prime_q q_5mod8. + pose proof (prime_ge_2 q _) as two_le_q. + destruct (Zle_lt_or_eq _ _ two_le_q) as [H|H]; [exact H|]. + rewrite <-H in *. discriminate. + Qed. + Local Hint Resolve two_lt_q_5mod8. + + Definition sqrt_5mod8 (a : F q) : F q := + let b := a ^ Z.to_N (q / 8 + 1) in + if dec (b ^ 2 = a) + then b + else sqrt_minus1 * b. + + Global Instance Proper_sqrt_5mod8 : Proper (eq ==> eq ) sqrt_5mod8. + Proof using Type. repeat intro; subst; reflexivity. Qed. + + Lemma eq_b4_a2 (x : F q) (Hex:exists y, y*y = x) : + ((x ^ Z.to_N (q / 8 + 1)) ^ 2) ^ 2 = x ^ 2. + Proof using prime_q q_5mod8. + pose proof two_lt_q_5mod8. + assert (0 <= q/8)%Z by (apply Z.div_le_lower_bound; rewrite ?Z.mul_0_r; omega). + assert (Z.to_N (q / 8 + 1) <> 0%N) by + (intro Hbad; change (0%N) with (Z.to_N 0%Z) in Hbad; rewrite Z2N.inj_iff in Hbad; omega). + destruct (dec (x = 0)); [subst; rewrite !F.pow_0_l by (trivial || lazy_decide); reflexivity|]. + rewrite !F.pow_pow_l. + + replace (Z.to_N (q / 8 + 1) * (2*2))%N with (Z.to_N (q / 2 + 2))%N. + Focus 2. { (* this is a boring but gnarly proof :/ *) + change (2*2)%N with (Z.to_N 4). + rewrite <- Z2N.inj_mul by zero_bounds. + apply Z2N.inj_iff; try zero_bounds. + rewrite <- Z.mul_cancel_l with (p := 2) by omega. + ring_simplify. + rewrite Z.mul_div_eq by omega. + rewrite Z.mul_div_eq by omega. + rewrite (Zmod_div_mod 2 8 q) by + (try omega; apply Zmod_divide; omega || auto). + rewrite q_5mod8. + replace (5 mod 2)%Z with 1%Z by auto. + ring. + } Unfocus. + + rewrite Z2N.inj_add, F.pow_add_r by zero_bounds. + replace (x ^ Z.to_N (q / 2)) with (F.of_Z q 1) by + (symmetry; apply @euler_criterion; eauto). + change (Z.to_N 2) with 2%N; ring. + Qed. + + Lemma mul_square_sqrt_minus1 : forall x, sqrt_minus1 * x * (sqrt_minus1 * x) = F.opp (x * x). + Proof using prime_q sqrt_minus1_valid. + intros. + transitivity (F.opp 1 * (x * x)); [ | field]. + rewrite <-sqrt_minus1_valid. + field. + Qed. + + Lemma eq_b4_a2_iff (x : F q) : x <> 0 -> + ((exists y, y*y = x) <-> ((x ^ Z.to_N (q / 8 + 1)) ^ 2) ^ 2 = x ^ 2). + Proof using Type*. + split; try apply eq_b4_a2. + intro Hyy. + rewrite !@F.pow_2_r in *. + destruct (Field.only_two_square_roots_choice _ x (x * x) Hyy eq_refl); clear Hyy; + [ eexists; eassumption | ]. + match goal with H : ?a * ?a = F.opp _ |- _ => exists (sqrt_minus1 * a); + rewrite mul_square_sqrt_minus1; rewrite H end. + field. + Qed. + + Lemma sqrt_5mod8_correct : forall x, + ((exists y, y*y = x) <-> (sqrt_5mod8 x)*(sqrt_5mod8 x) = x). + Proof using Type*. + cbv [sqrt_5mod8]; intros. + destruct (F.eq_dec x 0). + { + repeat match goal with + | |- _ => progress subst + | |- _ => progress rewrite ?F.pow_0_l + | |- (_ <> _)%N => rewrite <-Z2N.inj_0, Z2N.inj_iff by zero_bounds + | |- (?a <> 0)%Z => assert (0 < a) by zero_bounds; omega + | |- _ => congruence + end. + break_match; + match goal with |- _ <-> ?G => assert G by field end; intuition eauto. + } { + rewrite eq_b4_a2_iff by auto. + rewrite !@F.pow_2_r in *. + break_match. + intuition (f_equal; eauto). + split; intro A. { + destruct (Field.only_two_square_roots_choice _ x (x * x) A eq_refl) as [B | B]; + clear A; try congruence. + rewrite mul_square_sqrt_minus1, B; field. + } { + rewrite mul_square_sqrt_minus1 in A. + transitivity (F.opp x * F.opp x); [ | field ]. + f_equal; rewrite <-A at 3; field. + } + } + Qed. + End SquareRootsPrime5Mod8. + + Module Iso. + Section IsomorphicRings. + Context {q:positive} {prime_q:prime q} {two_lt_q:2 < Z.pos q}. + Context + {H : Type} {eq : H -> H -> Prop} {zero one : H} + {opp : H -> H} {add sub mul : H -> H -> H} + {phi : F q -> H} {phi' : H -> F q} + {phi'_phi : forall A : F q, Logic.eq (phi' (phi A)) A} + {phi'_iff : forall a b : H, iff (Logic.eq (phi' a) (phi' b)) (eq a b)} + {phi'_zero : Logic.eq (phi' zero) F.zero} {phi'_one : Logic.eq (phi' one) F.one} + {phi'_opp : forall a : H, Logic.eq (phi' (opp a)) (F.opp (phi' a))} + {phi'_add : forall a b : H, Logic.eq (phi' (add a b)) (F.add (phi' a) (phi' b))} + {phi'_sub : forall a b : H, Logic.eq (phi' (sub a b)) (F.sub (phi' a) (phi' b))} + {phi'_mul : forall a b : H, Logic.eq (phi' (mul a b)) (F.mul (phi' a) (phi' b))} + {P:Type} {pow : H -> P -> H} {NtoP:N->P} + {pow_is_scalarmult:ScalarMult.is_scalarmult(G:=H)(eq:=eq)(add:=mul)(zero:=one)(mul:=fun (n:nat) (k:H) => pow k (NtoP (N.of_nat n)))}. + Definition inv (x:H) := pow x (NtoP (Z.to_N (q - 2)%Z)). + Definition div x y := mul (inv y) x. + + Lemma ring : + @Algebra.Hierarchy.ring H eq zero one opp add sub mul + /\ @Ring.is_homomorphism (F q) Logic.eq F.one F.add F.mul H eq one add mul phi + /\ @Ring.is_homomorphism H eq one add mul (F q) Logic.eq F.one F.add F.mul phi'. + Proof using phi'_add phi'_iff phi'_mul phi'_one phi'_opp phi'_phi phi'_sub phi'_zero. eapply @Ring.ring_by_isomorphism; assumption || exact _. Qed. + Local Instance _iso_ring : Algebra.Hierarchy.ring := proj1 ring. + Local Instance _iso_hom1 : Ring.is_homomorphism := proj1 (proj2 ring). + Local Instance _iso_hom2 : Ring.is_homomorphism := proj2 (proj2 ring). + + Let inv_proof : forall a : H, phi' (inv a) = F.inv (phi' a). + Proof. + intros. + cbv [inv]. rewrite (Fq_inv_fermat(q:=q)(two_lt_q:=two_lt_q)). + rewrite <-Z_nat_N at 1 2. + rewrite (ScalarMult.homomorphism_scalarmult(phi:=phi')(MUL_is_scalarmult:=pow_is_scalarmult)(mul_is_scalarmult:=F.pow_is_scalarmult)). + reflexivity. + assumption. + Qed. + + Let div_proof : forall a b : H, phi' (mul (inv b) a) = phi' a / phi' b. + Proof. + intros. + rewrite phi'_mul, inv_proof, Algebra.Hierarchy.field_div_definition, Algebra.Hierarchy.commutative. + reflexivity. + Qed. + + Lemma field_and_iso : + @Algebra.Hierarchy.field H eq zero one opp add sub mul inv div + /\ @Ring.is_homomorphism (F q) Logic.eq F.one F.add F.mul H eq one add mul phi + /\ @Ring.is_homomorphism H eq one add mul (F q) Logic.eq F.one F.add F.mul phi'. + Proof using Type*. eapply @Field.field_and_homomorphism_from_redundant_representation; + assumption || exact _ || exact inv_proof || exact div_proof. Qed. + End IsomorphicRings. + End Iso. +End F. diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v new file mode 100644 index 000000000..cb37fb1f9 --- /dev/null +++ b/src/Arithmetic/Saturated.v @@ -0,0 +1,285 @@ +Require Import Coq.Init.Nat. +Require Import Coq.ZArith.ZArith. +Require Import Coq.Lists.List. +Local Open Scope Z_scope. + +Require Import Crypto.Algebra.Nsatz. +Require Import Crypto.Arithmetic.Core. +Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil. +Require Import Crypto.Util.Tuple Crypto.Util.ListUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Local Notation "A ^ n" := (tuple A n) : type_scope. + +(*** + +Arithmetic on bignums that handles carry bits; this is useful for +saturated limbs. Compatible with mixed-radix bases. + + ***) + +Module Columns. + Section Columns. + Context {weight : nat->Z} + {weight_0 : weight 0%nat = 1} + {weight_nonzero : forall i, weight i <> 0} + {weight_multiples : forall i, weight (S i) mod weight i = 0} + (* add_get_carry takes in a number at which to split output *) + {add_get_carry: Z ->Z -> Z -> (Z * Z)} + {add_get_carry_correct : forall s x y, + fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)} + . + + Definition eval {n} (x : (list Z)^n) : Z := + B.Positional.eval weight (Tuple.map sum x). + + Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z := + B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x). + + Lemma eval_from_0 {n} x : @eval_from n 0 x = eval x. + Proof using Type. cbv [eval_from eval]. auto using B.Positional.eval_wt_equiv. Qed. + + Lemma eval_from_S {n}: forall i (inp : (list Z)^(S n)), + eval_from i inp = eval_from (S i) (tl inp) + weight i * sum (hd inp). + Proof using Type. + intros; cbv [eval_from]. + replace inp with (append (hd inp) (tl inp)) + by (simpl in *; destruct n; destruct inp; reflexivity). + rewrite map_append, B.Positional.eval_step, hd_append, tl_append. + autorewrite with natsimplify; ring_simplify; rewrite Group.cancel_left. + apply B.Positional.eval_wt_equiv; intros; f_equal; omega. + Qed. + + (* Sums a list of integers using carry bits. + Output : next index, carry, sum + *) + Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) := + match digit with + | nil => f (0, 0) + | x :: nil => f (0, x) + | x :: tl => + compact_digit_cps n tl (fun rec => + dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in + dlet carry' := (fst rec + snd sum_carry)%RT in + f (carry', fst sum_carry)) + end. + + Definition compact_digit n digit := compact_digit_cps n digit id. + Lemma compact_digit_id n digit: forall {T} f, + @compact_digit_cps n digit T f = f (compact_digit n digit). + Proof using Type. + induction digit; intros; cbv [compact_digit]; [reflexivity|]; + simpl compact_digit_cps; break_match; [reflexivity|]. + rewrite !IHdigit; reflexivity. + Qed. + Hint Opaque compact_digit : uncps. + Hint Rewrite compact_digit_id : uncps. + + Definition compact_step_cps (index:nat) (carry:Z) (digit: list Z) + {T} (f:Z * Z->T) := + compact_digit_cps index (carry::digit) f. + + Definition compact_step i c d := compact_step_cps i c d id. + Lemma compact_step_id i c d T f : + @compact_step_cps i c d T f = f (compact_step i c d). + Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed. + Hint Opaque compact_step : uncps. + Hint Rewrite compact_step_id : uncps. + + Definition compact_cps {n} (xs : (list Z)^n) {T} (f:Z * Z^n->T) := + mapi_with_cps compact_step_cps 0 xs f. + + Definition compact {n} xs := @compact_cps n xs _ id. + Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs). + Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed. + + Lemma compact_digit_correct i (xs : list Z) : + snd (compact_digit i xs) = sum xs - (weight (S i) / weight i) * (fst (compact_digit i xs)). + Proof using add_get_carry_correct weight_0. + induction xs; cbv [compact_digit]; simpl compact_digit_cps; + cbv [Let_In]; + repeat match goal with + | _ => rewrite add_get_carry_correct + | _ => progress (rewrite ?sum_cons, ?sum_nil in * ) + | _ => progress (autorewrite with uncps push_id in * ) + | _ => progress (autorewrite with cancel_pair in * ) + | _ => progress break_match; try discriminate + | _ => progress ring_simplify + | _ => reflexivity + | _ => nsatz + end. + Qed. + + Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) := + B.Positional.eval_from weight i out + weight (i + n) * (rem) + = eval_from i inp + weight i*starter. + + Lemma compact_invariant_holds n i starter rem inp out : + compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out -> + compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out). + Proof using Type*. + cbv [compact_invariant B.Positional.eval_from]; intros. + repeat match goal with + | _ => rewrite B.Positional.eval_step + | _ => rewrite eval_from_S + | _ => rewrite sum_cons + | _ => rewrite weight_multiples + | _ => rewrite Nat.add_succ_l in * + | _ => rewrite Nat.add_succ_r in * + | _ => (rewrite fst_fst_compact_step in * ) + | _ => progress ring_simplify + | _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero + | _ => cbv [compact_step_cps] in *; + autorewrite with uncps push_id; + rewrite compact_digit_correct + | _ => progress (autorewrite with natsimplify in * ) + end. + rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega). + nsatz. + Qed. + + Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt. + Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed. + + Lemma compact_invariant_end {n} start (input : (list Z)^n): + compact_invariant n 0%nat start (fst (mapi_with_cps compact_step_cps start input id)) input (snd (mapi_with_cps compact_step_cps start input id)). + Proof using Type*. + autorewrite with uncps push_id. + apply (mapi_with_invariant _ compact_invariant + compact_invariant_holds compact_invariant_base). + Qed. + + Lemma eval_compact {n} (xs : tuple (list Z) n) : + B.Positional.eval weight (snd (compact xs)) + (weight n * fst (compact xs)) = eval xs. + Proof using Type*. + pose proof (compact_invariant_end 0 xs) as Hinv. + cbv [compact_invariant] in Hinv. + simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv. + rewrite eval_from_0, B.Positional.eval_from_0 in Hinv; apply Hinv. + Qed. + + Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n) + {T} (f:(list Z)^n->T) := + @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f. + + Definition cons_to_nth {n} i x t := @cons_to_nth_cps n i x t _ id. + Lemma cons_to_nth_id {n} i x t T f : + @cons_to_nth_cps n i x t T f = f (cons_to_nth i x t). + Proof using Type. + cbv [cons_to_nth_cps cons_to_nth]. + assert (forall xs : list (list Z), length xs = n -> + length (update_nth_cps i (cons x) xs id) = n) as Hlen. + { intros. autorewrite with uncps push_id distr_length. assumption. } + rewrite !on_tuple_cps_correct with (H:=Hlen) + by (intros; autorewrite with uncps push_id; reflexivity). reflexivity. + Qed. + Hint Opaque cons_to_nth : uncps. + Hint Rewrite @cons_to_nth_id : uncps. + + Lemma map_sum_update_nth l : forall i x, + List.map sum (update_nth i (cons x) l) = + update_nth i (Z.add x) (List.map sum l). + Proof using Type. + induction l; intros; destruct i; simpl; rewrite ?IHl; reflexivity. + Qed. + + Lemma cons_to_nth_add_to_nth n i x t : + map sum (@cons_to_nth n i x t) = B.Positional.add_to_nth i x (map sum t). + Proof using weight. + cbv [B.Positional.add_to_nth B.Positional.add_to_nth_cps cons_to_nth cons_to_nth_cps on_tuple_cps]. + induction n; [simpl; rewrite !update_nth_cps_correct; reflexivity|]. + specialize (IHn (tl t)). autorewrite with uncps push_id in *. + apply to_list_ext. rewrite <-!map_to_list. + erewrite !from_list_default_eq, !to_list_from_list. + rewrite map_sum_update_nth. reflexivity. + Unshelve. + distr_length. + distr_length. + Qed. + + Lemma eval_cons_to_nth n i x t : (i < n)%nat -> + eval (@cons_to_nth n i x t) = weight i * x + eval t. + Proof using Type. + cbv [eval]; intros. rewrite cons_to_nth_add_to_nth. + auto using B.Positional.eval_add_to_nth. + Qed. + Hint Rewrite eval_cons_to_nth using omega : push_basesystem_eval. + + Definition nils n : (list Z)^n := Tuple.repeat nil n. + + Lemma map_sum_nils n : map sum (nils n) = B.Positional.zeros n. + Proof using Type. + cbv [nils B.Positional.zeros]; induction n; [reflexivity|]. + change (repeat nil (S n)) with (@nil Z :: repeat nil n). + rewrite map_repeat, sum_nil. reflexivity. + Qed. + + Lemma eval_nils n : eval (nils n) = 0. + Proof using Type. cbv [eval]. rewrite map_sum_nils, B.Positional.eval_zeros. reflexivity. Qed. Hint Rewrite eval_nils : push_basesystem_eval. + + Definition from_associational_cps n (p:list B.limb) + {T} (f:(list Z)^n -> T) := + fold_right_cps + (fun t st => + B.Positional.place_cps weight t (pred n) + (fun p=> cons_to_nth_cps (fst p) (snd p) st id)) + (nils n) p f. + + Definition from_associational n p := from_associational_cps n p id. + Lemma from_associational_id n p T f : + @from_associational_cps n p T f = f (from_associational n p). + Proof using Type. + cbv [from_associational_cps from_associational]. + autorewrite with uncps push_id; reflexivity. + Qed. + Hint Opaque from_associational : uncps. + Hint Rewrite from_associational_id : uncps. + + Lemma eval_from_associational n p (n_nonzero:n<>0%nat): + eval (from_associational n p) = B.Associational.eval p. + Proof using weight_0 weight_nonzero. + cbv [from_associational_cps from_associational]; induction p; + autorewrite with uncps push_id push_basesystem_eval; [reflexivity|]. + pose proof (B.Positional.weight_place_cps weight weight_0 weight_nonzero a (pred n)). + pose proof (B.Positional.place_cps_in_range weight a (pred n)). + rewrite Nat.succ_pred in * by assumption. simpl. + autorewrite with uncps push_id push_basesystem_eval in *. + rewrite eval_cons_to_nth by omega. nsatz. + Qed. + + Definition mul_cps {n m} (p q : Z^n) {T} (f : (list Z)^m->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => B.Associational.mul_cps P Q + (fun PQ => from_associational_cps m PQ f))). + + Definition add_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) := + B.Positional.to_associational_cps weight p + (fun P => B.Positional.to_associational_cps weight q + (fun Q => from_associational_cps n (P++Q) f)). + + End Columns. +End Columns. + +(* +(* Just some pretty-printing *) +Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing). +Local Notation "snd~ a" := (let (_,y) := a in y) (at level 40, only printing). + +(* Simple example : base 10, multiply two bignums and compact them *) +Definition base10 i := Eval compute in 10^(Z.of_nat i). +Eval cbv -[runtime_add runtime_mul Let_In] in + (fun adc a0 a1 a2 b0 b1 b2 => + Columns.mul_cps (weight := base10) (n:=3) (a2,a1,a0) (b2,b1,b0) (fun ab => Columns.compact (n:=5) (add_get_carry:=adc) (weight:=base10) ab)). + +(* More complex example : base 2^56, 8 limbs *) +Definition base2pow56 i := Eval compute in 2^(56*Z.of_nat i). +Time Eval cbv -[runtime_add runtime_mul Let_In] in + (fun adc a0 a1 a2 a3 a4 a5 a6 a7 b0 b1 b2 b3 b4 b5 b6 b7 => + Columns.mul_cps (weight := base2pow56) (n:=8) (a7,a6,a5,a4,a3,a2,a1,a0) (b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=15) (add_get_carry:=adc) (weight:=base2pow56) ab)). (* Finished transaction in 151.392 secs *) + +(* Mixed-radix example : base 2^25.5, 10 limbs *) +Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1) / 2)). +Time Eval cbv -[runtime_add runtime_mul Let_In] in + (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 => + Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *) +*) |