aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@mit.edu>2017-04-06 22:53:07 -0400
committerGravatar Andres Erbsen <andreser@mit.edu>2017-04-06 22:53:07 -0400
commitc9fc5a3cdf1f5ea2d104c150c30d1b1a6ac64239 (patch)
treedb7187f6984acff324ca468e7b33d9285806a1eb /src/Arithmetic
parent21198245dab432d3c0ba2bb8a02254e7d0594382 (diff)
rename-everything
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/BarrettReduction/Generalized.v140
-rw-r--r--src/Arithmetic/BarrettReduction/HAC.v158
-rw-r--r--src/Arithmetic/BarrettReduction/Wikipedia.v122
-rw-r--r--src/Arithmetic/Core.v980
-rw-r--r--src/Arithmetic/Karatsuba.v49
-rw-r--r--src/Arithmetic/ModularArithmeticPre.v139
-rw-r--r--src/Arithmetic/ModularArithmeticTheorems.v347
-rw-r--r--src/Arithmetic/MontgomeryReduction/Definition.v179
-rw-r--r--src/Arithmetic/MontgomeryReduction/Proofs.v296
-rw-r--r--src/Arithmetic/PrimeFieldTheorems.v294
-rw-r--r--src/Arithmetic/Saturated.v285
11 files changed, 2989 insertions, 0 deletions
diff --git a/src/Arithmetic/BarrettReduction/Generalized.v b/src/Arithmetic/BarrettReduction/Generalized.v
new file mode 100644
index 000000000..76058463c
--- /dev/null
+++ b/src/Arithmetic/BarrettReduction/Generalized.v
@@ -0,0 +1,140 @@
+(*** Barrett Reduction *)
+(** This file implements a slightly-generalized version of Barrett
+ Reduction on [Z]. This version follows a middle path between the
+ Handbook of Applied Cryptography (Algorithm 14.42) and Wikipedia.
+ We split up the shifting and the multiplication so that we don't
+ need to store numbers that are quite so large, but we don't do
+ early reduction modulo [b^(k+offset)] (we generalize from HAC's [k
+ ± 1] to [k ± offset]). This leads to weaker conditions on the
+ base ([b]), exponent ([k]), and the [offset] than those given in
+ the HAC. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
+Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.BreakMatch.
+
+Local Open Scope Z_scope.
+
+Section barrett.
+ Context (n a : Z)
+ (n_reasonable : n <> 0).
+ (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *)
+ (** In modular arithmetic, Barrett reduction is a reduction
+ algorithm introduced in 1986 by P.D. Barrett. A naive way of
+ computing *)
+ (** [c = a mod n] *)
+ (** would be to use a fast division algorithm. Barrett reduction is
+ an algorithm designed to optimize this operation assuming [n] is
+ constant, and [a < n²], replacing divisions by
+ multiplications. *)
+
+ (** * General idea *)
+ Section general_idea.
+ (** Let [m = 1 / n] be the inverse of [n] as a floating point
+ number. Then *)
+ (** [a mod n = a - ⌊a m⌋ n] *)
+ (** where [⌊ x ⌋] denotes the floor function. The result is exact,
+ as long as [m] is computed with sufficient accuracy. *)
+
+ (* [/] is [Z.div], which means truncated division *)
+ Local Notation "⌊am⌋" := (a / n) (only parsing).
+
+ Theorem naive_barrett_reduction_correct
+ : a mod n = a - ⌊am⌋ * n.
+ Proof using n_reasonable.
+ apply Zmod_eq_full; assumption.
+ Qed.
+ End general_idea.
+
+ (** * Barrett algorithm *)
+ Section barrett_algorithm.
+ (** Barrett algorithm is a fixed-point analog which expresses
+ everything in terms of integers. Let [k] be the smallest
+ integer such that [2ᵏ > n]. Think of [n] as representing the
+ fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m =
+ ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number
+ [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *)
+ (** N.B. We don't need [k] to be the smallest such integer. *)
+ (** N.B. We generalize to an arbitrary base. *)
+ (** N.B. We generalize from [k ± 1] to [k ± offset]. *)
+ Context (b : Z)
+ (base_good : 0 < b)
+ (k : Z)
+ (k_good : n < b ^ k)
+ (m : Z)
+ (m_good : m = b^(2*k) / n) (* [/] is [Z.div], which is truncated *)
+ (offset : Z)
+ (offset_nonneg : 0 <= offset).
+ (** Wikipedia neglects to mention non-negativity, but we need it.
+ It might be possible to do with a relaxed assumption, such as
+ the sign of [a] and the sign of [n] being the same; but I
+ figured it wasn't worth it. *)
+ Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *)
+ (a_nonneg : 0 <= a).
+
+ Context (k_big_enough : offset <= k)
+ (a_small : a < b^(2*k))
+ (** We also need that [n] is large enough; [n] larger than
+ [bᵏ⁻¹] works, but we ask for something more precise. *)
+ (n_large : a mod b^(k-offset) <= n).
+
+ (** Now *)
+
+ Let q := (m * (a / b^(k-offset))) / b^(k+offset).
+ Let r := a - q * n.
+ (** Because of the floor function (in Coq, because [/] means
+ truncated division), [q] is an integer and [r ≡ a mod n]. *)
+ Theorem barrett_reduction_equivalent
+ : r mod n = a mod n.
+ Proof using m_good offset.
+ subst r q m.
+ rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption.
+ reflexivity.
+ Qed.
+
+ Lemma qn_small
+ : q * n <= a.
+ Proof using a_nonneg a_small base_good k_big_enough m_good n_pos n_reasonable offset_nonneg.
+ subst q r m.
+ assert (0 < b^(k-offset)). zero_bounds.
+ assert (0 < b^(k+offset)) by zero_bounds.
+ assert (0 < b^(2 * k)) by zero_bounds.
+ Z.simplify_fractions_le.
+ autorewrite with pull_Zpow pull_Zdiv zsimplify; reflexivity.
+ Qed.
+
+ Lemma q_nice : { b : bool * bool | q = a / n + (if fst b then -1 else 0) + (if snd b then -1 else 0) }.
+ Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg.
+ assert (0 < b^(k+offset)) by zero_bounds.
+ assert (0 < b^(k-offset)) by zero_bounds.
+ assert (a / b^(k-offset) <= b^(2*k) / b^(k-offset)) by auto with zarith lia.
+ assert (a / b^(k-offset) <= b^(k+offset)) by (autorewrite with pull_Zpow zsimplify in *; assumption).
+ subst q r m.
+ rewrite (Z.div_mul_diff_exact''' (b^(2*k)) n (a/b^(k-offset))) by auto with lia zero_bounds.
+ rewrite (Z_div_mod_eq (b^(2*k) * _ / n) (b^(k+offset))) by lia.
+ autorewrite with push_Zmul push_Zopp zsimplify zstrip_div zdiv_to_mod.
+ rewrite Z.div_sub_mod_cond, !Z.div_sub_small by auto with zero_bounds zarith.
+ eexists (_, _); reflexivity.
+ Qed.
+
+ Lemma r_small : r < 3 * n.
+ Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg q.
+ Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div.
+ assert (a mod n < n) by auto with zarith lia.
+ unfold r; rewrite (proj2_sig q_nice); generalize (proj1_sig q_nice); intro; subst q m.
+ autorewrite with push_Zmul zsimplify zstrip_div.
+ break_match; auto with lia.
+ Qed.
+
+ (** In that case, we have *)
+ Theorem barrett_reduction_small
+ : a mod n = let r := if r <? n then r else r-n in
+ let r := if r <? n then r else r-n in
+ r.
+ Proof using a_nonneg a_small base_good k_big_enough m_good n_large n_pos n_reasonable offset_nonneg q.
+ pose proof r_small. pose proof qn_small. cbv zeta.
+ destruct (r <? n) eqn:Hr, (r-n <? n) eqn:?; try rewrite Hr; Z.ltb_to_lt; try lia.
+ { symmetry; apply (Zmod_unique a n q); subst r; lia. }
+ { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. }
+ { symmetry; apply (Zmod_unique a n (q + 2)); subst r; lia. }
+ Qed.
+ End barrett_algorithm.
+End barrett.
diff --git a/src/Arithmetic/BarrettReduction/HAC.v b/src/Arithmetic/BarrettReduction/HAC.v
new file mode 100644
index 000000000..70661ee96
--- /dev/null
+++ b/src/Arithmetic/BarrettReduction/HAC.v
@@ -0,0 +1,158 @@
+(*** Barrett Reduction *)
+(** This file implements a slightly-generalized version of Barrett
+ Reduction on [Z]. This version follows the Handbook of Applied
+ Cryptography (Algorithm 14.42) rather closely; the only deviations
+ are that we generalize from [k ± 1] to [k ± offset] for an
+ arbitrary offset, and we weaken the conditions on the base [b] in
+ [bᵏ] slightly. Contrasted with some other versions, this version
+ does reduction modulo [b^(k+offset)] early (ensuring that we don't
+ have to carry around extra precision), but requires more stringint
+ conditions on the base ([b]), exponent ([k]), and the [offset]. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
+Require Import Crypto.Util.ZUtil Crypto.Util.Tactics.BreakMatch.
+
+Local Open Scope Z_scope.
+
+Section barrett.
+ (** Quoting the Handbook of Applied Cryptography <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>: *)
+ (** Barrett reduction (Algorithm 14.42) computes [r = x mod m] given
+ [x] and [m]. The algorithm requires the precomputation of the
+ quantity [µ = ⌊b²ᵏ/m⌋]; it is advantageous if many reductions
+ are performed with a single modulus. For example, each RSA
+ encryption for one entity requires reduction modulo that
+ entity’s public key modulus. The precomputation takes a fixed
+ amount of work, which is negligible in comparison to modular
+ exponentiation cost. Typically, the radix [b] is chosen to be
+ close to the word-size of the processor. Hence, assume [b > 3] in
+ Algorithm 14.42 (see Note 14.44 (ii)). *)
+
+ (** * Barrett modular reduction *)
+ Section barrett_modular_reduction.
+ Context (m b x k μ offset : Z)
+ (m_pos : 0 < m)
+ (base_pos : 0 < b)
+ (k_good : m < b^k)
+ (μ_good : μ = b^(2*k) / m) (* [/] is [Z.div], which is truncated *)
+ (x_nonneg : 0 <= x)
+ (offset_nonneg : 0 <= offset)
+ (k_big_enough : offset <= k)
+ (x_small : x < b^(2*k))
+ (m_small : 3 * m <= b^(k+offset))
+ (** We also need that [m] is large enough; [m] larger than
+ [bᵏ⁻¹] works, but we ask for something more precise. *)
+ (m_large : x mod b^(k-offset) <= m).
+
+ Let q1 := x / b^(k-offset). Let q2 := q1 * μ. Let q3 := q2 / b^(k+offset).
+ Let r1 := x mod b^(k+offset). Let r2 := (q3 * m) mod b^(k+offset).
+ (** At this point, the HAC says "If [r < 0] then [r ← r + bᵏ⁺¹]".
+ This is equivalent to reduction modulo [b^(k+offset)], as we
+ prove below. The version involving modular reduction has the
+ benefit of being cheaper to implement, and making the proofs
+ simpler, so we primarily use that version. *)
+ Let r_mod_3m := (r1 - r2) mod b^(k+offset).
+ Let r_mod_3m_orig := let r := r1 - r2 in
+ if r <? 0 then r + b^(k+offset) else r.
+
+ Lemma r_mod_3m_eq_orig : r_mod_3m = r_mod_3m_orig.
+ Proof using base_pos k_big_enough m_pos m_small offset_nonneg r1 r2.
+ assert (0 <= r1 < b^(k+offset)) by (subst r1; auto with zarith).
+ assert (0 <= r2 < b^(k+offset)) by (subst r2; auto with zarith).
+ subst r_mod_3m r_mod_3m_orig; cbv zeta.
+ break_match; Z.ltb_to_lt.
+ { symmetry; apply (Zmod_unique (r1 - r2) _ (-1)); lia. }
+ { symmetry; apply (Zmod_unique (r1 - r2) _ 0); lia. }
+ Qed.
+
+ (** 14.43 Fact By the division algorithm (Definition 2.82), there
+ exist integers [Q] and [R] such that [x = Qm + R] and [0 ≤ R <
+ m]. In step 1 of Algorithm 14.42 (Barrett modular reduction),
+ the following inequality is satisfied: [Q - 2 ≤ q₃ ≤ Q]. *)
+ (** We prove this by providing a more useful form for [q₃]. *)
+ Let Q := x / m.
+ Let R := x mod m.
+ Lemma q3_nice : { b : bool * bool | q3 = Q + (if fst b then -1 else 0) + (if snd b then -1 else 0) }.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg x_nonneg x_small μ_good.
+ assert (0 < b^(k+offset)) by zero_bounds.
+ assert (0 < b^(k-offset)) by zero_bounds.
+ assert (x / b^(k-offset) <= b^(2*k) / b^(k-offset)) by auto with zarith lia.
+ assert (x / b^(k-offset) <= b^(k+offset)) by (autorewrite with pull_Zpow zsimplify in *; assumption).
+ subst q1 q2 q3 Q r_mod_3m r_mod_3m_orig r1 r2 R μ.
+ rewrite (Z.div_mul_diff_exact' (b^(2*k)) m (x/b^(k-offset))) by auto with lia zero_bounds.
+ rewrite (Z_div_mod_eq (_ * b^(2*k) / m) (b^(k+offset))) by lia.
+ autorewrite with push_Zmul push_Zopp zsimplify zstrip_div zdiv_to_mod.
+ rewrite Z.div_sub_mod_cond, !Z.div_sub_small; auto with zero_bounds zarith.
+ eexists (_, _); reflexivity.
+ Qed.
+
+ Fact q3_in_range : Q - 2 <= q3 <= Q.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good.
+ rewrite (proj2_sig q3_nice).
+ break_match; lia.
+ Qed.
+
+ (** 14.44 Note (partial justification of correctness of Barrett reduction) *)
+ (** (i) Algorithm 14.42 is based on the observation that [⌊x/m⌋]
+ can be written as [Q =
+ ⌊(x/bᵏ⁻¹)(b²ᵏ/m)(1/bᵏ⁺¹)⌋]. Moreover, [Q] can be
+ approximated by the quantity [q₃ = ⌊⌊x/bᵏ⁻¹⌋µ/bᵏ⁺¹⌋].
+ Fact 14.43 guarantees that [q₃] is never larger than the
+ true quotient [Q], and is at most 2 smaller. *)
+ Lemma x_minus_q3_m_in_range : 0 <= x - q3 * m < 3 * m.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good.
+ pose proof q3_in_range.
+ assert (0 <= R < m) by (subst R; auto with zarith).
+ assert (0 <= (Q - q3) * m + R < 3 * m) by nia.
+ subst Q R; autorewrite with push_Zmul zdiv_to_mod in *; lia.
+ Qed.
+
+ Lemma r_mod_3m_eq_alt : r_mod_3m = x - q3 * m.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg q2 x_nonneg x_small μ_good.
+ pose proof x_minus_q3_m_in_range.
+ subst r_mod_3m r_mod_3m_orig r1 r2.
+ autorewrite with pull_Zmod zsimplify; reflexivity.
+ Qed.
+
+ (** This version uses reduction modulo [b^(k+offset)]. *)
+ Theorem barrett_reduction_equivalent
+ : r_mod_3m mod m = x mod m.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r1 r2 x_nonneg x_small μ_good.
+ rewrite r_mod_3m_eq_alt.
+ autorewrite with zsimplify push_Zmod; reflexivity.
+ Qed.
+
+ (** This version, which matches the original in the HAC, uses
+ conditional addition of [b^(k+offset)]. *)
+ Theorem barrett_reduction_orig_equivalent
+ : r_mod_3m_orig mod m = x mod m.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r_mod_3m x_nonneg x_small μ_good. rewrite <- r_mod_3m_eq_orig; apply barrett_reduction_equivalent. Qed.
+
+ Lemma r_small : 0 <= r_mod_3m < 3 * m.
+ Proof using Q R base_pos k_big_enough m_large m_pos m_small offset_nonneg q3 x_nonneg x_small μ_good.
+ pose proof x_minus_q3_m_in_range.
+ subst Q R r_mod_3m r_mod_3m_orig r1 r2.
+ autorewrite with pull_Zmod zsimplify; lia.
+ Qed.
+
+
+ (** This version uses reduction modulo [b^(k+offset)]. *)
+ Theorem barrett_reduction_small (r := r_mod_3m)
+ : x mod m = let r := if r <? m then r else r-m in
+ let r := if r <? m then r else r-m in
+ r.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r1 r2 x_nonneg x_small μ_good.
+ pose proof r_small. cbv zeta.
+ destruct (r <? m) eqn:Hr, (r-m <? m) eqn:?; subst r; rewrite !r_mod_3m_eq_alt, ?Hr in *; Z.ltb_to_lt; try lia.
+ { symmetry; eapply (Zmod_unique x m q3); lia. }
+ { symmetry; eapply (Zmod_unique x m (q3 + 1)); lia. }
+ { symmetry; eapply (Zmod_unique x m (q3 + 2)); lia. }
+ Qed.
+
+ (** This version, which matches the original in the HAC, uses
+ conditional addition of [b^(k+offset)]. *)
+ Theorem barrett_reduction_small_orig (r := r_mod_3m_orig)
+ : x mod m = let r := if r <? m then r else r-m in
+ let r := if r <? m then r else r-m in
+ r.
+ Proof using base_pos k_big_enough m_large m_pos m_small offset_nonneg r_mod_3m x_nonneg x_small μ_good. subst r; rewrite <- r_mod_3m_eq_orig; apply barrett_reduction_small. Qed.
+ End barrett_modular_reduction.
+End barrett.
diff --git a/src/Arithmetic/BarrettReduction/Wikipedia.v b/src/Arithmetic/BarrettReduction/Wikipedia.v
new file mode 100644
index 000000000..69ce10c4b
--- /dev/null
+++ b/src/Arithmetic/BarrettReduction/Wikipedia.v
@@ -0,0 +1,122 @@
+(*** Barrett Reduction *)
+(** This file implements Barrett Reduction on [Z]. We follow Wikipedia. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
+
+Local Open Scope Z_scope.
+
+Section barrett.
+ Context (n a : Z)
+ (n_reasonable : n <> 0).
+ (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *)
+ (** In modular arithmetic, Barrett reduction is a reduction
+ algorithm introduced in 1986 by P.D. Barrett. A naive way of
+ computing *)
+ (** [c = a mod n] *)
+ (** would be to use a fast division algorithm. Barrett reduction is
+ an algorithm designed to optimize this operation assuming [n] is
+ constant, and [a < n²], replacing divisions by
+ multiplications. *)
+
+ (** * General idea *)
+ Section general_idea.
+ (** Let [m = 1 / n] be the inverse of [n] as a floating point
+ number. Then *)
+ (** [a mod n = a - ⌊a m⌋ n] *)
+ (** where [⌊ x ⌋] denotes the floor function. The result is exact,
+ as long as [m] is computed with sufficient accuracy. *)
+
+ (* [/] is [Z.div], which means truncated division *)
+ Local Notation "⌊am⌋" := (a / n) (only parsing).
+
+ Theorem naive_barrett_reduction_correct
+ : a mod n = a - ⌊am⌋ * n.
+ Proof using n_reasonable.
+ apply Zmod_eq_full; assumption.
+ Qed.
+ End general_idea.
+
+ (** * Barrett algorithm *)
+ Section barrett_algorithm.
+ (** Barrett algorithm is a fixed-point analog which expresses
+ everything in terms of integers. Let [k] be the smallest
+ integer such that [2ᵏ > n]. Think of [n] as representing the
+ fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m =
+ ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number
+ [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *)
+ (** N.B. We don't need [k] to be the smallest such integer. *)
+ Context (k : Z)
+ (k_good : n < 2 ^ k)
+ (m : Z)
+ (m_good : m = 4^k / n). (* [/] is [Z.div], which is truncated *)
+ (** Wikipedia neglects to mention non-negativity, but we need it.
+ It might be possible to do with a relaxed assumption, such as
+ the sign of [a] and the sign of [n] being the same; but I
+ figured it wasn't worth it. *)
+ Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *)
+ (a_nonneg : 0 <= a).
+
+ Lemma k_nonnegative : 0 <= k.
+ Proof using Type*.
+ destruct (Z_lt_le_dec k 0); try assumption.
+ rewrite !Z.pow_neg_r in * by lia; lia.
+ Qed.
+
+ (** Now *)
+ Let q := (m * a) / 4^k.
+ Let r := a - q * n.
+ (** Because of the floor function (in Coq, because [/] means
+ truncated division), [q] is an integer and [r ≡ a mod n]. *)
+ Theorem barrett_reduction_equivalent
+ : r mod n = a mod n.
+ Proof using m_good.
+ subst r q m.
+ rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption.
+ reflexivity.
+ Qed.
+
+ Lemma qn_small
+ : q * n <= a.
+ Proof using a_nonneg k_good m_good n_pos n_reasonable.
+ pose proof k_nonnegative; subst q r m.
+ assert (0 <= 2^(k-1)) by zero_bounds.
+ Z.simplify_fractions_le.
+ Qed.
+
+ (** Also, if [a < n²] then [r < 2n]. *)
+ (** N.B. It turns out that it is sufficient to assume [a < 4ᵏ]. *)
+ Context (a_small : a < 4^k).
+ Lemma q_nice : { b : bool | q = a / n + if b then -1 else 0 }.
+ Proof using a_nonneg a_small k_good m_good n_pos n_reasonable.
+ assert (0 <= (4 ^ k * a / n) mod 4 ^ k < 4 ^ k) by auto with zarith lia.
+ assert (0 <= a * (4 ^ k mod n) / n < 4 ^ k) by (auto with zero_bounds zarith lia).
+ subst q r m.
+ rewrite (Z.div_mul_diff_exact''' (4^k) n a) by lia.
+ rewrite (Z_div_mod_eq (4^k * _ / n) (4^k)) by lia.
+ autorewrite with push_Zmul push_Zopp zsimplify zstrip_div.
+ eexists; reflexivity.
+ Qed.
+
+ Lemma r_small : r < 2 * n.
+ Proof using a_nonneg a_small k_good m_good n_pos n_reasonable q.
+ Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div.
+ assert (a mod n < n) by auto with zarith lia.
+ unfold r; rewrite (proj2_sig q_nice); generalize (proj1_sig q_nice); intro; subst q m.
+ autorewrite with push_Zmul zsimplify zstrip_div.
+ break_match; auto with lia.
+ Qed.
+
+ (** In that case, we have *)
+ Theorem barrett_reduction_small
+ : a mod n = if r <? n
+ then r
+ else r - n.
+ Proof using a_nonneg a_small k_good m_good n_pos n_reasonable q.
+ pose proof r_small. pose proof qn_small.
+ destruct (r <? n) eqn:rlt; Z.ltb_to_lt.
+ { symmetry; apply (Zmod_unique a n q); subst r; lia. }
+ { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. }
+ Qed.
+ End barrett_algorithm.
+End barrett.
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
new file mode 100644
index 000000000..2613765d0
--- /dev/null
+++ b/src/Arithmetic/Core.v
@@ -0,0 +1,980 @@
+(*****
+
+This file provides a generalized version of arithmetic with "mixed
+radix" numerical systems. Later, parameters are entered into the
+general functions, and they are partially evaluated until only runtime
+basic arithmetic operations remain.
+
+CPS
+---
+
+Fuctions are written in continuation passing style (CPS). This means
+that each operation is passed a "continuation" function, which it is
+expected to call on its own output (like a callback). See the end of
+this comment for a motivating example explaining why we do CPS,
+despite a fair amount of resulting boilerplate code for each
+operation. The code block for an operation called A would look like
+this:
+
+```
+Definition A_cps x y {T} f : T := ...
+
+Definition A x y := A_cps x y id.
+Lemma A_cps_id x y : forall {T} f, @A_cps x y T f = f (A x y).
+Hint Opaque A : uncps.
+Hint Rewrite A_cps_id : uncps.
+
+Lemma eval_A x y : eval (A x y) = ...
+Hint Rewrite eval_A : push_basesystem_eval.
+```
+
+`A_cps` is the main, CPS-style definition of the operation (`f` is the
+continuation function). `A` is the non-CPS version of `A_cps`, simply
+defined by passing an identity function to `A_cps`. `A_cps_id` states
+that we can replace the CPS version with the non-cps version. `eval_A`
+is the actual correctness lemma for the operation, stating that it has
+the correct arithmetic properties. In general, the middle block
+containing `A` and `A_cps_id` is boring boilerplate and can be safely
+ignored.
+
+HintDbs
+-------
+
++ `uncps` : Converts CPS operations to their non-CPS versions.
++ `push_basesystem_eval` : Contains all the correctness lemmas for
+ operations in this file, which are in terms of the `eval` function.
+
+Positional/Associational
+------------------------
+
+We represent mixed-radix numbers in a few different ways:
+
++ "Positional" : a tuple of numbers and a weight function (nat->Z),
+which is evaluated by multiplying the `i`th element of the tuple by
+`weight i`, and then summing the products.
++ "Associational" : a list of pairs of numbers--the first is the
+weight, the second is the runtime value. Evaluated by multiplying each
+pair and summing the products.
+
+The associational representation is good for basic operations like
+addition and multiplication; for addition, one can simply just append
+two associational lists. But the end-result code should use the
+positional representation (with each digit representing a machine
+word). Since converting to and fro can be easily compiled away once
+the weight function is known, we use associational to write most of
+the operations and liberally convert back and forth to ensure correct
+output. In particular, it is important to convert before carrying.
+
+Runtime Operations
+------------------
+
+Since some instances of e.g. Z.add or Z.mul operate on (compile-time)
+weights, and some operate on runtime values, we need a way to
+differentiate these cases before partial evaluation. We define a
+runtime_scope to mark certain additions/multiplications as runtime
+values, so they will not be unfolded during partial evaluation. For
+instance, if we have:
+
+```
+Definition f (x y : Z * Z) := (fst x + fst y, (snd x + snd y)%RT).
+```
+
+then when we are partially evaluating `f`, we can easily exclude the
+runtime operations (`cbv - [runtime_add]`) and prevent Coq from trying
+to simplify the second addition.
+
+
+Why CPS?
+--------
+
+Let's suppose we want to add corresponding elements of two `list Z`s
+(so on inputs `[1,2,3]` and `[2,3,1]`, we get `[3,5,4]`). We might
+write our function like this :
+
+```
+Fixpoint add_lists (p q : list Z) :=
+ match p, q with
+ | p0 :: p', q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ | _, _ => nil
+ end.
+```
+
+(Note : `dlet` is a notation for `Let_In`, which is just a dumb
+wrapper for `let`. This allows us to `cbv - [Let_In]` if we want to
+not simplify certain `let`s.)
+
+A CPS equivalent of `add_lists` would look like this:
+
+```
+Fixpoint add_lists_cps (p q : list Z) {T} (f:list Z->T) :=
+ match p, q with
+ | p0 :: p', q0 :: q' =>
+ dlet sum := p0 + q0 in
+ add_lists_cps p' q' (fun r => f (sum :: r))
+ | _, _ => f nil
+ end.
+```
+
+Now let's try some partial evaluation. The expression we'll evaluate is:
+
+```
+Definition x :=
+ (fun a0 a1 a2 b0 b1 b2 =>
+ let r := add_lists [a0;a1;a2] [b0;b1;b2] in
+ let rr := add_lists r r in
+ add_lists rr rr).
+```
+
+Or, using `add_lists_cps`:
+
+```
+Definition y :=
+ (fun a0 a1 a2 b0 b1 b2 =>
+ add_lists_cps [a0;a1;a2] [b0;b1;b2]
+ (fun r => add_lists_cps r r
+ (fun rr => add_lists_cps rr rr id))).
+```
+
+If we run `Eval cbv -[Z.add] in x` and `Eval cbv -[Z.add] in y`, we get
+identical output:
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ [a0 + b0 + (a0 + b0) + (a0 + b0 + (a0 + b0));
+ a1 + b1 + (a1 + b1) + (a1 + b1 + (a1 + b1));
+ a2 + b2 + (a2 + b2) + (a2 + b2 + (a2 + b2))]
+```
+
+However, there are a lot of common subexpressions here--this is what
+the `dlet` we put into the functions should help us avoid. Let's try
+`Eval cbv -[Let_In Z.add] in x`:
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ (fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ ((fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1])))
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1]))))
+ ((fix add_lists (p q : list Z) {struct p} :
+ list Z :=
+ match p with
+ | [] => []
+ | p0 :: p' =>
+ match q with
+ | [] => []
+ | q0 :: q' =>
+ dlet sum := p0 + q0 in
+ sum :: add_lists p' q'
+ end
+ end)
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1])))
+ (dlet sum := a0 + b0 in
+ sum
+ :: (dlet sum0 := a1 + b1 in
+ sum0 :: (dlet sum1 := a2 + b2 in
+ [sum1]))))
+```
+
+Not so great. Because the `dlet`s are stuck in the inner terms, we
+can't simplify the expression very nicely. Let's try that on the CPS
+version (`Eval cbv -[Let_In Z.add] in y`):
+
+```
+fun a0 a1 a2 b0 b1 b2 : Z =>
+ dlet sum := a0 + b0 in
+ dlet sum0 := a1 + b1 in
+ dlet sum1 := a2 + b2 in
+ dlet sum2 := sum + sum in
+ dlet sum3 := sum0 + sum0 in
+ dlet sum4 := sum1 + sum1 in
+ dlet sum5 := sum2 + sum2 in
+ dlet sum6 := sum3 + sum3 in
+ dlet sum7 := sum4 + sum4 in
+ [sum5; sum6; sum7]
+```
+
+Isn't that lovely? Since we can push continuation functions "under"
+the `dlet`s, we can end up with a nice, concise, simplified
+expression.
+
+One might suggest that we could just inline the `dlet`s and do common
+subexpression elimination. But some of our terms have so many `dlet`s
+that inlining them all would make a term too huge to process in
+reasonable time, so this is not really an option.
+
+*****)
+
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
+Require Import Coq.ZArith.BinIntDef.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Algebra.Nsatz.
+Require Import Crypto.Util.Decidable Crypto.Util.LetIn.
+Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
+Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.
+Require Import Crypto.Arithmetic.PrimeFieldTheorems.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.UniquePose.
+Require Import Crypto.Util.Tactics.VM.
+
+Require Import Coq.Lists.List. Import ListNotations.
+Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
+
+Local Ltac prove_id :=
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress simpl
+ | _ => progress cbv [Let_In]
+ | _ => progress (autorewrite with uncps push_id in * )
+ | _ => break_innermost_match_step
+ | _ => contradiction
+ | _ => reflexivity
+ | _ => nsatz
+ | _ => solve [auto]
+ end.
+
+Create HintDb push_basesystem_eval discriminated.
+Local Ltac prove_eval :=
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress simpl
+ | _ => progress cbv [Let_In]
+ | _ => progress (autorewrite with push_basesystem_eval uncps push_id cancel_pair in * )
+ | _ => break_innermost_match_step
+ | _ => split
+ | H : _ /\ _ |- _ => destruct H
+ | H : Some _ = Some _ |- _ => progress (inversion H; subst)
+ | _ => discriminate
+ | _ => reflexivity
+ | _ => nsatz
+ end.
+
+Definition mod_eq (m:positive) a b := a mod m = b mod m.
+Global Instance mod_eq_equiv m : RelationClasses.Equivalence (mod_eq m).
+Proof. constructor; congruence. Qed.
+Definition mod_eq_dec m a b : {mod_eq m a b} + {~ mod_eq m a b}
+ := Z.eq_dec _ _.
+Lemma mod_eq_Z2F_iff m a b :
+ mod_eq m a b <-> Logic.eq (F.of_Z m a) (F.of_Z m b).
+Proof. rewrite <-F.eq_of_Z_iff; reflexivity. Qed.
+
+Delimit Scope runtime_scope with RT.
+
+Definition runtime_mul := Z.mul.
+Global Notation "a * b" := (runtime_mul a%RT b%RT) : runtime_scope.
+Definition runtime_add := Z.add.
+Global Notation "a + b" := (runtime_add a%RT b%RT) : runtime_scope.
+Definition runtime_opp := Z.opp.
+Global Notation "- a" := (runtime_opp a%RT) : runtime_scope.
+Definition runtime_and := Z.land.
+Global Notation "a &' b" := (runtime_and a%RT b%RT) : runtime_scope.
+Definition runtime_shr := Z.shiftr.
+Global Notation "a >> b" := (runtime_shr a%RT b%RT) : runtime_scope.
+
+Module B.
+ Definition limb := (Z*Z)%type. (* position coefficient and run-time value *)
+ Module Associational.
+ Definition eval (p:list limb) : Z :=
+ List.fold_right Z.add 0%Z (List.map (fun t => fst t * snd t) p).
+
+ Lemma eval_nil : eval nil = 0. Proof. reflexivity. Qed.
+ Lemma eval_cons p q : eval (p::q) = (fst p) * (snd p) + eval q. Proof. reflexivity. Qed.
+ Lemma eval_app p q: eval (p++q) = eval p + eval q.
+ Proof. induction p; simpl eval; rewrite ?eval_nil, ?eval_cons; nsatz. Qed.
+ Hint Rewrite eval_nil eval_cons eval_app : push_basesystem_eval.
+
+ Definition multerm (t t' : limb) : limb :=
+ (fst t * fst t', (snd t * snd t')%RT).
+ Lemma eval_map_multerm (a:limb) (q:list limb)
+ : eval (List.map (multerm a) q) = fst a * snd a * eval q.
+ Proof.
+ induction q; cbv [multerm]; simpl List.map;
+ autorewrite with push_basesystem_eval cancel_pair; nsatz.
+ Qed. Hint Rewrite eval_map_multerm : push_basesystem_eval.
+
+ Definition mul_cps (p q:list limb) {T} (f : list limb->T) :=
+ flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f.
+
+ Definition mul (p q:list limb) := mul_cps p q id.
+ Lemma mul_cps_id p q: forall {T} f, @mul_cps p q T f = f (mul p q).
+ Proof. cbv [mul_cps mul]; prove_id. Qed.
+ Hint Opaque mul : uncps.
+ Hint Rewrite mul_cps_id : uncps.
+
+ Lemma eval_mul p q: eval (mul p q) = eval p * eval q.
+ Proof. cbv [mul mul_cps]; induction p; prove_eval. Qed.
+ Hint Rewrite eval_mul : push_basesystem_eval.
+
+ Fixpoint split_cps (s:Z) (xs:list limb)
+ {T} (f :list limb*list limb->T) :=
+ match xs with
+ | nil => f (nil, nil)
+ | cons x xs' =>
+ split_cps s xs'
+ (fun sxs' =>
+ if dec (fst x mod s = 0)
+ then f (fst sxs', cons (fst x / s, snd x) (snd sxs'))
+ else f (cons x (fst sxs'), snd sxs'))
+ end.
+
+ Definition split s xs := split_cps s xs id.
+ Lemma split_cps_id s p: forall {T} f,
+ @split_cps s p T f = f (split s p).
+ Proof.
+ induction p;
+ repeat match goal with
+ | _ => rewrite IHp
+ | _ => progress (cbv [split]; prove_id)
+ end.
+ Qed.
+ Hint Opaque split : uncps.
+ Hint Rewrite split_cps_id : uncps.
+
+ Lemma eval_split s p (s_nonzero:s<>0):
+ eval (fst (split s p)) + s*eval (snd (split s p)) = eval p.
+ Proof.
+ cbv [split]; induction p; prove_eval.
+ match goal with
+ H:_ |- _ =>
+ unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H)
+ end; nsatz.
+ Qed. Hint Rewrite @eval_split using auto : push_basesystem_eval.
+
+ Definition reduce_cps (s:Z) (c:list limb) (p:list limb)
+ {T} (f : list limb->T) :=
+ split_cps s p
+ (fun ab => mul_cps c (snd ab)
+ (fun rr =>f (fst ab ++ rr))).
+
+ Definition reduce s c p := reduce_cps s c p id.
+ Lemma reduce_cps_id s c p {T} f:
+ @reduce_cps s c p T f = f (reduce s c p).
+ Proof. cbv [reduce_cps reduce]; prove_id. Qed.
+ Hint Opaque reduce : uncps.
+ Hint Rewrite reduce_cps_id : uncps.
+
+ Lemma reduction_rule a b s c m (m_eq:Z.pos m = s - c):
+ (a + s * b) mod m = (a + c * b) mod m.
+ Proof.
+ rewrite m_eq. pose proof (Pos2Z.is_pos m).
+ replace (a + s * b) with ((a + c*b) + b*(s-c)) by ring.
+ rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod by omega.
+ trivial.
+ Qed.
+ Lemma eval_reduce s c p (s_nonzero:s<>0) m (m_eq : Z.pos m = s - eval c) :
+ mod_eq m (eval (reduce s c p)) (eval p).
+ Proof.
+ cbv [reduce reduce_cps mod_eq]; prove_eval.
+ erewrite <-reduction_rule by eauto; prove_eval.
+ Qed.
+ Hint Rewrite eval_reduce using (omega || assumption) : push_basesystem_eval.
+ (* Why TF does this hint get picked up outside the section (while other eval_ hints do not?) *)
+
+
+ Definition negate_snd_cps (p:list limb) {T} (f:list limb ->T) :=
+ map_cps (fun cx => (fst cx, (-snd cx)%RT)) p f.
+
+ Definition negate_snd p := negate_snd_cps p id.
+ Lemma negate_snd_id p {T} f : @negate_snd_cps p T f = f (negate_snd p).
+ Proof. cbv [negate_snd_cps negate_snd]; prove_id. Qed.
+ Hint Opaque negate_snd : uncps.
+ Hint Rewrite negate_snd_id : uncps.
+
+ Lemma eval_negate_snd p : eval (negate_snd p) = - eval p.
+ Proof.
+ cbv [negate_snd_cps negate_snd]; induction p; prove_eval.
+ Qed. Hint Rewrite eval_negate_snd : push_basesystem_eval.
+
+ Section Carries.
+ Context {modulo div:Z->Z->Z}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+
+ Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) :=
+ if dec (fst t = w)
+ then dlet t2 := snd t in
+ f ((w*fw, div t2 fw) :: (w, modulo t2 fw) :: @nil limb)
+ else f [t].
+
+ Definition carryterm w fw t := carryterm_cps w fw t id.
+ Lemma carryterm_cps_id w fw t {T} f :
+ @carryterm_cps w fw t T f
+ = f (@carryterm w fw t).
+ Proof using Type. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed.
+ Hint Opaque carryterm : uncps.
+ Hint Rewrite carryterm_cps_id : uncps.
+
+
+ Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0):
+ eval (carryterm w fw t) = eval [t].
+ Proof using Type*.
+ cbv [carryterm_cps carryterm Let_In]; prove_eval.
+ specialize (div_mod (snd t) fw fw_nonzero).
+ nsatz.
+ Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval.
+
+ Definition carry_cps (w fw:Z) (p:list limb) {T} (f:list limb->T) :=
+ flat_map_cps (carryterm_cps w fw) p f.
+
+ Definition carry w fw p := carry_cps w fw p id.
+ Lemma carry_cps_id w fw p {T} f:
+ @carry_cps w fw p T f = f (carry w fw p).
+ Proof using Type. cbv [carry_cps carry]; prove_id. Qed.
+ Hint Opaque carry : uncps.
+ Hint Rewrite carry_cps_id : uncps.
+
+ Lemma eval_carry w fw p (fw_nonzero:fw<>0):
+ eval (carry w fw p) = eval p.
+ Proof using Type*. cbv [carry_cps carry]; induction p; prove_eval. Qed.
+ Hint Rewrite eval_carry using auto : push_basesystem_eval.
+ End Carries.
+
+ End Associational.
+ Hint Rewrite
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ @Associational.reduce_cps_id
+ @Associational.split_cps_id
+ @Associational.mul_cps_id : uncps.
+
+ Module Positional.
+ Section Positional.
+ Import Associational.
+ Context (weight : nat -> Z) (* [weight i] is the weight of position [i] *)
+ (weight_0 : weight 0%nat = 1%Z)
+ (weight_nonzero : forall i, weight i <> 0).
+
+ (** Converting from positional to associational *)
+ Definition to_associational_cps {n:nat} (xs:tuple Z n)
+ {T} (f:list limb->T) :=
+ map_cps weight (seq 0 n)
+ (fun r =>
+ to_list_cps n xs (fun rr => combine_cps r rr f)).
+
+ Definition to_associational {n} xs :=
+ @to_associational_cps n xs _ id.
+ Lemma to_associational_cps_id {n} x {T} f:
+ @to_associational_cps n x T f = f (to_associational x).
+ Proof using Type. cbv [to_associational_cps to_associational]; prove_id. Qed.
+ Hint Opaque to_associational : uncps.
+ Hint Rewrite @to_associational_cps_id : uncps.
+
+ Definition eval {n} x :=
+ @to_associational_cps n x _ Associational.eval.
+
+ Lemma eval_to_associational {n} x :
+ Associational.eval (@to_associational n x) = eval x.
+ Proof using Type.
+ cbv [to_associational_cps eval to_associational]; prove_eval.
+ Qed. Hint Rewrite @eval_to_associational : push_basesystem_eval.
+
+ (** (modular) equality that tolerates redundancy **)
+ Definition eq {sz} m (a b : tuple Z sz) : Prop :=
+ mod_eq m (eval a) (eval b).
+
+ (** Converting from associational to positional *)
+
+ Definition zeros n : tuple Z n := Tuple.repeat 0 n.
+ Lemma eval_zeros n : eval (zeros n) = 0.
+ Proof using Type.
+ cbv [eval Associational.eval to_associational_cps zeros].
+ pose proof (seq_length n 0). generalize dependent (seq 0 n).
+ intro xs; revert n; induction xs; intros;
+ [autorewrite with uncps; reflexivity|].
+ intros; destruct n; [distr_length|].
+ specialize (IHxs n). autorewrite with uncps in *.
+ rewrite !@Tuple.to_list_repeat in *.
+ simpl List.repeat. rewrite map_cons, combine_cons, map_cons.
+ simpl fold_right. rewrite IHxs by distr_length. ring.
+ Qed. Hint Rewrite eval_zeros : push_basesystem_eval.
+
+ Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) :=
+ @on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f.
+
+ Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id.
+ Lemma add_to_nth_cps_id {n} i x xs {T} f:
+ @add_to_nth_cps n i x xs T f = f (add_to_nth i x xs).
+ Proof using weight.
+ cbv [add_to_nth_cps add_to_nth]; erewrite !on_tuple_cps_correct
+ by (intros; autorewrite with uncps; reflexivity); prove_id.
+ Unshelve.
+ intros; subst. autorewrite with uncps push_id. distr_length.
+ Qed.
+ Hint Opaque add_to_nth : uncps.
+ Hint Rewrite @add_to_nth_cps_id : uncps.
+
+ Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i<n)%nat) (xs:tuple Z n):
+ eval (@add_to_nth n i x xs) = weight i * x + eval xs.
+ Proof using Type.
+ cbv [eval to_associational_cps add_to_nth add_to_nth_cps runtime_add].
+ erewrite on_tuple_cps_correct by (intros; autorewrite with uncps; reflexivity).
+ prove_eval.
+ cbv [Tuple.on_tuple].
+ rewrite !Tuple.to_list_from_list.
+ autorewrite with uncps push_id.
+ rewrite ListUtil.combine_update_nth_r at 1.
+ rewrite <-(update_nth_id i (List.combine _ _)) at 2.
+ rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _ (weight 0, 0)); cbv [ListUtil.splice_nth id];
+ repeat match goal with
+ | _ => progress (apply Zminus_eq; ring_simplify)
+ | _ => progress autorewrite with push_basesystem_eval cancel_pair distr_length
+ | _ => progress rewrite <-?ListUtil.map_nth_default_always, ?map_fst_combine, ?List.firstn_all2, ?ListUtil.map_nth_default_always, ?nth_default_seq_inbouns, ?plus_O_n
+ end; trivial; lia.
+ Unshelve.
+ intros; subst. autorewrite with uncps push_id. distr_length.
+ Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval.
+
+ Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) :=
+ if dec (fst t mod weight i = 0)
+ then f (i, let c := fst t / weight i in (c * snd t)%RT)
+ else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+
+ Definition place t i := place_cps t i id.
+ Lemma place_cps_id t i {T} f :
+ @place_cps t i T f = f (place t i).
+ Proof using Type. cbv [place]; induction i; prove_id. Qed.
+ Hint Opaque place : uncps.
+ Hint Rewrite place_cps_id : uncps.
+
+ Lemma place_cps_in_range (t:limb) (n:nat)
+ : (fst (place_cps t n id) < S n)%nat.
+ Proof using Type. induction n; simpl; break_match; simpl; omega. Qed.
+ Lemma weight_place_cps t i
+ : weight (fst (place_cps t i id)) * snd (place_cps t i id)
+ = fst t * snd t.
+ Proof using Type*.
+ induction i; cbv [id]; simpl place_cps; break_match;
+ autorewrite with cancel_pair;
+ try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end;
+ nsatz || auto.
+ Qed.
+
+ Definition from_associational_cps n (p:list limb)
+ {T} (f:tuple Z n->T):=
+ fold_right_cps
+ (fun t st =>
+ place_cps t (pred n)
+ (fun p=> add_to_nth_cps (fst p) (snd p) st id))
+ (zeros n) p f.
+
+ Definition from_associational n p := from_associational_cps n p id.
+ Lemma from_associational_cps_id {n} p {T} f:
+ @from_associational_cps n p T f = f (from_associational n p).
+ Proof using Type.
+ cbv [from_associational_cps from_associational]; prove_id.
+ Qed.
+ Hint Opaque from_associational : uncps.
+ Hint Rewrite @from_associational_cps_id : uncps.
+
+ Lemma eval_from_associational {n} p (n_nonzero:n<>O):
+ eval (from_associational n p) = Associational.eval p.
+ Proof using Type*.
+ cbv [from_associational_cps from_associational]; induction p;
+ [|pose proof (place_cps_in_range a (pred n))]; prove_eval.
+ cbv [place]; rewrite weight_place_cps. nsatz.
+ Qed.
+ Hint Rewrite @eval_from_associational using omega
+ : push_basesystem_eval.
+
+ Section Carries.
+ Context {modulo div : Z->Z->Z}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+ Definition carry_cps {n m} (index:nat) (p:tuple Z n)
+ {T} (f:tuple Z m->T) :=
+ to_associational_cps p
+ (fun P => @Associational.carry_cps
+ modulo div
+ (weight index)
+ (weight (S index) / weight index)
+ P T
+ (fun R => from_associational_cps m R f)).
+
+ Definition carry {n m} i p := @carry_cps n m i p _ id.
+ Lemma carry_cps_id {n m} i p {T} f:
+ @carry_cps n m i p T f = f (carry i p).
+ Proof.
+ cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity.
+ Qed.
+ Hint Opaque carry : uncps. Hint Rewrite @carry_cps_id : uncps.
+
+ Lemma eval_carry {n m} i p: (n <> 0%nat) -> (m <> 0%nat) ->
+ weight (S i) / weight i <> 0 ->
+ eval (carry (n:=n) (m:=m) i p) = eval p.
+ Proof.
+ cbv [carry_cps carry]; intros. prove_eval.
+ rewrite @eval_carry by eauto.
+ apply eval_to_associational.
+ Qed.
+ Hint Rewrite @eval_carry : push_basesystem_eval.
+
+ (* N.B. It is important to reverse [idxs] here. Like
+ [fold_right], [fold_right_cps2] is written such that the first
+ terms in the list are actually used last in the computation. For
+ example, running:
+
+ `Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).`
+
+ will produce [fun a b c d => (a + (b + (c + d)))].*)
+ Definition chained_carries_cps {n} (p:tuple Z n) (idxs : list nat)
+ {T} (f:tuple Z n->T) :=
+ fold_right_cps2 carry_cps p (rev idxs) f.
+
+ Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id.
+ Lemma chained_carries_id {n} p idxs : forall {T} f,
+ @chained_carries_cps n p idxs T f = f (chained_carries p idxs).
+ Proof using Type. cbv [chained_carries_cps chained_carries]; prove_id. Qed.
+ Hint Opaque chained_carries : uncps.
+ Hint Rewrite @chained_carries_id : uncps.
+
+ Lemma eval_chained_carries {n} (p:tuple Z n) idxs :
+ (forall i, In i idxs -> weight (S i) / weight i <> 0) ->
+ eval (chained_carries p idxs) = eval p.
+ Proof using Type*.
+ cbv [chained_carries chained_carries_cps]; intros;
+ autorewrite with uncps push_id.
+ apply fold_right_invariant; [|intro; rewrite <-in_rev];
+ destruct n; prove_eval; auto.
+ Qed. Hint Rewrite @eval_chained_carries : push_basesystem_eval.
+
+ (* Reverse of [eval]; ranslate from Z to basesystem by putting
+ everything in first digit and then carrying. This function, like
+ [eval], is not defined using CPS. *)
+ Definition encode {n} (x : Z) : tuple Z n :=
+ chained_carries (from_associational n [(1,x)]) (seq 0 n).
+ Lemma eval_encode {n} x : (n <> 0%nat) ->
+ (forall i, In i (seq 0 n) -> weight (S i) / weight i <> 0) ->
+ eval (@encode n x) = x.
+ Proof using Type*. cbv [encode]; intros; prove_eval; auto. Qed.
+ Hint Rewrite @eval_encode : push_basesystem_eval.
+
+ End Carries.
+
+ Section Wrappers.
+ (* Simple wrappers for Associational definitions; convert to
+ associational, do the operation, convert back. *)
+
+ Definition add_cps {n} (p q : tuple Z n) {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => to_associational_cps q
+ (fun Q => from_associational_cps n (P++Q) f)).
+
+ Definition mul_cps {n m} (p q : tuple Z n) {T} (f:tuple Z m->T) :=
+ to_associational_cps p
+ (fun P => to_associational_cps q
+ (fun Q => Associational.mul_cps P Q
+ (fun PQ => from_associational_cps m PQ f))).
+
+ Definition reduce_cps {m n} (s:Z) (c:list B.limb) (p : tuple Z m)
+ {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => Associational.reduce_cps s c P
+ (fun R => from_associational_cps n R f)).
+
+ Definition carry_reduce_cps {n div modulo}
+ (s:Z) (c:list limb) (p : tuple Z n)
+ {T} (f: tuple Z n ->T) :=
+ carry_cps (div:=div) (modulo:=modulo) (n:=n) (m:=S n) (pred n) p
+ (fun r => reduce_cps (m:=S n) (n:=n) s c r f).
+
+ Definition negate_snd_cps {n} (p : tuple Z n)
+ {T} (f:tuple Z n->T) :=
+ to_associational_cps p
+ (fun P => Associational.negate_snd_cps P
+ (fun R => from_associational_cps n R f)).
+
+ End Wrappers.
+ Hint Unfold
+ Positional.add_cps
+ Positional.mul_cps
+ Positional.reduce_cps
+ Positional.carry_reduce_cps
+ Positional.negate_snd_cps
+ .
+
+ Section Subtraction.
+ Context {m n} {coef : tuple Z n}
+ {coef_mod : mod_eq m (eval coef) 0}.
+
+ Definition sub_cps (p q : tuple Z n) {T} (f:tuple Z n->T):=
+ add_cps coef p
+ (fun cp => negate_snd_cps q
+ (fun _q => add_cps cp _q f)).
+
+ Definition sub p q := sub_cps p q id.
+ Lemma sub_id p q {T} f : @sub_cps p q T f = f (sub p q).
+ Proof using Type. cbv [sub_cps sub]; autounfold; prove_id. Qed.
+ Hint Opaque sub : uncps.
+ Hint Rewrite sub_id : uncps.
+
+ Lemma eval_sub p q : mod_eq m (eval (sub p q)) (eval p - eval q).
+ Proof using Type*.
+ cbv [sub sub_cps]; autounfold; destruct n; prove_eval.
+ transitivity (eval coef + (eval p - eval q)).
+ { apply f_equal2; ring. }
+ { cbv [mod_eq] in *; rewrite Z.add_mod_full, coef_mod, Z.add_0_l, Zmod_mod. reflexivity. }
+ Qed.
+
+ Definition opp_cps (p : tuple Z n) {T} (f:tuple Z n->T):=
+ sub_cps (zeros n) p f.
+ End Subtraction.
+
+ (* Lemmas about converting to/from F. Will be useful in proving
+ that basesystem is isomorphic to F.commutative_ring_modulo.*)
+ Section F.
+ Context {sz:nat} {sz_nonzero : sz<>0%nat} {m :positive}.
+ Context (weight_divides : forall i : nat, weight (S i) / weight i <> 0).
+ Context {modulo div:Z->Z->Z}
+ {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+
+ Definition Fencode (x : F m) : tuple Z sz :=
+ encode (div:=div) (modulo:=modulo) (F.to_Z x).
+
+ Definition Fdecode (x : tuple Z sz) : F m := F.of_Z m (eval x).
+
+ Lemma Fdecode_Fencode_id x : Fdecode (Fencode x) = x.
+ Proof using div_mod sz_nonzero weight_0 weight_divides weight_nonzero.
+ cbv [Fdecode Fencode]; rewrite @eval_encode by auto.
+ apply F.of_Z_to_Z.
+ Qed.
+
+ Lemma eq_Feq_iff a b :
+ Logic.eq (Fdecode a) (Fdecode b) <-> eq m a b.
+ Proof using Type. cbv [Fdecode]; rewrite <-F.eq_of_Z_iff; reflexivity. Qed.
+ End F.
+
+
+ End Positional.
+
+ (* Helper lemmas and definitions for [eval]; this needs to be in a
+ separate section so the weight function can change. *)
+ Section EvalHelpers.
+ Lemma eval_single wt (x:Z) : eval (n:=1) wt x = wt 0%nat * x.
+ Proof. cbv - [Z.mul Z.add]. ring. Qed.
+
+ Lemma eval_step {n} (x:tuple Z n) : forall wt z,
+ eval wt (Tuple.append z x) = wt 0%nat * z + eval (fun i => wt (S i)) x.
+ Proof.
+ destruct n; [reflexivity|].
+ intros; cbv [eval to_associational_cps].
+ autorewrite with uncps. rewrite map_S_seq. reflexivity.
+ Qed.
+
+ Lemma eval_wt_equiv {n} :forall wta wtb (x:tuple Z n),
+ (forall i, wta i = wtb i) -> eval wta x = eval wtb x.
+ Proof.
+ destruct n; [reflexivity|].
+ induction n; intros; [rewrite !eval_single, H; reflexivity|].
+ simpl tuple in *; destruct x.
+ change (t, z) with (Tuple.append (n:=S n) z t).
+ rewrite !eval_step. rewrite (H 0%nat). apply Group.cancel_left.
+ apply IHn; auto.
+ Qed.
+
+ Definition eval_from {n} weight (offset:nat) (x : tuple Z n) : Z :=
+ eval (fun i => weight (i+offset)%nat) x.
+
+ Lemma eval_from_0 {n} wt x : @eval_from n wt 0 x = eval wt x.
+ Proof. cbv [eval_from]. auto using eval_wt_equiv. Qed.
+ End EvalHelpers.
+
+ End Positional.
+ Hint Unfold
+ Positional.add_cps
+ Positional.mul_cps
+ Positional.reduce_cps
+ Positional.carry_reduce_cps
+ Positional.negate_snd_cps
+ Positional.opp_cps
+ .
+ Hint Rewrite
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ @Associational.reduce_cps_id
+ @Associational.split_cps_id
+ @Associational.mul_cps_id
+ @Positional.carry_cps_id
+ @Positional.from_associational_cps_id
+ @Positional.place_cps_id
+ @Positional.add_to_nth_cps_id
+ @Positional.to_associational_cps_id
+ @Positional.chained_carries_id
+ @Positional.sub_id
+ : uncps.
+ Hint Rewrite
+ @Associational.eval_mul
+ @Positional.eval_to_associational
+ @Associational.eval_carry
+ @Associational.eval_carryterm
+ @Associational.eval_reduce
+ @Associational.eval_split
+ @Positional.eval_zeros
+ @Positional.eval_carry
+ @Positional.eval_from_associational
+ @Positional.eval_add_to_nth
+ @Positional.eval_chained_carries
+ @Positional.eval_sub
+ using (assumption || vm_decide) : push_basesystem_eval.
+End B.
+
+(* Modulo and div that do shifts if possible, otherwise normal mod/div *)
+Section DivMod.
+ Definition modulo (a b : Z) : Z :=
+ if dec (2 ^ (Z.log2 b) = b)
+ then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT
+ else Z.modulo a b.
+
+ Definition div (a b : Z) : Z :=
+ if dec (2 ^ (Z.log2 b) = b)
+ then let x := Z.log2 b in (a >> x)%RT
+ else Z.div a b.
+
+ Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b.
+ Proof.
+ cbv [div modulo]; intros. break_match; auto using Z.div_mod.
+ rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg.
+ pose proof (Z.div_mod a b H). congruence.
+ Qed.
+End DivMod.
+
+Import B.
+
+Ltac basesystem_partial_evaluation_RHS :=
+ let t0 := match goal with |- _ _ ?t => t end in
+ let t := (eval cbv delta [
+ (* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], [runtime_opp], [runtime_mul], [runtime_shr], or [runtime_and] *)
+Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries_cps Positional.chained_carries Positional.sub_cps Positional.sub Positional.negate_snd_cps Positional.add_cps Positional.opp_cps Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry Associational.negate_snd_cps Associational.negate_snd div modulo
+ ] in t0) in
+ let t := (eval pattern @runtime_mul in t) in
+ let t := match t with ?t _ => t end in
+ let t := (eval pattern @runtime_add in t) in
+ let t := match t with ?t _ => t end in
+ let t := (eval pattern @runtime_opp in t) in
+ let t := match t with ?t _ => t end in
+ let t := (eval pattern @runtime_shr in t) in
+ let t := match t with ?t _ => t end in
+ let t := (eval pattern @runtime_and in t) in
+ let t := match t with ?t _ => t end in
+ let t := (eval pattern @Let_In in t) in
+ let t := match t with ?t _ => t end in
+ let t1 := fresh "t1" in
+ pose t as t1;
+ transitivity (t1
+ (@Let_In)
+ (@runtime_and)
+ (@runtime_shr)
+ (@runtime_opp)
+ (@runtime_add)
+ (@runtime_mul));
+ [replace_with_vm_compute t1; clear t1|reflexivity].
+
+(** This block of tactic code works around bug #5434
+ (https://coq.inria.fr/bugs/show_bug.cgi?id=5434), that
+ [vm_compute] breaks an invariant in pretyping/constr_matching.ml.
+ So we refresh all of the names in match statements in the goal by
+ crawling it.
+
+ In particular, [replace_with_vm_compute] creates a [vm_compute]d
+ term which has anonymous binders where pretyping expects there to
+ be named binders. This shows up when you try to match on the
+ function (the branch statement of the match) with an Ltac pattern
+ like [(fun x : ?T => ?C)] rather than [(fun x : ?T => @?C x)]; we
+ use the former in reification to save the cost of many extra
+ invocations of [cbv beta]. Luckily, patterns like [(fun x : ?T =>
+ @?C x)] don't trigger this anomaly, so we can walk the term,
+ fixing all match statements whose branches are functions whose
+ binder names were eaten by [vm_compute] (note that in a match,
+ every branch where the corresponding constructor takes arguments
+ is represented internally as a function (lambda term)). We fix
+ the match statements by pulling out the branch with the [@?]
+ pattern that doesn't trigger the anomaly, and then recreating the
+ match with a destructuring [let] that hasn't been through
+ [vm_compute], and therefore has name information that
+ constr_matching is happy with. *)
+Ltac replace_match_with_destructuring_match T :=
+ match T with
+ | ?F ?X
+ => let F' := replace_match_with_destructuring_match F in
+ let X' := replace_match_with_destructuring_match X in
+ constr:(F' X')
+ (* we must use [@?f a b] here and not [?f], or else we get an anomaly *)
+ | match ?d with pair a b => @?f a b end
+ => let d' := replace_match_with_destructuring_match d in
+ let T' := fresh in
+ constr:(let '(a, b) := d' in
+ match f a b with
+ | T' => ltac:(let v := (eval cbv beta delta [T'] in T') in
+ let v := replace_match_with_destructuring_match v in
+ exact v)
+ end)
+ | ?x => x
+ end.
+Ltac do_replace_match_with_destructuring_match_in_goal :=
+ let G := get_goal in
+ let G' := replace_match_with_destructuring_match G in
+ change G'.
+
+(* TODO : move *)
+Lemma F_of_Z_opp {m} x : F.of_Z m (- x) = F.opp (F.of_Z m x).
+Proof.
+ cbv [F.opp]; intros. rewrite F.to_Z_of_Z, <-Z.sub_0_l.
+ etransitivity; rewrite F.of_Z_mod;
+ [rewrite Z.opp_mod_mod|]; reflexivity.
+Qed.
+
+Hint Rewrite <-@F.of_Z_add : pull_FofZ.
+Hint Rewrite <-@F.of_Z_mul : pull_FofZ.
+Hint Rewrite <-@F.of_Z_sub : pull_FofZ.
+Hint Rewrite <-@F_of_Z_opp : pull_FofZ.
+
+Ltac F_mod_eq :=
+ cbv [Positional.Fdecode]; autorewrite with pull_FofZ;
+ apply mod_eq_Z2F_iff.
+
+Ltac solve_op_mod_eq wt x :=
+ transitivity (Positional.eval wt x); repeat autounfold;
+ [|autorewrite with uncps push_id push_basesystem_eval;
+ reflexivity];
+ cbv [mod_eq]; apply f_equal2; [|reflexivity];
+ apply f_equal;
+ basesystem_partial_evaluation_RHS;
+ do_replace_match_with_destructuring_match_in_goal.
+
+Ltac solve_op_F wt x := F_mod_eq; solve_op_mod_eq wt x.
diff --git a/src/Arithmetic/Karatsuba.v b/src/Arithmetic/Karatsuba.v
new file mode 100644
index 000000000..0f20bb238
--- /dev/null
+++ b/src/Arithmetic/Karatsuba.v
@@ -0,0 +1,49 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Algebra.Nsatz.
+Require Import Crypto.Util.ZUtil.
+Local Open Scope Z_scope.
+
+Section Karatsuba.
+ Context {T : Type} (eval : T -> Z)
+ (sub : T -> T -> T)
+ (eval_sub : forall x y, eval (sub x y) = eval x - eval y)
+ (mul : T -> T -> T)
+ (eval_mul : forall x y, eval (mul x y) = eval x * eval y)
+ (add : T -> T -> T)
+ (eval_add : forall x y, eval (add x y) = eval x + eval y)
+ (scmul : Z -> T -> T)
+ (eval_scmul : forall c x, eval (scmul c x) = c * eval x)
+ (split : Z -> T -> T * T)
+ (eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval x)
+ .
+
+ Definition karatsuba_mul s (x y : T) : T :=
+ let xab := split s x in
+ let yab := split s y in
+ let xy0 := mul (fst xab) (fst yab) in
+ let xy2 := mul (snd xab) (snd yab) in
+ let xy1 := sub (mul (add (fst xab) (snd xab)) (add (fst yab) (snd yab))) (add xy2 xy0) in
+ add (add (scmul (s^2) xy2) (scmul s xy1)) xy0.
+
+ Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
+ eval (karatsuba_mul s x y) = eval x * eval y.
+ Proof using Type*. cbv [karatsuba_mul]; repeat rewrite ?eval_sub, ?eval_mul, ?eval_add, ?eval_scmul.
+ rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. Qed.
+
+
+ Definition goldilocks_mul s (xs ys : T) : T :=
+ let a_b := split s xs in
+ let c_d := split s ys in
+ let ac := mul (fst a_b) (fst c_d) in
+ (add (add ac (mul (snd a_b) (snd c_d)))
+ (scmul s (sub (mul (add (fst a_b) (snd a_b)) (add (fst c_d) (snd c_d))) ac))).
+
+ Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper Z.modulo_equiv_modulo_Proper.
+
+ Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
+ (eval (goldilocks_mul s xs ys)) mod p = (eval xs * eval ys) mod p.
+ Proof using Type*. cbv [goldilocks_mul]; Zmod_to_equiv_modulo.
+ repeat rewrite ?eval_mul, ?eval_add, ?eval_sub, ?eval_scmul, <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify.
+ setoid_rewrite s2_modp.
+ apply f_equal2; nsatz. Qed.
+End Karatsuba.
diff --git a/src/Arithmetic/ModularArithmeticPre.v b/src/Arithmetic/ModularArithmeticPre.v
new file mode 100644
index 000000000..b27ffd16d
--- /dev/null
+++ b/src/Arithmetic/ModularArithmeticPre.v
@@ -0,0 +1,139 @@
+Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.Numbers.BinNums Coq.ZArith.Zdiv Coq.ZArith.Znumtheory.
+Require Import Coq.Logic.Eqdep_dec.
+Require Import Coq.Logic.EqdepFacts.
+Require Import Coq.omega.Omega.
+Require Import Crypto.Util.NumTheoryUtil.
+Require Export Crypto.Util.FixCoqMistakes.
+
+Lemma Z_mod_mod x m : x mod m = (x mod m) mod m.
+ symmetry.
+ destruct (BinInt.Z.eq_dec m 0).
+ - subst; rewrite !Zdiv.Zmod_0_r; reflexivity.
+ - apply BinInt.Z.mod_mod; assumption.
+Qed.
+
+Lemma exist_reduced_eq: forall (m : Z) (a b : Z), a = b -> forall pfa pfb,
+ exist (fun z : Z => z = z mod m) a pfa =
+ exist (fun z : Z => z = z mod m) b pfb.
+Proof.
+ intuition; simpl in *; try congruence.
+ subst.
+ f_equal.
+ eapply UIP_dec, Z.eq_dec.
+Qed.
+
+Definition mulmod m := fun a b => a * b mod m.
+Definition powmod_pos m := Pos.iter_op (mulmod m).
+Definition powmod m a x := match x with N0 => 1 mod m | Npos p => powmod_pos m p (a mod m) end.
+
+Lemma mulmod_assoc:
+ forall m x y z : Z, mulmod m x (mulmod m y z) = mulmod m (mulmod m x y) z.
+Proof.
+ unfold mulmod; intros.
+ rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal.
+ apply Z.mul_assoc.
+Qed.
+
+Lemma powmod_1plus:
+ forall m a : Z,
+ forall x : N, powmod m a (1 + x) = (a * (powmod m a x mod m)) mod m.
+Proof.
+ intros m a x.
+ rewrite N.add_1_l.
+ cbv beta delta [powmod N.succ].
+ destruct x. simpl; rewrite ?Zdiv.Zmult_mod_idemp_r, Z.mul_1_r; auto.
+ unfold powmod_pos.
+ rewrite Pos.iter_op_succ by (apply mulmod_assoc).
+ unfold mulmod.
+ rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal.
+Qed.
+
+
+Lemma N_pos_1plus : forall p, (N.pos p = 1 + (N.pred (N.pos p)))%N.
+ intros.
+ rewrite <-N.pos_pred_spec.
+ rewrite N.add_1_l.
+ rewrite N.pos_pred_spec.
+ rewrite N.succ_pred; eauto.
+ discriminate.
+Qed.
+
+Lemma powmod_Zpow_mod : forall m a n, powmod m a n = (a^Z.of_N n) mod m.
+Proof.
+ induction n using N.peano_ind; [auto|].
+ rewrite <-N.add_1_l.
+ rewrite powmod_1plus.
+ rewrite IHn.
+ rewrite Zmod_mod.
+ rewrite N.add_1_l.
+ rewrite N2Z.inj_succ.
+ rewrite Z.pow_succ_r by (apply N2Z.is_nonneg).
+ rewrite ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; f_equal.
+Qed.
+
+Local Obligation Tactic := idtac.
+
+Program Definition pow_impl_sig {m} (a:{z : Z | z = z mod m}) (x:N) : {z : Z | z = z mod m}
+ := powmod m (proj1_sig a) x.
+Next Obligation.
+ intros; destruct x; [simpl; rewrite Zmod_mod; reflexivity|].
+ rewrite N_pos_1plus.
+ rewrite powmod_1plus.
+ rewrite Zmod_mod; reflexivity.
+Qed.
+
+Program Definition pow_impl {m} :
+ {pow0
+ : {z : BinNums.Z | z = z mod m} -> BinNums.N -> {z : BinNums.Z | z = z mod m}
+ |
+ forall a : {z : BinNums.Z | z = z mod m},
+ pow0 a 0%N =
+ exist (fun z : BinNums.Z => z = z mod m) (1 mod m) (Z_mod_mod 1 m) /\
+ (forall x : BinNums.N,
+ pow0 a (1 + x)%N =
+ exist (fun z : BinNums.Z => z = z mod m)
+ ((proj1_sig a * proj1_sig (pow0 a x)) mod m)
+ (Z_mod_mod (proj1_sig a * proj1_sig (pow0 a x)) m))} := pow_impl_sig.
+Next Obligation.
+ split; intros; apply exist_reduced_eq;
+ rewrite ?powmod_1plus, ?Zdiv.Zmult_mod_idemp_l, ?Zdiv.Zmult_mod_idemp_r; reflexivity.
+Qed.
+
+Program Definition mod_inv_sig {m} (a:{z : Z | z = z mod m}) : {z : Z | z = z mod m} :=
+ let (a, _) := a in
+ match a return _ with
+ | 0%Z => 0 (* m = 2 *)
+ | _ => powmod m a (Z.to_N (m-2))
+ end.
+Next Obligation.
+ intros; break_match; rewrite ?powmod_Zpow_mod, ?Zmod_mod, ?Zmod_0_l; reflexivity.
+Qed.
+
+Program Definition inv_impl {m : BinNums.Z} :
+ {inv0 : {z : BinNums.Z | z = z mod m} -> {z : BinNums.Z | z = z mod m} |
+ inv0 (exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m)) =
+ exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m) /\
+ (Znumtheory.prime m ->
+ forall a : {z : BinNums.Z | z = z mod m},
+ a <> exist (fun z : BinNums.Z => z = z mod m) (0 mod m) (Z_mod_mod 0 m) ->
+ exist (fun z : BinNums.Z => z = z mod m)
+ ((proj1_sig (inv0 a) * proj1_sig a) mod m)
+ (Z_mod_mod (proj1_sig (inv0 a) * proj1_sig a) m) =
+ exist (fun z : BinNums.Z => z = z mod m) (1 mod m) (Z_mod_mod 1 m))}
+ := mod_inv_sig.
+Next Obligation.
+ split.
+ { apply exist_reduced_eq; rewrite Zmod_0_l; reflexivity. }
+ intros Hm [a pfa] Ha'. apply exist_reduced_eq.
+ assert (Hm':0 <= m - 2) by (pose proof prime_ge_2 m Hm; omega).
+ assert (Ha:a mod m<>0) by (intro; apply Ha', exist_reduced_eq; congruence).
+ cbv [proj1_sig mod_inv_sig].
+ transitivity ((a*powmod m a (Z.to_N (m - 2))) mod m); [destruct a; f_equal; ring|].
+ rewrite !powmod_Zpow_mod.
+ rewrite Z2N.id by assumption.
+ rewrite Zmult_mod_idemp_r.
+ rewrite <-Z.pow_succ_r by assumption.
+ replace (Z.succ (m - 2)) with (m-1) by omega.
+ rewrite (Zmod_small 1) by omega.
+ apply (fermat_little m Hm a Ha).
+Qed. \ No newline at end of file
diff --git a/src/Arithmetic/ModularArithmeticTheorems.v b/src/Arithmetic/ModularArithmeticTheorems.v
new file mode 100644
index 000000000..990aa9dc8
--- /dev/null
+++ b/src/Arithmetic/ModularArithmeticTheorems.v
@@ -0,0 +1,347 @@
+Require Import Coq.omega.Omega.
+Require Import Crypto.Spec.ModularArithmetic.
+Require Import Crypto.Arithmetic.ModularArithmeticPre.
+
+Require Import Coq.ZArith.BinInt Coq.ZArith.Zdiv Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *)
+Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid.
+Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Ring_tac.
+
+Require Import Crypto.Algebra.Hierarchy Crypto.Algebra.Ring Crypto.Algebra.Field.
+Require Import Crypto.Util.Decidable Crypto.Util.ZUtil.
+Require Export Crypto.Util.FixCoqMistakes.
+
+Module F.
+ Ltac unwrap_F :=
+ intros;
+ repeat match goal with [ x : F _ |- _ ] => destruct x end;
+ lazy iota beta delta [F.add F.sub F.mul F.opp F.to_Z F.of_Z proj1_sig] in *;
+ try apply eqsig_eq;
+ pull_Zmod.
+
+ (* FIXME: remove the pose proof once [monoid] no longer contains decidable equality *)
+ Global Instance eq_dec {m} : DecidableRel (@eq (F m)). pose proof dec_eq_Z. exact _. Defined.
+
+ Global Instance commutative_ring_modulo m
+ : @Algebra.Hierarchy.commutative_ring (F m) Logic.eq 0%F 1%F F.opp F.add F.sub F.mul.
+ Proof.
+ repeat (split || intro); unwrap_F;
+ autorewrite with zsimplify; solve [ exact _ | auto with zarith | congruence].
+ Qed.
+
+ Lemma pow_spec {m} a : F.pow a 0%N = 1%F :> F m /\ forall x, F.pow a (1 + x)%N = F.mul a (F.pow a x).
+ Proof. change (@F.pow m) with (proj1_sig (@F.pow_with_spec m)); destruct (@F.pow_with_spec m); eauto. Qed.
+
+ Global Instance char_gt {m} :
+ @Ring.char_ge
+ (F m) Logic.eq F.zero F.one F.opp F.add F.sub F.mul
+ m.
+ Proof.
+ Admitted.
+
+ Section FandZ.
+ Context {m:positive}.
+ Local Open Scope F_scope.
+
+ Theorem eq_to_Z_iff (x y : F m) : x = y <-> F.to_Z x = F.to_Z y.
+ Proof using Type. destruct x, y; intuition; simpl in *; try apply (eqsig_eq _ _); congruence. Qed.
+
+ Lemma eq_of_Z_iff : forall x y : Z, x mod m = y mod m <-> F.of_Z m x = F.of_Z m y.
+ Proof using Type. split; unwrap_F; congruence. Qed.
+
+
+ Lemma to_Z_of_Z : forall z, F.to_Z (F.of_Z m z) = z mod m.
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma of_Z_to_Z x : F.of_Z m (F.to_Z x) = x :> F m.
+ Proof using Type. unwrap_F; congruence. Qed.
+
+
+ Lemma of_Z_mod : forall x, F.of_Z m x = F.of_Z m (x mod m).
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma mod_to_Z : forall (x:F m), F.to_Z x mod m = F.to_Z x.
+ Proof using Type. unwrap_F. congruence. Qed.
+
+ Lemma to_Z_0 : F.to_Z (0:F m) = 0%Z.
+ Proof using Type. unwrap_F. apply Zmod_0_l. Qed.
+
+ Lemma of_Z_small_nonzero z : (0 < z < m)%Z -> F.of_Z m z <> 0.
+ Proof using Type. intros Hrange Hnz. inversion Hnz. rewrite Zmod_small, Zmod_0_l in *; omega. Qed.
+
+ Lemma to_Z_nonzero (x:F m) : x <> 0 -> F.to_Z x <> 0%Z.
+ Proof using Type. intros Hnz Hz. rewrite <- Hz, of_Z_to_Z in Hnz; auto. Qed.
+
+ Lemma to_Z_range (x : F m) : 0 < m -> 0 <= F.to_Z x < m.
+ Proof using Type. intros. rewrite <- mod_to_Z. apply Z.mod_pos_bound. trivial. Qed.
+
+ Lemma to_Z_nonzero_range (x : F m) : (x <> 0) -> 0 < m -> (1 <= F.to_Z x < m)%Z.
+ Proof using Type.
+ unfold not; intros Hnz Hlt.
+ rewrite eq_to_Z_iff, to_Z_0 in Hnz; pose proof (to_Z_range x Hlt).
+ omega.
+ Qed.
+
+ Lemma of_Z_add : forall (x y : Z),
+ F.of_Z m (x + y) = F.of_Z m x + F.of_Z m y.
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma to_Z_add : forall x y : F m,
+ F.to_Z (x + y) = ((F.to_Z x + F.to_Z y) mod m)%Z.
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma of_Z_mul x y : F.of_Z m (x * y) = F.of_Z _ x * F.of_Z _ y :> F m.
+ Proof using Type. unwrap_F. trivial. Qed.
+
+ Lemma to_Z_mul : forall x y : F m,
+ F.to_Z (x * y) = ((F.to_Z x * F.to_Z y) mod m)%Z.
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma of_Z_sub x y : F.of_Z _ (x - y) = F.of_Z _ x - F.of_Z _ y :> F m.
+ Proof using Type. unwrap_F. trivial. Qed.
+
+ Lemma to_Z_opp : forall x : F m, F.to_Z (F.opp x) = (- F.to_Z x) mod m.
+ Proof using Type. unwrap_F; trivial. Qed.
+
+ Lemma of_Z_pow x n : F.of_Z _ x ^ n = F.of_Z _ (x ^ (Z.of_N n) mod m) :> F m.
+ Proof using Type.
+ intros.
+ induction n using N.peano_ind;
+ destruct (pow_spec (F.of_Z m x)) as [pow_0 pow_succ] . {
+ rewrite pow_0.
+ unwrap_F; trivial.
+ } {
+ rewrite N2Z.inj_succ.
+ rewrite Z.pow_succ_r by apply N2Z.is_nonneg.
+ rewrite <- N.add_1_l.
+ rewrite pow_succ.
+ rewrite IHn.
+ unwrap_F; trivial.
+ }
+ Qed.
+
+ Lemma to_Z_pow : forall (x : F m) n,
+ F.to_Z (x ^ n)%F = (F.to_Z x ^ Z.of_N n mod m)%Z.
+ Proof using Type.
+ intros.
+ symmetry.
+ induction n using N.peano_ind;
+ destruct (pow_spec x) as [pow_0 pow_succ] . {
+ rewrite pow_0, Z.pow_0_r; auto.
+ } {
+ rewrite N2Z.inj_succ.
+ rewrite Z.pow_succ_r by apply N2Z.is_nonneg.
+ rewrite <- N.add_1_l.
+ rewrite pow_succ.
+ rewrite <- Zmult_mod_idemp_r.
+ rewrite IHn.
+ apply to_Z_mul.
+ }
+ Qed.
+
+ Lemma square_iff (x:F m) :
+ (exists y : F m, y * y = x) <-> (exists y : Z, y * y mod m = F.to_Z x)%Z.
+ Proof using Type.
+ setoid_rewrite eq_to_Z_iff; setoid_rewrite to_Z_mul; split; intro H; destruct H as [x' H].
+ - eauto.
+ - exists (F.of_Z _ x'); rewrite !to_Z_of_Z; pull_Zmod; auto.
+ Qed.
+ End FandZ.
+
+ Section FandNat.
+ Import NPeano Nat.
+ Local Infix "mod" := modulo : nat_scope.
+ Local Open Scope nat_scope.
+
+ Context {m:BinPos.positive}.
+
+ Lemma to_nat_of_nat (n:nat) : F.to_nat (F.of_nat m n) = (n mod (Z.to_nat m))%nat.
+ Proof using Type.
+ unfold F.to_nat, F.of_nat.
+ rewrite F.to_Z_of_Z.
+ assert (Pos.to_nat m <> 0)%nat as HA by (pose proof Pos2Nat.is_pos m; omega).
+ pose proof (mod_Zmod n (Pos.to_nat m) HA) as Hmod.
+ rewrite positive_nat_Z in Hmod.
+ rewrite <- Hmod.
+ rewrite <-Nat2Z.id, Z2Nat.inj_pos; omega.
+ Qed.
+
+ Lemma of_nat_to_nat x : F.of_nat m (F.to_nat x) = x.
+ Proof using Type.
+
+ unfold F.to_nat, F.of_nat.
+ rewrite Z2Nat.id; [ eapply F.of_Z_to_Z | eapply F.to_Z_range; reflexivity].
+ Qed.
+
+ Lemma Pos_to_nat_nonzero p : Pos.to_nat p <> 0%nat.
+ Admitted.
+
+ Lemma of_nat_mod (n:nat) : F.of_nat m (n mod (Z.to_nat m)) = F.of_nat m n.
+ Proof using Type.
+ unfold F.of_nat.
+ rewrite (F.of_Z_mod (Z.of_nat n)), ?mod_Zmod, ?Z2Nat.id; [reflexivity|..].
+ { apply Pos2Z.is_nonneg. }
+ { rewrite Z2Nat.inj_pos. apply Pos_to_nat_nonzero. }
+ Qed.
+
+ Lemma to_nat_mod (x:F m) (Hm:(0 < m)%Z) : F.to_nat x mod (Z.to_nat m) = F.to_nat x.
+ Proof using Type.
+
+ unfold F.to_nat.
+ rewrite <-F.mod_to_Z at 2.
+ apply Z.mod_to_nat; [assumption|].
+ apply F.to_Z_range; assumption.
+ Qed.
+
+ Lemma of_nat_add x y :
+ F.of_nat m (x + y) = (F.of_nat m x + F.of_nat m y)%F.
+ Proof using Type. unfold F.of_nat; rewrite Nat2Z.inj_add, F.of_Z_add; reflexivity. Qed.
+
+ Lemma of_nat_mul x y :
+ F.of_nat m (x * y) = (F.of_nat m x * F.of_nat m y)%F.
+ Proof using Type. unfold F.of_nat; rewrite Nat2Z.inj_mul, F.of_Z_mul; reflexivity. Qed.
+ End FandNat.
+
+ Section RingTacticGadgets.
+ Context (m:positive).
+
+ Definition ring_theory : ring_theory 0%F 1%F (@F.add m) (@F.mul m) (@F.sub m) (@F.opp m) eq
+ := Algebra.Ring.ring_theory_for_stdlib_tactic.
+
+ Lemma pow_pow_N (x : F m) : forall (n : N), (x ^ id n)%F = pow_N 1%F F.mul x n.
+ Proof using Type.
+ destruct (pow_spec x) as [HO HS]; intros.
+ destruct n; auto; unfold id.
+ rewrite ModularArithmeticPre.N_pos_1plus at 1.
+ rewrite HS.
+ simpl.
+ induction p using Pos.peano_ind.
+ - simpl. rewrite HO. apply Algebra.Hierarchy.right_identity.
+ - rewrite (@pow_pos_succ (F m) (@F.mul m) eq _ _ associative x).
+ rewrite <-IHp, Pos.pred_N_succ, ModularArithmeticPre.N_pos_1plus, HS.
+ trivial.
+ Qed.
+
+ Lemma power_theory : power_theory 1%F (@F.mul m) eq id (@F.pow m).
+ Proof using Type. split; apply pow_pow_N. Qed.
+
+ (***** Division Theory *****)
+ Definition quotrem(a b: F m): F m * F m :=
+ let '(q, r) := (Z.quotrem (F.to_Z a) (F.to_Z b)) in (F.of_Z _ q , F.of_Z _ r).
+ Lemma div_theory : div_theory eq (@F.add m) (@F.mul m) (@id _) quotrem.
+ Proof using Type.
+ constructor; intros; unfold quotrem, id.
+
+ replace (Z.quotrem (F.to_Z a) (F.to_Z b)) with (Z.quot (F.to_Z a) (F.to_Z b), Z.rem (F.to_Z a) (F.to_Z b)) by
+ try (unfold Z.quot, Z.rem; rewrite <- surjective_pairing; trivial).
+
+ unwrap_F; rewrite <-Z.quot_rem'; trivial.
+ Qed.
+
+ (* Define a "ring morphism" between GF and Z, i.e. an equivalence
+ * between 'inject (ZFunction (X))' and 'GFFunction (inject (X))'.
+ *
+ * Doing this allows the [ring] tactic to do coefficient
+ * manipulations in Z rather than F, because we know it's equivalent
+ * to inject the result afterward. *)
+ Lemma ring_morph: ring_morph 0%F 1%F F.add F.mul F.sub F.opp eq
+ 0%Z 1%Z Z.add Z.mul Z.sub Z.opp Z.eqb (F.of_Z m).
+ Proof using Type. split; intros; unwrap_F; solve [ auto | rewrite (proj1 (Z.eqb_eq x y)); trivial]. Qed.
+
+ (* Redefine our division theory under the ring morphism *)
+ Lemma morph_div_theory:
+ Ring_theory.div_theory eq Zplus Zmult (F.of_Z m) Z.quotrem.
+ Proof using Type.
+ split; intros.
+ replace (Z.quotrem a b) with (Z.quot a b, Z.rem a b);
+ try (unfold Z.quot, Z.rem; rewrite <- surjective_pairing; trivial).
+ unwrap_F; rewrite <- (Z.quot_rem' a b); trivial.
+ Qed.
+
+ End RingTacticGadgets.
+
+ Ltac is_constant t := match t with F.of_Z _ ?x => x | _ => NotConstant end.
+ Ltac is_pow_constant t := Ncst t.
+
+ Section VariousModulo.
+ Context {m:positive}.
+ Local Open Scope F_scope.
+
+ Add Ring _theory : (ring_theory m)
+ (morphism (ring_morph m),
+ constants [is_constant],
+ div (morph_div_theory m),
+ power_tac (power_theory m) [is_pow_constant]).
+
+ Lemma mul_nonzero_l : forall a b : F m, a*b <> 0 -> a <> 0.
+ Proof using Type. intros a b Hnz Hz. rewrite Hz in Hnz; apply Hnz; ring. Qed.
+
+ Lemma mul_nonzero_r : forall a b : F m, a*b <> 0 -> b <> 0.
+ Proof using Type. intros a b Hnz Hz. rewrite Hz in Hnz; apply Hnz; ring. Qed.
+ End VariousModulo.
+
+ Section Pow.
+ Context {m:positive}.
+ Add Ring _theory' : (ring_theory m)
+ (morphism (ring_morph m),
+ constants [is_constant],
+ div (morph_div_theory m),
+ power_tac (power_theory m) [is_pow_constant]).
+ Local Open Scope F_scope.
+
+ Import Algebra.ScalarMult.
+ Global Instance pow_is_scalarmult
+ : is_scalarmult (G:=F m) (eq:=eq) (add:=F.mul) (zero:=1%F) (mul := fun n x => x ^ (N.of_nat n)).
+ Proof using Type.
+ split; intros; rewrite ?Nat2N.inj_succ, <-?N.add_1_l;
+ match goal with
+ | [x:F m |- _ ] => solve [destruct (@pow_spec m P); auto]
+ | |- Proper _ _ => solve_proper
+ end.
+ Qed.
+
+ (* TODO: move this somewhere? *)
+ Create HintDb nat2N discriminated.
+ Hint Rewrite Nat2N.inj_iff
+ (eq_refl _ : (0%N = N.of_nat 0))
+ (eq_refl _ : (1%N = N.of_nat 1))
+ (eq_refl _ : (2%N = N.of_nat 2))
+ (eq_refl _ : (3%N = N.of_nat 3))
+ : nat2N.
+ Hint Rewrite <- Nat2N.inj_double Nat2N.inj_succ_double Nat2N.inj_succ
+ Nat2N.inj_add Nat2N.inj_mul Nat2N.inj_sub Nat2N.inj_pred
+ Nat2N.inj_div2 Nat2N.inj_max Nat2N.inj_min Nat2N.id
+ : nat2N.
+
+ Ltac pow_to_scalarmult_ref :=
+ repeat (autorewrite with nat2N;
+ match goal with
+ | |- context [ (_^?n)%F ] =>
+ rewrite <-(N2Nat.id n); generalize (N.to_nat n); clear n;
+ let m := fresh n in intro m
+ | |- context [ (_^N.of_nat ?n)%F ] =>
+ let rw := constr:(scalarmult_ext(zero:=F.of_Z m 1) n) in
+ setoid_rewrite rw (* rewriting moduloa reduction *)
+ end).
+
+ Lemma pow_0_r (x:F m) : x^0 = 1.
+ Proof using Type. pow_to_scalarmult_ref. apply scalarmult_0_l. Qed.
+
+ Lemma pow_add_r (x:F m) (a b:N) : x^(a+b) = x^a * x^b.
+ Proof using Type. pow_to_scalarmult_ref; apply scalarmult_add_l. Qed.
+
+ Lemma pow_0_l (n:N) : n <> 0%N -> 0^n = 0 :> F m.
+ Proof using Type. pow_to_scalarmult_ref; destruct n; simpl; intros; [congruence|ring]. Qed.
+
+ Lemma pow_pow_l (x:F m) (a b:N) : (x^a)^b = x^(a*b).
+ Proof using Type. pow_to_scalarmult_ref. apply scalarmult_assoc. Qed.
+
+ Lemma pow_1_r (x:F m) : x^1 = x.
+ Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed.
+
+ Lemma pow_2_r (x:F m) : x^2 = x*x.
+ Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed.
+
+ Lemma pow_3_r (x:F m) : x^3 = x*x*x.
+ Proof using Type. pow_to_scalarmult_ref; simpl; ring. Qed.
+ End Pow.
+End F.
diff --git a/src/Arithmetic/MontgomeryReduction/Definition.v b/src/Arithmetic/MontgomeryReduction/Definition.v
new file mode 100644
index 000000000..78d3c037f
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/Definition.v
@@ -0,0 +1,179 @@
+(*** Montgomery Multiplication *)
+(** This file implements Montgomery Form, Montgomery Reduction, and
+ Montgomery Multiplication on [Z]. We follow Wikipedia. *)
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope Z_scope.
+Delimit Scope montgomery_naive_scope with montgomery_naive.
+Delimit Scope montgomery_scope with montgomery.
+Definition montgomeryZ := Z.
+Bind Scope montgomery_scope with montgomeryZ.
+
+Section montgomery.
+ Context (N : Z)
+ (R : Z)
+ (R' : Z). (* R' is R⁻¹ mod N *)
+ Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope.
+ Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope.
+ (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Montgomery_modular_multiplication>: *)
+ (** In modular arithmetic computation, Montgomery modular
+ multiplication, more commonly referred to as Montgomery
+ multiplication, is a method for performing fast modular
+ multiplication, introduced in 1985 by the American mathematician
+ Peter L. Montgomery. *)
+ (** Given two integers [a] and [b], the classical modular
+ multiplication algorithm computes [ab mod N]. Montgomery
+ multiplication works by transforming [a] and [b] into a special
+ representation known as Montgomery form. For a modulus [N], the
+ Montgomery form of [a] is defined to be [aR mod N] for some
+ constant [R] depending only on [N] and the underlying computer
+ architecture. If [aR mod N] and [bR mod N] are the Montgomery
+ forms of [a] and [b], then their Montgomery product is [abR mod
+ N]. Montgomery multiplication is a fast algorithm to compute the
+ Montgomery product. Transforming the result out of Montgomery
+ form yields the classical modular product [ab mod N]. *)
+
+ Definition to_montgomery_naive (x : Z) : montgomeryZ := x * R.
+ Definition from_montgomery_naive (x : montgomeryZ) : Z := x * R'.
+
+ (** * Modular arithmetic and Montgomery form *)
+ Section general.
+ (** Addition and subtraction in Montgomery form are the same as
+ ordinary modular addition and subtraction because of the
+ distributive law: *)
+ (** [aR + bR = (a+b)R], *)
+ (** [aR - bR = (a-b)R]. *)
+ (** This is a consequence of the fact that, because [gcd(R, N) =
+ 1], multiplication by [R] is an isomorphism on the additive
+ group [ℤ/Nℤ]. *)
+
+ Definition add : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR + bR.
+ Definition sub : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR - bR.
+
+ (** Multiplication in Montgomery form, however, is seemingly more
+ complicated. The usual product of [aR] and [bR] does not
+ represent the product of [a] and [b] because it has an extra
+ factor of R: *)
+ (** [(aR mod N)(bR mod N) mod N = (abR)R mod N]. *)
+ (** Computing products in Montgomery form requires removing the
+ extra factor of [R]. While division by [R] is cheap, the
+ intermediate product [(aR mod N)(bR mod N)] is not divisible
+ by [R] because the modulo operation has destroyed that
+ property. *)
+ (** Removing the extra factor of R can be done by multiplying by
+ an integer [R′] such that [RR' ≡ 1 (mod N)], that is, by an
+ [R′] whose residue class is the modular inverse of [R] mod
+ [N]. Then, working modulo [N], *)
+ (** [(aR mod N)(bR mod N)R' ≡ (aR)(bR)R⁻¹ ≡ (ab)R (mod N)]. *)
+
+ Definition mul_naive : montgomeryZ -> montgomeryZ -> montgomeryZ
+ := fun aR bR => aR * bR * R'.
+ End general.
+
+ (** * The REDC algorithm *)
+ Section redc.
+ (** While the above algorithm is correct, it is slower than
+ multiplication in the standard representation because of the
+ need to multiply by [R′] and divide by [N]. Montgomery
+ reduction, also known as REDC, is an algorithm that
+ simultaneously computes the product by [R′] and reduces modulo
+ [N] more quickly than the naive method. The speed is because
+ all computations are done using only reduction and divisions
+ with respect to [R], not [N]: *)
+ (**
+<<
+function REDC is
+ input: Integers R and N with gcd(R, N) = 1,
+ Integer N′ in [0, R − 1] such that NN′ ≡ −1 mod R,
+ Integer T in the range [0, RN − 1]
+ output: Integer S in the range [0, N − 1] such that S ≡ TR⁻¹ mod N
+
+ m ← ((T mod R)N′) mod R
+ t ← (T + mN) / R
+ if t ≥ N then
+ return t − N
+ else
+ return t
+ end if
+end function
+>> *)
+ Context (N' : Z). (* N' is (-N⁻¹) mod R *)
+ Section redc.
+ Context (T : Z).
+
+ Let m := ((T mod R) * N') mod R.
+ Let t := (T + m * N) / R.
+ Definition prereduce : montgomeryZ := t.
+
+ Definition partial_reduce : montgomeryZ
+ := if R <=? t then
+ prereduce - N
+ else
+ prereduce.
+
+ Definition partial_reduce_alt : montgomeryZ
+ := let v0 := (T + m * N) in
+ let v := (v0 mod (R * R)) / R in
+ if R * R <=? v0 then
+ (v - N) mod R
+ else
+ v.
+
+ Definition reduce : montgomeryZ
+ := if N <=? t then
+ prereduce - N
+ else
+ prereduce.
+
+ Definition reduce_via_partial : montgomeryZ
+ := if N <=? partial_reduce then
+ partial_reduce - N
+ else
+ partial_reduce.
+
+ Definition reduce_via_partial_alt : montgomeryZ
+ := if N <=? partial_reduce then
+ partial_reduce - N
+ else
+ partial_reduce.
+ End redc.
+
+ (** * Arithmetic in Montgomery form *)
+ Section arithmetic.
+ (** Many operations of interest modulo [N] can be expressed
+ equally well in Montgomery form. Addition, subtraction,
+ negation, comparison for equality, multiplication by an
+ integer not in Montgomery form, and greatest common divisors
+ with [N] may all be done with the standard algorithms. *)
+ (** When [R > N], most other arithmetic operations can be
+ expressed in terms of REDC. This assumption implies that the
+ product of two representatives [mod N] is less than [RN],
+ the exact hypothesis necessary for REDC to generate correct
+ output. In particular, the product of [aR mod N] and [bR mod
+ N] is [REDC((aR mod N)(bR mod N))]. The combined operation
+ of multiplication and REDC is often called Montgomery
+ multiplication. *)
+ Definition mul : montgomeryZ -> montgomeryZ -> montgomeryZ
+ := fun aR bR => reduce (aR * bR).
+
+ (** Conversion into Montgomery form is done by computing
+ [REDC((a mod N)(R² mod N))]. Conversion out of Montgomery
+ form is done by computing [REDC(aR mod N)]. The modular
+ inverse of [aR mod N] is [REDC((aR mod N)⁻¹(R³ mod
+ N))]. Modular exponentiation can be done using
+ exponentiation by squaring by initializing the initial
+ product to the Montgomery representation of 1, that is, to
+ [R mod N], and by replacing the multiply and square steps by
+ Montgomery multiplies. *)
+ Definition to_montgomery (a : Z) : montgomeryZ := reduce (a * (R*R mod N)).
+ Definition from_montgomery (aR : montgomeryZ) : Z := reduce aR.
+ End arithmetic.
+ End redc.
+End montgomery.
+
+Infix "+" := add : montgomery_scope.
+Infix "-" := sub : montgomery_scope.
+Infix "*" := mul_naive : montgomery_naive_scope.
+Infix "*" := mul : montgomery_scope.
diff --git a/src/Arithmetic/MontgomeryReduction/Proofs.v b/src/Arithmetic/MontgomeryReduction/Proofs.v
new file mode 100644
index 000000000..d5de00213
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/Proofs.v
@@ -0,0 +1,296 @@
+(*** Montgomery Multiplication *)
+(** This file implements the proofs for Montgomery Form, Montgomery
+ Reduction, and Montgomery Multiplication on [Z]. We follow
+ Wikipedia. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.Structures.Equalities.
+Require Import Crypto.Arithmetic.MontgomeryReduction.Definition.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.SimplifyRepeatedIfs.
+Require Import Crypto.Util.Notations.
+
+Declare Module Nop : Nop.
+Module Import ImportEquivModuloInstances := Z.EquivModuloInstances Nop.
+
+Local Existing Instance eq_Reflexive. (* speed up setoid_rewrite as per https://coq.inria.fr/bugs/show_bug.cgi?id=4978 *)
+
+Local Open Scope Z_scope.
+
+Section montgomery.
+ Context (N : Z)
+ (N_reasonable : N <> 0)
+ (R : Z)
+ (R_good : Z.gcd N R = 1).
+ Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope.
+ Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope.
+ Context (R' : Z)
+ (R'_good : R * R' ≡ 1).
+
+ Lemma R'_good' : R' * R ≡ 1.
+ Proof using R'_good. rewrite <- R'_good; apply f_equal2; lia. Qed.
+
+ Local Notation to_montgomery_naive := (to_montgomery_naive R) (only parsing).
+ Local Notation from_montgomery_naive := (from_montgomery_naive R') (only parsing).
+
+ Lemma to_from_montgomery_naive x : to_montgomery_naive (from_montgomery_naive x) ≡ x.
+ Proof using R'_good.
+ unfold to_montgomery_naive, from_montgomery_naive.
+ rewrite <- Z.mul_assoc, R'_good'.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+ Lemma from_to_montgomery_naive x : from_montgomery_naive (to_montgomery_naive x) ≡ x.
+ Proof using R'_good.
+ unfold to_montgomery_naive, from_montgomery_naive.
+ rewrite <- Z.mul_assoc, R'_good.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+
+ (** * Modular arithmetic and Montgomery form *)
+ Section general.
+ Local Infix "+" := add : montgomery_scope.
+ Local Infix "-" := sub : montgomery_scope.
+ Local Infix "*" := (mul_naive R') : montgomery_scope.
+
+ Lemma add_correct_naive x y : from_montgomery_naive (x + y) = from_montgomery_naive x + from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, add; lia. Qed.
+ Lemma add_correct_naive_to x y : to_montgomery_naive (x + y) = (to_montgomery_naive x + to_montgomery_naive y)%montgomery.
+ Proof using Type. unfold to_montgomery_naive, add; autorewrite with push_Zmul; reflexivity. Qed.
+ Lemma sub_correct_naive x y : from_montgomery_naive (x - y) = from_montgomery_naive x - from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, sub; lia. Qed.
+ Lemma sub_correct_naive_to x y : to_montgomery_naive (x - y) = (to_montgomery_naive x - to_montgomery_naive y)%montgomery.
+ Proof using Type. unfold to_montgomery_naive, sub; autorewrite with push_Zmul; reflexivity. Qed.
+
+ Theorem mul_correct_naive x y : from_montgomery_naive (x * y) = from_montgomery_naive x * from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, mul_naive; lia. Qed.
+ Theorem mul_correct_naive_to x y : to_montgomery_naive (x * y) ≡ (to_montgomery_naive x * to_montgomery_naive y)%montgomery.
+ Proof using R'_good.
+ unfold to_montgomery_naive, mul_naive.
+ rewrite <- !Z.mul_assoc, R'_good.
+ autorewrite with zsimplify; apply (f_equal2 Z.modulo); lia.
+ Qed.
+ End general.
+
+ (** * The REDC algorithm *)
+ Section redc.
+ Context (N' : Z)
+ (N'_in_range : 0 <= N' < R)
+ (N'_good : N * N' ≡ᵣ -1).
+
+ Lemma N'_good' : N' * N ≡ᵣ -1.
+ Proof using N'_good. rewrite <- N'_good; apply f_equal2; lia. Qed.
+
+ Lemma N'_good'_alt x : (((x mod R) * (N' mod R)) mod R) * (N mod R) ≡ᵣ x * -1.
+ Proof using N'_good.
+ rewrite <- N'_good', Z.mul_assoc.
+ unfold Z.equiv_modulo; push_Zmod.
+ reflexivity.
+ Qed.
+
+ Section redc.
+ Context (T : Z).
+
+ Local Notation m := (((T mod R) * N') mod R).
+ Local Notation prereduce := (prereduce N R N').
+
+ Local Ltac t_fin_correct :=
+ unfold Z.equiv_modulo; push_Zmod; autorewrite with zsimplify; reflexivity.
+
+ Lemma prereduce_correct : prereduce T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ transitivity ((T + m * N) * R').
+ { unfold prereduce.
+ autorewrite with zstrip_div; push_Zmod.
+ rewrite N'_good'_alt.
+ autorewrite with zsimplify pull_Zmod.
+ reflexivity. }
+ t_fin_correct.
+ Qed.
+
+ Lemma reduce_correct : reduce N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold reduce.
+ break_match; rewrite prereduce_correct; t_fin_correct.
+ Qed.
+
+ Lemma partial_reduce_correct : partial_reduce N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold partial_reduce.
+ break_match; rewrite prereduce_correct; t_fin_correct.
+ Qed.
+
+ Lemma reduce_via_partial_correct : reduce_via_partial N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold reduce_via_partial.
+ break_match; rewrite partial_reduce_correct; t_fin_correct.
+ Qed.
+
+ Let m_small : 0 <= m < R. Proof. auto with zarith. Qed.
+
+ Section generic.
+ Lemma prereduce_in_range_gen B
+ : 0 <= N
+ -> 0 <= T <= R * B
+ -> 0 <= prereduce T < B + N.
+ Proof using N_reasonable m_small. unfold prereduce; auto with zarith nia. Qed.
+ End generic.
+
+ Section N_very_small.
+ Context (N_very_small : 0 <= 4 * N < R).
+
+ Lemma prereduce_in_range_very_small
+ : 0 <= T <= (2 * N - 1) * (2 * N - 1)
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable N_very_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+ End N_very_small.
+
+ Section N_small.
+ Context (N_small : 0 <= 2 * N < R).
+
+ Lemma prereduce_in_range_small
+ : 0 <= T <= (2 * N - 1) * (N - 1)
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+
+ Lemma prereduce_in_range_small_fully_reduced
+ : 0 <= T <= 2 * N
+ -> 0 <= prereduce T <= N.
+ Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen 1); nia. Qed.
+ End N_small.
+
+ Section N_small_enough.
+ Context (N_small_enough : 0 <= N < R).
+
+ Lemma prereduce_in_range_small_enough
+ : 0 <= T <= R * R
+ -> 0 <= prereduce T < R + N.
+ Proof using N_reasonable N_small_enough m_small. pose proof (prereduce_in_range_gen R); nia. Qed.
+
+ Lemma reduce_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= reduce N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma partial_reduce_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= partial_reduce N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma reduce_via_partial_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= reduce_via_partial N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold reduce_via_partial, partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+ End N_small_enough.
+
+ Section unconstrained.
+ Lemma prereduce_in_range
+ : 0 <= T <= R * N
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+
+ Lemma reduce_in_range
+ : 0 <= T <= R * N
+ -> 0 <= reduce N R N' T < N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (prereduce_in_range H).
+ unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma partial_reduce_in_range
+ : 0 <= T <= R * N
+ -> Z.min 0 (R - N) <= partial_reduce N R N' T < 2 * N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (prereduce_in_range H).
+ unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt;
+ apply Z.min_case_strong; nia.
+ Qed.
+
+ Lemma reduce_via_partial_in_range
+ : 0 <= T <= R * N
+ -> Z.min 0 (R - N) <= reduce_via_partial N R N' T < N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (partial_reduce_in_range H).
+ unfold reduce_via_partial in *; break_match; Z.ltb_to_lt; lia.
+ Qed.
+ End unconstrained.
+
+ Section alt.
+ Context (N_in_range : 0 <= N < R)
+ (T_representable : 0 <= T < R * R).
+ Lemma partial_reduce_alt_eq : partial_reduce_alt N R N' T = partial_reduce N R N' T.
+ Proof using N_in_range N_reasonable T_representable m_small.
+ assert (0 <= T + m * N < 2 * (R * R)) by nia.
+ assert (0 <= T + m * N < R * (R + N)) by nia.
+ assert (0 <= (T + m * N) / R < R + N) by auto with zarith.
+ assert ((T + m * N) / R - N < R) by lia.
+ assert (R * R <= T + m * N -> R <= (T + m * N) / R) by auto with zarith.
+ assert (T + m * N < R * R -> (T + m * N) / R < R) by auto with zarith.
+ assert (H' : (T + m * N) mod (R * R) = if R * R <=? T + m * N then T + m * N - R * R else T + m * N)
+ by (break_match; Z.ltb_to_lt; autorewrite with zsimplify; lia).
+ unfold partial_reduce, partial_reduce_alt, prereduce.
+ rewrite H'; clear H'.
+ simplify_repeated_ifs.
+ set (m' := m) in *.
+ autorewrite with zsimplify; push_Zmod; autorewrite with zsimplify; pull_Zmod.
+ break_match; Z.ltb_to_lt; autorewrite with zsimplify; try reflexivity; lia.
+ Qed.
+ End alt.
+ End redc.
+
+ (** * Arithmetic in Montgomery form *)
+ Section arithmetic.
+ Local Infix "*" := (mul N R N') : montgomery_scope.
+
+ Local Notation to_montgomery := (to_montgomery N R N').
+ Local Notation from_montgomery := (from_montgomery N R N').
+ Lemma to_from_montgomery a : to_montgomery (from_montgomery a) ≡ a.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, from_montgomery.
+ transitivity ((a * 1) * 1); [ | apply f_equal2; lia ].
+ rewrite <- !R'_good, !reduce_correct.
+ unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ apply f_equal2; lia.
+ Qed.
+ Lemma from_to_montgomery a : from_montgomery (to_montgomery a) ≡ a.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, from_montgomery.
+ rewrite !reduce_correct.
+ transitivity (a * ((R * (R * R' mod N) * R') mod N)).
+ { unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ apply f_equal2; lia. }
+ { repeat first [ rewrite R'_good
+ | reflexivity
+ | push_Zmod; pull_Zmod; progress autorewrite with zsimplify
+ | progress unfold Z.equiv_modulo ]. }
+ Qed.
+
+ Theorem mul_correct x y : from_montgomery (x * y) ≡ from_montgomery x * from_montgomery y.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold from_montgomery, mul.
+ rewrite !reduce_correct; apply f_equal2; lia.
+ Qed.
+ Theorem mul_correct_to x y : to_montgomery (x * y) ≡ (to_montgomery x * to_montgomery y)%montgomery.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, mul.
+ rewrite !reduce_correct.
+ transitivity (x * y * R * 1 * 1 * 1);
+ [ rewrite <- R'_good at 1
+ | rewrite <- R'_good at 1 2 3 ];
+ autorewrite with zsimplify;
+ unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ { apply f_equal2; lia. }
+ { apply f_equal2; lia. }
+ Qed.
+ End arithmetic.
+ End redc.
+End montgomery.
+
+Module Import LocalizeEquivModuloInstances := Z.RemoveEquivModuloInstances Nop.
diff --git a/src/Arithmetic/PrimeFieldTheorems.v b/src/Arithmetic/PrimeFieldTheorems.v
new file mode 100644
index 000000000..c253752c5
--- /dev/null
+++ b/src/Arithmetic/PrimeFieldTheorems.v
@@ -0,0 +1,294 @@
+Require Export Crypto.Spec.ModularArithmetic.
+Require Export Crypto.Arithmetic.ModularArithmeticTheorems.
+Require Export Coq.setoid_ring.Ring_theory Coq.setoid_ring.Field_theory Coq.setoid_ring.Field_tac.
+
+Require Import Coq.nsatz.Nsatz.
+Require Import Crypto.Arithmetic.ModularArithmeticPre.
+Require Import Crypto.Util.NumTheoryUtil.
+Require Import Coq.Classes.Morphisms Coq.Setoids.Setoid.
+Require Import Coq.ZArith.BinInt Coq.NArith.BinNat Coq.ZArith.ZArith Coq.ZArith.Znumtheory Coq.NArith.NArith. (* import Zdiv before Znumtheory *)
+Require Import Coq.Logic.Eqdep_dec.
+Require Import Crypto.Util.NumTheoryUtil Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.SpecializeBy.
+Require Import Crypto.Util.Decidable.
+Require Export Crypto.Util.FixCoqMistakes.
+Require Crypto.Algebra.Hierarchy Crypto.Algebra.Field.
+
+Existing Class prime.
+Local Open Scope F_scope.
+
+Module F.
+ Section Field.
+ Context (q:positive) {prime_q:prime q}.
+ Lemma inv_spec : F.inv 0%F = (0%F : F q)
+ /\ (prime q -> forall x : F q, x <> 0%F -> (F.inv x * x)%F = 1%F).
+ Proof using Type. change (@F.inv q) with (proj1_sig (@F.inv_with_spec q)); destruct (@F.inv_with_spec q); eauto. Qed.
+
+ Lemma inv_0 : F.inv 0%F = F.of_Z q 0. Proof using Type. destruct inv_spec; auto. Qed.
+ Lemma inv_nonzero (x:F q) : (x <> 0 -> F.inv x * x%F = 1)%F. Proof using Type*. destruct inv_spec; auto. Qed.
+
+ Global Instance field_modulo : @Algebra.Hierarchy.field (F q) Logic.eq 0%F 1%F F.opp F.add F.sub F.mul F.inv F.div.
+ Proof using Type*.
+ repeat match goal with
+ | _ => solve [ solve_proper
+ | apply F.commutative_ring_modulo
+ | apply inv_nonzero
+ | cbv [not]; pose proof prime_ge_2 q prime_q;
+ rewrite F.eq_to_Z_iff, !F.to_Z_of_Z, !Zmod_small; omega ]
+ | _ => split
+ end.
+ Qed.
+ End Field.
+
+ Section NumberThoery.
+ Context {q:positive} {prime_q:prime q} {two_lt_q: 2 < q}.
+
+ (* TODO: move to PrimeFieldTheorems *)
+ Lemma to_Z_1 : @F.to_Z q 1 = 1%Z.
+ Proof using two_lt_q. simpl. rewrite Zmod_small; omega. Qed.
+
+ Lemma Fq_inv_fermat (x:F q) : F.inv x = x ^ Z.to_N (q - 2)%Z.
+ Proof using Type*.
+ destruct (dec (x = 0%F)) as [?|Hnz].
+ { subst x; rewrite inv_0, F.pow_0_l; trivial.
+ change (0%N) with (Z.to_N 0%Z); rewrite Z2N.inj_iff; omega. }
+ erewrite <-Algebra.Field.inv_unique; try reflexivity.
+ rewrite F.eq_to_Z_iff, F.to_Z_mul, F.to_Z_pow, Z2N.id, to_Z_1 by omega.
+ apply (fermat_inv q _ (F.to_Z x)); rewrite F.mod_to_Z; eapply F.to_Z_nonzero; trivial.
+ Qed.
+
+ Lemma euler_criterion (a : F q) (a_nonzero : a <> 0) :
+ (a ^ (Z.to_N (q / 2)) = 1) <-> (exists b, b*b = a).
+ Proof using Type*.
+ pose proof F.to_Z_nonzero_range a; pose proof (odd_as_div q).
+ specialize_by (destruct (Z.prime_odd_or_2 _ prime_q); try omega; trivial).
+ rewrite F.eq_to_Z_iff, !F.to_Z_pow, !to_Z_1, !Z2N.id by omega.
+ rewrite F.square_iff, <-(euler_criterion (q/2)) by (trivial || omega); reflexivity.
+ Qed.
+
+ Global Instance Decidable_square (x:F q) : Decidable (exists y, y*y = x).
+ Proof.
+ destruct (dec (x = 0)).
+ { left. abstract (exists 0; subst; apply Ring.mul_0_l). }
+ { eapply Decidable_iff_to_impl; [eapply euler_criterion; assumption | exact _]. }
+ Defined.
+ End NumberThoery.
+
+ Section SquareRootsPrime3Mod4.
+ Context {q:positive} {prime_q: prime q} {q_3mod4 : q mod 4 = 3}.
+
+ Add Field _field2 : (Algebra.Field.field_theory_for_stdlib_tactic(T:=F q))
+ (morphism (F.ring_morph q),
+ constants [F.is_constant],
+ div (F.morph_div_theory q),
+ power_tac (F.power_theory q) [F.is_pow_constant]).
+
+ Definition sqrt_3mod4 (a : F q) : F q := a ^ Z.to_N (q / 4 + 1).
+
+ Global Instance Proper_sqrt_3mod4 : Proper (eq ==> eq ) sqrt_3mod4.
+ Proof using Type. repeat intro; subst; reflexivity. Qed.
+
+ Lemma two_lt_q_3mod4 : 2 < q.
+ Proof using Type*.
+ pose proof (prime_ge_2 q _) as two_le_q.
+ destruct (Zle_lt_or_eq _ _ two_le_q) as [H|H]; [exact H|].
+ rewrite <-H in q_3mod4; discriminate.
+ Qed.
+ Local Hint Resolve two_lt_q_3mod4.
+
+ Lemma sqrt_3mod4_correct (x:F q) :
+ ((exists y, y*y = x) <-> (sqrt_3mod4 x)*(sqrt_3mod4 x) = x)%F.
+ Proof using Type*.
+ cbv [sqrt_3mod4]; intros.
+ destruct (F.eq_dec x 0);
+ repeat match goal with
+ | |- _ => progress subst
+ | |- _ => progress rewrite ?F.pow_0_l, <-?F.pow_add_r
+ | |- _ => progress rewrite <-?Z2N.inj_0, <-?Z2N.inj_add by zero_bounds
+ | |- _ => rewrite <-@euler_criterion by auto
+ | |- ?x ^ (?f _) = ?a <-> ?x ^ (?f _) = ?a => do 3 f_equiv; [ ]
+ | |- _ => rewrite !Zmod_odd in *; repeat (break_match; break_match_hyps); omega
+ | |- _ => rewrite Z.rem_mul_r in * by omega
+ | |- (exists x, _) <-> ?B => assert B by field; solve [intuition eauto]
+ | |- (?x ^ Z.to_N ?a = 1) <-> _ =>
+ transitivity (x ^ Z.to_N a * x ^ Z.to_N 1 = x);
+ [ rewrite F.pow_1_r, Algebra.Field.mul_cancel_l_iff by auto; reflexivity | ]
+ | |- (_ <> _)%N => rewrite Z2N.inj_iff by zero_bounds
+ | |- (?a <> 0)%Z => assert (0 < a) by zero_bounds; omega
+ | |- (_ = _)%Z => replace 4 with (2 * 2)%Z in * by ring;
+ rewrite <-Z.div_div by zero_bounds;
+ rewrite Z.add_diag, Z.mul_add_distr_l, Z.mul_div_eq by omega
+ end.
+ Qed.
+ End SquareRootsPrime3Mod4.
+
+ Section SquareRootsPrime5Mod8.
+ Context {q:positive} {prime_q: prime q} {q_5mod8 : q mod 8 = 5}.
+ Local Open Scope F_scope.
+ Add Field _field3 : (Algebra.Field.field_theory_for_stdlib_tactic(T:=F q))
+ (morphism (F.ring_morph q),
+ constants [F.is_constant],
+ div (F.morph_div_theory q),
+ power_tac (F.power_theory q) [F.is_pow_constant]).
+
+ (* Any nonsquare element raised to (q-1)/4 (real implementations use 2 ^ ((q-1)/4) )
+ would work for sqrt_minus1 *)
+ Context (sqrt_minus1 : F q) (sqrt_minus1_valid : sqrt_minus1 * sqrt_minus1 = F.opp 1).
+
+ Lemma two_lt_q_5mod8 : 2 < q.
+ Proof using prime_q q_5mod8.
+ pose proof (prime_ge_2 q _) as two_le_q.
+ destruct (Zle_lt_or_eq _ _ two_le_q) as [H|H]; [exact H|].
+ rewrite <-H in *. discriminate.
+ Qed.
+ Local Hint Resolve two_lt_q_5mod8.
+
+ Definition sqrt_5mod8 (a : F q) : F q :=
+ let b := a ^ Z.to_N (q / 8 + 1) in
+ if dec (b ^ 2 = a)
+ then b
+ else sqrt_minus1 * b.
+
+ Global Instance Proper_sqrt_5mod8 : Proper (eq ==> eq ) sqrt_5mod8.
+ Proof using Type. repeat intro; subst; reflexivity. Qed.
+
+ Lemma eq_b4_a2 (x : F q) (Hex:exists y, y*y = x) :
+ ((x ^ Z.to_N (q / 8 + 1)) ^ 2) ^ 2 = x ^ 2.
+ Proof using prime_q q_5mod8.
+ pose proof two_lt_q_5mod8.
+ assert (0 <= q/8)%Z by (apply Z.div_le_lower_bound; rewrite ?Z.mul_0_r; omega).
+ assert (Z.to_N (q / 8 + 1) <> 0%N) by
+ (intro Hbad; change (0%N) with (Z.to_N 0%Z) in Hbad; rewrite Z2N.inj_iff in Hbad; omega).
+ destruct (dec (x = 0)); [subst; rewrite !F.pow_0_l by (trivial || lazy_decide); reflexivity|].
+ rewrite !F.pow_pow_l.
+
+ replace (Z.to_N (q / 8 + 1) * (2*2))%N with (Z.to_N (q / 2 + 2))%N.
+ Focus 2. { (* this is a boring but gnarly proof :/ *)
+ change (2*2)%N with (Z.to_N 4).
+ rewrite <- Z2N.inj_mul by zero_bounds.
+ apply Z2N.inj_iff; try zero_bounds.
+ rewrite <- Z.mul_cancel_l with (p := 2) by omega.
+ ring_simplify.
+ rewrite Z.mul_div_eq by omega.
+ rewrite Z.mul_div_eq by omega.
+ rewrite (Zmod_div_mod 2 8 q) by
+ (try omega; apply Zmod_divide; omega || auto).
+ rewrite q_5mod8.
+ replace (5 mod 2)%Z with 1%Z by auto.
+ ring.
+ } Unfocus.
+
+ rewrite Z2N.inj_add, F.pow_add_r by zero_bounds.
+ replace (x ^ Z.to_N (q / 2)) with (F.of_Z q 1) by
+ (symmetry; apply @euler_criterion; eauto).
+ change (Z.to_N 2) with 2%N; ring.
+ Qed.
+
+ Lemma mul_square_sqrt_minus1 : forall x, sqrt_minus1 * x * (sqrt_minus1 * x) = F.opp (x * x).
+ Proof using prime_q sqrt_minus1_valid.
+ intros.
+ transitivity (F.opp 1 * (x * x)); [ | field].
+ rewrite <-sqrt_minus1_valid.
+ field.
+ Qed.
+
+ Lemma eq_b4_a2_iff (x : F q) : x <> 0 ->
+ ((exists y, y*y = x) <-> ((x ^ Z.to_N (q / 8 + 1)) ^ 2) ^ 2 = x ^ 2).
+ Proof using Type*.
+ split; try apply eq_b4_a2.
+ intro Hyy.
+ rewrite !@F.pow_2_r in *.
+ destruct (Field.only_two_square_roots_choice _ x (x * x) Hyy eq_refl); clear Hyy;
+ [ eexists; eassumption | ].
+ match goal with H : ?a * ?a = F.opp _ |- _ => exists (sqrt_minus1 * a);
+ rewrite mul_square_sqrt_minus1; rewrite H end.
+ field.
+ Qed.
+
+ Lemma sqrt_5mod8_correct : forall x,
+ ((exists y, y*y = x) <-> (sqrt_5mod8 x)*(sqrt_5mod8 x) = x).
+ Proof using Type*.
+ cbv [sqrt_5mod8]; intros.
+ destruct (F.eq_dec x 0).
+ {
+ repeat match goal with
+ | |- _ => progress subst
+ | |- _ => progress rewrite ?F.pow_0_l
+ | |- (_ <> _)%N => rewrite <-Z2N.inj_0, Z2N.inj_iff by zero_bounds
+ | |- (?a <> 0)%Z => assert (0 < a) by zero_bounds; omega
+ | |- _ => congruence
+ end.
+ break_match;
+ match goal with |- _ <-> ?G => assert G by field end; intuition eauto.
+ } {
+ rewrite eq_b4_a2_iff by auto.
+ rewrite !@F.pow_2_r in *.
+ break_match.
+ intuition (f_equal; eauto).
+ split; intro A. {
+ destruct (Field.only_two_square_roots_choice _ x (x * x) A eq_refl) as [B | B];
+ clear A; try congruence.
+ rewrite mul_square_sqrt_minus1, B; field.
+ } {
+ rewrite mul_square_sqrt_minus1 in A.
+ transitivity (F.opp x * F.opp x); [ | field ].
+ f_equal; rewrite <-A at 3; field.
+ }
+ }
+ Qed.
+ End SquareRootsPrime5Mod8.
+
+ Module Iso.
+ Section IsomorphicRings.
+ Context {q:positive} {prime_q:prime q} {two_lt_q:2 < Z.pos q}.
+ Context
+ {H : Type} {eq : H -> H -> Prop} {zero one : H}
+ {opp : H -> H} {add sub mul : H -> H -> H}
+ {phi : F q -> H} {phi' : H -> F q}
+ {phi'_phi : forall A : F q, Logic.eq (phi' (phi A)) A}
+ {phi'_iff : forall a b : H, iff (Logic.eq (phi' a) (phi' b)) (eq a b)}
+ {phi'_zero : Logic.eq (phi' zero) F.zero} {phi'_one : Logic.eq (phi' one) F.one}
+ {phi'_opp : forall a : H, Logic.eq (phi' (opp a)) (F.opp (phi' a))}
+ {phi'_add : forall a b : H, Logic.eq (phi' (add a b)) (F.add (phi' a) (phi' b))}
+ {phi'_sub : forall a b : H, Logic.eq (phi' (sub a b)) (F.sub (phi' a) (phi' b))}
+ {phi'_mul : forall a b : H, Logic.eq (phi' (mul a b)) (F.mul (phi' a) (phi' b))}
+ {P:Type} {pow : H -> P -> H} {NtoP:N->P}
+ {pow_is_scalarmult:ScalarMult.is_scalarmult(G:=H)(eq:=eq)(add:=mul)(zero:=one)(mul:=fun (n:nat) (k:H) => pow k (NtoP (N.of_nat n)))}.
+ Definition inv (x:H) := pow x (NtoP (Z.to_N (q - 2)%Z)).
+ Definition div x y := mul (inv y) x.
+
+ Lemma ring :
+ @Algebra.Hierarchy.ring H eq zero one opp add sub mul
+ /\ @Ring.is_homomorphism (F q) Logic.eq F.one F.add F.mul H eq one add mul phi
+ /\ @Ring.is_homomorphism H eq one add mul (F q) Logic.eq F.one F.add F.mul phi'.
+ Proof using phi'_add phi'_iff phi'_mul phi'_one phi'_opp phi'_phi phi'_sub phi'_zero. eapply @Ring.ring_by_isomorphism; assumption || exact _. Qed.
+ Local Instance _iso_ring : Algebra.Hierarchy.ring := proj1 ring.
+ Local Instance _iso_hom1 : Ring.is_homomorphism := proj1 (proj2 ring).
+ Local Instance _iso_hom2 : Ring.is_homomorphism := proj2 (proj2 ring).
+
+ Let inv_proof : forall a : H, phi' (inv a) = F.inv (phi' a).
+ Proof.
+ intros.
+ cbv [inv]. rewrite (Fq_inv_fermat(q:=q)(two_lt_q:=two_lt_q)).
+ rewrite <-Z_nat_N at 1 2.
+ rewrite (ScalarMult.homomorphism_scalarmult(phi:=phi')(MUL_is_scalarmult:=pow_is_scalarmult)(mul_is_scalarmult:=F.pow_is_scalarmult)).
+ reflexivity.
+ assumption.
+ Qed.
+
+ Let div_proof : forall a b : H, phi' (mul (inv b) a) = phi' a / phi' b.
+ Proof.
+ intros.
+ rewrite phi'_mul, inv_proof, Algebra.Hierarchy.field_div_definition, Algebra.Hierarchy.commutative.
+ reflexivity.
+ Qed.
+
+ Lemma field_and_iso :
+ @Algebra.Hierarchy.field H eq zero one opp add sub mul inv div
+ /\ @Ring.is_homomorphism (F q) Logic.eq F.one F.add F.mul H eq one add mul phi
+ /\ @Ring.is_homomorphism H eq one add mul (F q) Logic.eq F.one F.add F.mul phi'.
+ Proof using Type*. eapply @Field.field_and_homomorphism_from_redundant_representation;
+ assumption || exact _ || exact inv_proof || exact div_proof. Qed.
+ End IsomorphicRings.
+ End Iso.
+End F.
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
new file mode 100644
index 000000000..cb37fb1f9
--- /dev/null
+++ b/src/Arithmetic/Saturated.v
@@ -0,0 +1,285 @@
+Require Import Coq.Init.Nat.
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Algebra.Nsatz.
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
+Require Import Crypto.Util.Tuple Crypto.Util.ListUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+(***
+
+Arithmetic on bignums that handles carry bits; this is useful for
+saturated limbs. Compatible with mixed-radix bases.
+
+ ***)
+
+Module Columns.
+ Section Columns.
+ Context {weight : nat->Z}
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ (* add_get_carry takes in a number at which to split output *)
+ {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {add_get_carry_correct : forall s x y,
+ fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)}
+ .
+
+ Definition eval {n} (x : (list Z)^n) : Z :=
+ B.Positional.eval weight (Tuple.map sum x).
+
+ Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z :=
+ B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x).
+
+ Lemma eval_from_0 {n} x : @eval_from n 0 x = eval x.
+ Proof using Type. cbv [eval_from eval]. auto using B.Positional.eval_wt_equiv. Qed.
+
+ Lemma eval_from_S {n}: forall i (inp : (list Z)^(S n)),
+ eval_from i inp = eval_from (S i) (tl inp) + weight i * sum (hd inp).
+ Proof using Type.
+ intros; cbv [eval_from].
+ replace inp with (append (hd inp) (tl inp))
+ by (simpl in *; destruct n; destruct inp; reflexivity).
+ rewrite map_append, B.Positional.eval_step, hd_append, tl_append.
+ autorewrite with natsimplify; ring_simplify; rewrite Group.cancel_left.
+ apply B.Positional.eval_wt_equiv; intros; f_equal; omega.
+ Qed.
+
+ (* Sums a list of integers using carry bits.
+ Output : next index, carry, sum
+ *)
+ Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) :=
+ match digit with
+ | nil => f (0, 0)
+ | x :: nil => f (0, x)
+ | x :: tl =>
+ compact_digit_cps n tl (fun rec =>
+ dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
+ dlet carry' := (fst rec + snd sum_carry)%RT in
+ f (carry', fst sum_carry))
+ end.
+
+ Definition compact_digit n digit := compact_digit_cps n digit id.
+ Lemma compact_digit_id n digit: forall {T} f,
+ @compact_digit_cps n digit T f = f (compact_digit n digit).
+ Proof using Type.
+ induction digit; intros; cbv [compact_digit]; [reflexivity|];
+ simpl compact_digit_cps; break_match; [reflexivity|].
+ rewrite !IHdigit; reflexivity.
+ Qed.
+ Hint Opaque compact_digit : uncps.
+ Hint Rewrite compact_digit_id : uncps.
+
+ Definition compact_step_cps (index:nat) (carry:Z) (digit: list Z)
+ {T} (f:Z * Z->T) :=
+ compact_digit_cps index (carry::digit) f.
+
+ Definition compact_step i c d := compact_step_cps i c d id.
+ Lemma compact_step_id i c d T f :
+ @compact_step_cps i c d T f = f (compact_step i c d).
+ Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
+ Hint Opaque compact_step : uncps.
+ Hint Rewrite compact_step_id : uncps.
+
+ Definition compact_cps {n} (xs : (list Z)^n) {T} (f:Z * Z^n->T) :=
+ mapi_with_cps compact_step_cps 0 xs f.
+
+ Definition compact {n} xs := @compact_cps n xs _ id.
+ Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs).
+ Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
+
+ Lemma compact_digit_correct i (xs : list Z) :
+ snd (compact_digit i xs) = sum xs - (weight (S i) / weight i) * (fst (compact_digit i xs)).
+ Proof using add_get_carry_correct weight_0.
+ induction xs; cbv [compact_digit]; simpl compact_digit_cps;
+ cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite add_get_carry_correct
+ | _ => progress (rewrite ?sum_cons, ?sum_nil in * )
+ | _ => progress (autorewrite with uncps push_id in * )
+ | _ => progress (autorewrite with cancel_pair in * )
+ | _ => progress break_match; try discriminate
+ | _ => progress ring_simplify
+ | _ => reflexivity
+ | _ => nsatz
+ end.
+ Qed.
+
+ Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) :=
+ B.Positional.eval_from weight i out + weight (i + n) * (rem)
+ = eval_from i inp + weight i*starter.
+
+ Lemma compact_invariant_holds n i starter rem inp out :
+ compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out ->
+ compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out).
+ Proof using Type*.
+ cbv [compact_invariant B.Positional.eval_from]; intros.
+ repeat match goal with
+ | _ => rewrite B.Positional.eval_step
+ | _ => rewrite eval_from_S
+ | _ => rewrite sum_cons
+ | _ => rewrite weight_multiples
+ | _ => rewrite Nat.add_succ_l in *
+ | _ => rewrite Nat.add_succ_r in *
+ | _ => (rewrite fst_fst_compact_step in * )
+ | _ => progress ring_simplify
+ | _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero
+ | _ => cbv [compact_step_cps] in *;
+ autorewrite with uncps push_id;
+ rewrite compact_digit_correct
+ | _ => progress (autorewrite with natsimplify in * )
+ end.
+ rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega).
+ nsatz.
+ Qed.
+
+ Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt.
+ Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed.
+
+ Lemma compact_invariant_end {n} start (input : (list Z)^n):
+ compact_invariant n 0%nat start (fst (mapi_with_cps compact_step_cps start input id)) input (snd (mapi_with_cps compact_step_cps start input id)).
+ Proof using Type*.
+ autorewrite with uncps push_id.
+ apply (mapi_with_invariant _ compact_invariant
+ compact_invariant_holds compact_invariant_base).
+ Qed.
+
+ Lemma eval_compact {n} (xs : tuple (list Z) n) :
+ B.Positional.eval weight (snd (compact xs)) + (weight n * fst (compact xs)) = eval xs.
+ Proof using Type*.
+ pose proof (compact_invariant_end 0 xs) as Hinv.
+ cbv [compact_invariant] in Hinv.
+ simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv.
+ rewrite eval_from_0, B.Positional.eval_from_0 in Hinv; apply Hinv.
+ Qed.
+
+ Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n)
+ {T} (f:(list Z)^n->T) :=
+ @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f.
+
+ Definition cons_to_nth {n} i x t := @cons_to_nth_cps n i x t _ id.
+ Lemma cons_to_nth_id {n} i x t T f :
+ @cons_to_nth_cps n i x t T f = f (cons_to_nth i x t).
+ Proof using Type.
+ cbv [cons_to_nth_cps cons_to_nth].
+ assert (forall xs : list (list Z), length xs = n ->
+ length (update_nth_cps i (cons x) xs id) = n) as Hlen.
+ { intros. autorewrite with uncps push_id distr_length. assumption. }
+ rewrite !on_tuple_cps_correct with (H:=Hlen)
+ by (intros; autorewrite with uncps push_id; reflexivity). reflexivity.
+ Qed.
+ Hint Opaque cons_to_nth : uncps.
+ Hint Rewrite @cons_to_nth_id : uncps.
+
+ Lemma map_sum_update_nth l : forall i x,
+ List.map sum (update_nth i (cons x) l) =
+ update_nth i (Z.add x) (List.map sum l).
+ Proof using Type.
+ induction l; intros; destruct i; simpl; rewrite ?IHl; reflexivity.
+ Qed.
+
+ Lemma cons_to_nth_add_to_nth n i x t :
+ map sum (@cons_to_nth n i x t) = B.Positional.add_to_nth i x (map sum t).
+ Proof using weight.
+ cbv [B.Positional.add_to_nth B.Positional.add_to_nth_cps cons_to_nth cons_to_nth_cps on_tuple_cps].
+ induction n; [simpl; rewrite !update_nth_cps_correct; reflexivity|].
+ specialize (IHn (tl t)). autorewrite with uncps push_id in *.
+ apply to_list_ext. rewrite <-!map_to_list.
+ erewrite !from_list_default_eq, !to_list_from_list.
+ rewrite map_sum_update_nth. reflexivity.
+ Unshelve.
+ distr_length.
+ distr_length.
+ Qed.
+
+ Lemma eval_cons_to_nth n i x t : (i < n)%nat ->
+ eval (@cons_to_nth n i x t) = weight i * x + eval t.
+ Proof using Type.
+ cbv [eval]; intros. rewrite cons_to_nth_add_to_nth.
+ auto using B.Positional.eval_add_to_nth.
+ Qed.
+ Hint Rewrite eval_cons_to_nth using omega : push_basesystem_eval.
+
+ Definition nils n : (list Z)^n := Tuple.repeat nil n.
+
+ Lemma map_sum_nils n : map sum (nils n) = B.Positional.zeros n.
+ Proof using Type.
+ cbv [nils B.Positional.zeros]; induction n; [reflexivity|].
+ change (repeat nil (S n)) with (@nil Z :: repeat nil n).
+ rewrite map_repeat, sum_nil. reflexivity.
+ Qed.
+
+ Lemma eval_nils n : eval (nils n) = 0.
+ Proof using Type. cbv [eval]. rewrite map_sum_nils, B.Positional.eval_zeros. reflexivity. Qed. Hint Rewrite eval_nils : push_basesystem_eval.
+
+ Definition from_associational_cps n (p:list B.limb)
+ {T} (f:(list Z)^n -> T) :=
+ fold_right_cps
+ (fun t st =>
+ B.Positional.place_cps weight t (pred n)
+ (fun p=> cons_to_nth_cps (fst p) (snd p) st id))
+ (nils n) p f.
+
+ Definition from_associational n p := from_associational_cps n p id.
+ Lemma from_associational_id n p T f :
+ @from_associational_cps n p T f = f (from_associational n p).
+ Proof using Type.
+ cbv [from_associational_cps from_associational].
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque from_associational : uncps.
+ Hint Rewrite from_associational_id : uncps.
+
+ Lemma eval_from_associational n p (n_nonzero:n<>0%nat):
+ eval (from_associational n p) = B.Associational.eval p.
+ Proof using weight_0 weight_nonzero.
+ cbv [from_associational_cps from_associational]; induction p;
+ autorewrite with uncps push_id push_basesystem_eval; [reflexivity|].
+ pose proof (B.Positional.weight_place_cps weight weight_0 weight_nonzero a (pred n)).
+ pose proof (B.Positional.place_cps_in_range weight a (pred n)).
+ rewrite Nat.succ_pred in * by assumption. simpl.
+ autorewrite with uncps push_id push_basesystem_eval in *.
+ rewrite eval_cons_to_nth by omega. nsatz.
+ Qed.
+
+ Definition mul_cps {n m} (p q : Z^n) {T} (f : (list Z)^m->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => B.Associational.mul_cps P Q
+ (fun PQ => from_associational_cps m PQ f))).
+
+ Definition add_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => from_associational_cps n (P++Q) f)).
+
+ End Columns.
+End Columns.
+
+(*
+(* Just some pretty-printing *)
+Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing).
+Local Notation "snd~ a" := (let (_,y) := a in y) (at level 40, only printing).
+
+(* Simple example : base 10, multiply two bignums and compact them *)
+Definition base10 i := Eval compute in 10^(Z.of_nat i).
+Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 b0 b1 b2 =>
+ Columns.mul_cps (weight := base10) (n:=3) (a2,a1,a0) (b2,b1,b0) (fun ab => Columns.compact (n:=5) (add_get_carry:=adc) (weight:=base10) ab)).
+
+(* More complex example : base 2^56, 8 limbs *)
+Definition base2pow56 i := Eval compute in 2^(56*Z.of_nat i).
+Time Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 a3 a4 a5 a6 a7 b0 b1 b2 b3 b4 b5 b6 b7 =>
+ Columns.mul_cps (weight := base2pow56) (n:=8) (a7,a6,a5,a4,a3,a2,a1,a0) (b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=15) (add_get_carry:=adc) (weight:=base2pow56) ab)). (* Finished transaction in 151.392 secs *)
+
+(* Mixed-radix example : base 2^25.5, 10 limbs *)
+Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1) / 2)).
+Time Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 =>
+ Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *)
+*)