aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/MontgomeryReduction
diff options
context:
space:
mode:
authorGravatar Andres Erbsen <andreser@mit.edu>2017-04-06 22:53:07 -0400
committerGravatar Andres Erbsen <andreser@mit.edu>2017-04-06 22:53:07 -0400
commitc9fc5a3cdf1f5ea2d104c150c30d1b1a6ac64239 (patch)
treedb7187f6984acff324ca468e7b33d9285806a1eb /src/Arithmetic/MontgomeryReduction
parent21198245dab432d3c0ba2bb8a02254e7d0594382 (diff)
rename-everything
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
-rw-r--r--src/Arithmetic/MontgomeryReduction/Definition.v179
-rw-r--r--src/Arithmetic/MontgomeryReduction/Proofs.v296
2 files changed, 475 insertions, 0 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/Definition.v b/src/Arithmetic/MontgomeryReduction/Definition.v
new file mode 100644
index 000000000..78d3c037f
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/Definition.v
@@ -0,0 +1,179 @@
+(*** Montgomery Multiplication *)
+(** This file implements Montgomery Form, Montgomery Reduction, and
+ Montgomery Multiplication on [Z]. We follow Wikipedia. *)
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Notations.
+
+Local Open Scope Z_scope.
+Delimit Scope montgomery_naive_scope with montgomery_naive.
+Delimit Scope montgomery_scope with montgomery.
+Definition montgomeryZ := Z.
+Bind Scope montgomery_scope with montgomeryZ.
+
+Section montgomery.
+ Context (N : Z)
+ (R : Z)
+ (R' : Z). (* R' is R⁻¹ mod N *)
+ Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope.
+ Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope.
+ (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Montgomery_modular_multiplication>: *)
+ (** In modular arithmetic computation, Montgomery modular
+ multiplication, more commonly referred to as Montgomery
+ multiplication, is a method for performing fast modular
+ multiplication, introduced in 1985 by the American mathematician
+ Peter L. Montgomery. *)
+ (** Given two integers [a] and [b], the classical modular
+ multiplication algorithm computes [ab mod N]. Montgomery
+ multiplication works by transforming [a] and [b] into a special
+ representation known as Montgomery form. For a modulus [N], the
+ Montgomery form of [a] is defined to be [aR mod N] for some
+ constant [R] depending only on [N] and the underlying computer
+ architecture. If [aR mod N] and [bR mod N] are the Montgomery
+ forms of [a] and [b], then their Montgomery product is [abR mod
+ N]. Montgomery multiplication is a fast algorithm to compute the
+ Montgomery product. Transforming the result out of Montgomery
+ form yields the classical modular product [ab mod N]. *)
+
+ Definition to_montgomery_naive (x : Z) : montgomeryZ := x * R.
+ Definition from_montgomery_naive (x : montgomeryZ) : Z := x * R'.
+
+ (** * Modular arithmetic and Montgomery form *)
+ Section general.
+ (** Addition and subtraction in Montgomery form are the same as
+ ordinary modular addition and subtraction because of the
+ distributive law: *)
+ (** [aR + bR = (a+b)R], *)
+ (** [aR - bR = (a-b)R]. *)
+ (** This is a consequence of the fact that, because [gcd(R, N) =
+ 1], multiplication by [R] is an isomorphism on the additive
+ group [ℤ/Nℤ]. *)
+
+ Definition add : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR + bR.
+ Definition sub : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR - bR.
+
+ (** Multiplication in Montgomery form, however, is seemingly more
+ complicated. The usual product of [aR] and [bR] does not
+ represent the product of [a] and [b] because it has an extra
+ factor of R: *)
+ (** [(aR mod N)(bR mod N) mod N = (abR)R mod N]. *)
+ (** Computing products in Montgomery form requires removing the
+ extra factor of [R]. While division by [R] is cheap, the
+ intermediate product [(aR mod N)(bR mod N)] is not divisible
+ by [R] because the modulo operation has destroyed that
+ property. *)
+ (** Removing the extra factor of R can be done by multiplying by
+ an integer [R′] such that [RR' ≡ 1 (mod N)], that is, by an
+ [R′] whose residue class is the modular inverse of [R] mod
+ [N]. Then, working modulo [N], *)
+ (** [(aR mod N)(bR mod N)R' ≡ (aR)(bR)R⁻¹ ≡ (ab)R (mod N)]. *)
+
+ Definition mul_naive : montgomeryZ -> montgomeryZ -> montgomeryZ
+ := fun aR bR => aR * bR * R'.
+ End general.
+
+ (** * The REDC algorithm *)
+ Section redc.
+ (** While the above algorithm is correct, it is slower than
+ multiplication in the standard representation because of the
+ need to multiply by [R′] and divide by [N]. Montgomery
+ reduction, also known as REDC, is an algorithm that
+ simultaneously computes the product by [R′] and reduces modulo
+ [N] more quickly than the naive method. The speed is because
+ all computations are done using only reduction and divisions
+ with respect to [R], not [N]: *)
+ (**
+<<
+function REDC is
+ input: Integers R and N with gcd(R, N) = 1,
+ Integer N′ in [0, R − 1] such that NN′ ≡ −1 mod R,
+ Integer T in the range [0, RN − 1]
+ output: Integer S in the range [0, N − 1] such that S ≡ TR⁻¹ mod N
+
+ m ← ((T mod R)N′) mod R
+ t ← (T + mN) / R
+ if t ≥ N then
+ return t − N
+ else
+ return t
+ end if
+end function
+>> *)
+ Context (N' : Z). (* N' is (-N⁻¹) mod R *)
+ Section redc.
+ Context (T : Z).
+
+ Let m := ((T mod R) * N') mod R.
+ Let t := (T + m * N) / R.
+ Definition prereduce : montgomeryZ := t.
+
+ Definition partial_reduce : montgomeryZ
+ := if R <=? t then
+ prereduce - N
+ else
+ prereduce.
+
+ Definition partial_reduce_alt : montgomeryZ
+ := let v0 := (T + m * N) in
+ let v := (v0 mod (R * R)) / R in
+ if R * R <=? v0 then
+ (v - N) mod R
+ else
+ v.
+
+ Definition reduce : montgomeryZ
+ := if N <=? t then
+ prereduce - N
+ else
+ prereduce.
+
+ Definition reduce_via_partial : montgomeryZ
+ := if N <=? partial_reduce then
+ partial_reduce - N
+ else
+ partial_reduce.
+
+ Definition reduce_via_partial_alt : montgomeryZ
+ := if N <=? partial_reduce then
+ partial_reduce - N
+ else
+ partial_reduce.
+ End redc.
+
+ (** * Arithmetic in Montgomery form *)
+ Section arithmetic.
+ (** Many operations of interest modulo [N] can be expressed
+ equally well in Montgomery form. Addition, subtraction,
+ negation, comparison for equality, multiplication by an
+ integer not in Montgomery form, and greatest common divisors
+ with [N] may all be done with the standard algorithms. *)
+ (** When [R > N], most other arithmetic operations can be
+ expressed in terms of REDC. This assumption implies that the
+ product of two representatives [mod N] is less than [RN],
+ the exact hypothesis necessary for REDC to generate correct
+ output. In particular, the product of [aR mod N] and [bR mod
+ N] is [REDC((aR mod N)(bR mod N))]. The combined operation
+ of multiplication and REDC is often called Montgomery
+ multiplication. *)
+ Definition mul : montgomeryZ -> montgomeryZ -> montgomeryZ
+ := fun aR bR => reduce (aR * bR).
+
+ (** Conversion into Montgomery form is done by computing
+ [REDC((a mod N)(R² mod N))]. Conversion out of Montgomery
+ form is done by computing [REDC(aR mod N)]. The modular
+ inverse of [aR mod N] is [REDC((aR mod N)⁻¹(R³ mod
+ N))]. Modular exponentiation can be done using
+ exponentiation by squaring by initializing the initial
+ product to the Montgomery representation of 1, that is, to
+ [R mod N], and by replacing the multiply and square steps by
+ Montgomery multiplies. *)
+ Definition to_montgomery (a : Z) : montgomeryZ := reduce (a * (R*R mod N)).
+ Definition from_montgomery (aR : montgomeryZ) : Z := reduce aR.
+ End arithmetic.
+ End redc.
+End montgomery.
+
+Infix "+" := add : montgomery_scope.
+Infix "-" := sub : montgomery_scope.
+Infix "*" := mul_naive : montgomery_naive_scope.
+Infix "*" := mul : montgomery_scope.
diff --git a/src/Arithmetic/MontgomeryReduction/Proofs.v b/src/Arithmetic/MontgomeryReduction/Proofs.v
new file mode 100644
index 000000000..d5de00213
--- /dev/null
+++ b/src/Arithmetic/MontgomeryReduction/Proofs.v
@@ -0,0 +1,296 @@
+(*** Montgomery Multiplication *)
+(** This file implements the proofs for Montgomery Form, Montgomery
+ Reduction, and Montgomery Multiplication on [Z]. We follow
+ Wikipedia. *)
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.Structures.Equalities.
+Require Import Crypto.Arithmetic.MontgomeryReduction.Definition.
+Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Require Import Crypto.Util.Tactics.SimplifyRepeatedIfs.
+Require Import Crypto.Util.Notations.
+
+Declare Module Nop : Nop.
+Module Import ImportEquivModuloInstances := Z.EquivModuloInstances Nop.
+
+Local Existing Instance eq_Reflexive. (* speed up setoid_rewrite as per https://coq.inria.fr/bugs/show_bug.cgi?id=4978 *)
+
+Local Open Scope Z_scope.
+
+Section montgomery.
+ Context (N : Z)
+ (N_reasonable : N <> 0)
+ (R : Z)
+ (R_good : Z.gcd N R = 1).
+ Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope.
+ Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope.
+ Context (R' : Z)
+ (R'_good : R * R' ≡ 1).
+
+ Lemma R'_good' : R' * R ≡ 1.
+ Proof using R'_good. rewrite <- R'_good; apply f_equal2; lia. Qed.
+
+ Local Notation to_montgomery_naive := (to_montgomery_naive R) (only parsing).
+ Local Notation from_montgomery_naive := (from_montgomery_naive R') (only parsing).
+
+ Lemma to_from_montgomery_naive x : to_montgomery_naive (from_montgomery_naive x) ≡ x.
+ Proof using R'_good.
+ unfold to_montgomery_naive, from_montgomery_naive.
+ rewrite <- Z.mul_assoc, R'_good'.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+ Lemma from_to_montgomery_naive x : from_montgomery_naive (to_montgomery_naive x) ≡ x.
+ Proof using R'_good.
+ unfold to_montgomery_naive, from_montgomery_naive.
+ rewrite <- Z.mul_assoc, R'_good.
+ autorewrite with zsimplify; reflexivity.
+ Qed.
+
+ (** * Modular arithmetic and Montgomery form *)
+ Section general.
+ Local Infix "+" := add : montgomery_scope.
+ Local Infix "-" := sub : montgomery_scope.
+ Local Infix "*" := (mul_naive R') : montgomery_scope.
+
+ Lemma add_correct_naive x y : from_montgomery_naive (x + y) = from_montgomery_naive x + from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, add; lia. Qed.
+ Lemma add_correct_naive_to x y : to_montgomery_naive (x + y) = (to_montgomery_naive x + to_montgomery_naive y)%montgomery.
+ Proof using Type. unfold to_montgomery_naive, add; autorewrite with push_Zmul; reflexivity. Qed.
+ Lemma sub_correct_naive x y : from_montgomery_naive (x - y) = from_montgomery_naive x - from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, sub; lia. Qed.
+ Lemma sub_correct_naive_to x y : to_montgomery_naive (x - y) = (to_montgomery_naive x - to_montgomery_naive y)%montgomery.
+ Proof using Type. unfold to_montgomery_naive, sub; autorewrite with push_Zmul; reflexivity. Qed.
+
+ Theorem mul_correct_naive x y : from_montgomery_naive (x * y) = from_montgomery_naive x * from_montgomery_naive y.
+ Proof using Type. unfold from_montgomery_naive, mul_naive; lia. Qed.
+ Theorem mul_correct_naive_to x y : to_montgomery_naive (x * y) ≡ (to_montgomery_naive x * to_montgomery_naive y)%montgomery.
+ Proof using R'_good.
+ unfold to_montgomery_naive, mul_naive.
+ rewrite <- !Z.mul_assoc, R'_good.
+ autorewrite with zsimplify; apply (f_equal2 Z.modulo); lia.
+ Qed.
+ End general.
+
+ (** * The REDC algorithm *)
+ Section redc.
+ Context (N' : Z)
+ (N'_in_range : 0 <= N' < R)
+ (N'_good : N * N' ≡ᵣ -1).
+
+ Lemma N'_good' : N' * N ≡ᵣ -1.
+ Proof using N'_good. rewrite <- N'_good; apply f_equal2; lia. Qed.
+
+ Lemma N'_good'_alt x : (((x mod R) * (N' mod R)) mod R) * (N mod R) ≡ᵣ x * -1.
+ Proof using N'_good.
+ rewrite <- N'_good', Z.mul_assoc.
+ unfold Z.equiv_modulo; push_Zmod.
+ reflexivity.
+ Qed.
+
+ Section redc.
+ Context (T : Z).
+
+ Local Notation m := (((T mod R) * N') mod R).
+ Local Notation prereduce := (prereduce N R N').
+
+ Local Ltac t_fin_correct :=
+ unfold Z.equiv_modulo; push_Zmod; autorewrite with zsimplify; reflexivity.
+
+ Lemma prereduce_correct : prereduce T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ transitivity ((T + m * N) * R').
+ { unfold prereduce.
+ autorewrite with zstrip_div; push_Zmod.
+ rewrite N'_good'_alt.
+ autorewrite with zsimplify pull_Zmod.
+ reflexivity. }
+ t_fin_correct.
+ Qed.
+
+ Lemma reduce_correct : reduce N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold reduce.
+ break_match; rewrite prereduce_correct; t_fin_correct.
+ Qed.
+
+ Lemma partial_reduce_correct : partial_reduce N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold partial_reduce.
+ break_match; rewrite prereduce_correct; t_fin_correct.
+ Qed.
+
+ Lemma reduce_via_partial_correct : reduce_via_partial N R N' T ≡ T * R'.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold reduce_via_partial.
+ break_match; rewrite partial_reduce_correct; t_fin_correct.
+ Qed.
+
+ Let m_small : 0 <= m < R. Proof. auto with zarith. Qed.
+
+ Section generic.
+ Lemma prereduce_in_range_gen B
+ : 0 <= N
+ -> 0 <= T <= R * B
+ -> 0 <= prereduce T < B + N.
+ Proof using N_reasonable m_small. unfold prereduce; auto with zarith nia. Qed.
+ End generic.
+
+ Section N_very_small.
+ Context (N_very_small : 0 <= 4 * N < R).
+
+ Lemma prereduce_in_range_very_small
+ : 0 <= T <= (2 * N - 1) * (2 * N - 1)
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable N_very_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+ End N_very_small.
+
+ Section N_small.
+ Context (N_small : 0 <= 2 * N < R).
+
+ Lemma prereduce_in_range_small
+ : 0 <= T <= (2 * N - 1) * (N - 1)
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+
+ Lemma prereduce_in_range_small_fully_reduced
+ : 0 <= T <= 2 * N
+ -> 0 <= prereduce T <= N.
+ Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen 1); nia. Qed.
+ End N_small.
+
+ Section N_small_enough.
+ Context (N_small_enough : 0 <= N < R).
+
+ Lemma prereduce_in_range_small_enough
+ : 0 <= T <= R * R
+ -> 0 <= prereduce T < R + N.
+ Proof using N_reasonable N_small_enough m_small. pose proof (prereduce_in_range_gen R); nia. Qed.
+
+ Lemma reduce_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= reduce N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma partial_reduce_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= partial_reduce N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma reduce_via_partial_in_range_R
+ : 0 <= T <= R * R
+ -> 0 <= reduce_via_partial N R N' T < R.
+ Proof using N_reasonable N_small_enough m_small.
+ intro H; pose proof (prereduce_in_range_small_enough H).
+ unfold reduce_via_partial, partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+ End N_small_enough.
+
+ Section unconstrained.
+ Lemma prereduce_in_range
+ : 0 <= T <= R * N
+ -> 0 <= prereduce T < 2 * N.
+ Proof using N_reasonable m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
+
+ Lemma reduce_in_range
+ : 0 <= T <= R * N
+ -> 0 <= reduce N R N' T < N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (prereduce_in_range H).
+ unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
+ Qed.
+
+ Lemma partial_reduce_in_range
+ : 0 <= T <= R * N
+ -> Z.min 0 (R - N) <= partial_reduce N R N' T < 2 * N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (prereduce_in_range H).
+ unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt;
+ apply Z.min_case_strong; nia.
+ Qed.
+
+ Lemma reduce_via_partial_in_range
+ : 0 <= T <= R * N
+ -> Z.min 0 (R - N) <= reduce_via_partial N R N' T < N.
+ Proof using N_reasonable m_small.
+ intro H; pose proof (partial_reduce_in_range H).
+ unfold reduce_via_partial in *; break_match; Z.ltb_to_lt; lia.
+ Qed.
+ End unconstrained.
+
+ Section alt.
+ Context (N_in_range : 0 <= N < R)
+ (T_representable : 0 <= T < R * R).
+ Lemma partial_reduce_alt_eq : partial_reduce_alt N R N' T = partial_reduce N R N' T.
+ Proof using N_in_range N_reasonable T_representable m_small.
+ assert (0 <= T + m * N < 2 * (R * R)) by nia.
+ assert (0 <= T + m * N < R * (R + N)) by nia.
+ assert (0 <= (T + m * N) / R < R + N) by auto with zarith.
+ assert ((T + m * N) / R - N < R) by lia.
+ assert (R * R <= T + m * N -> R <= (T + m * N) / R) by auto with zarith.
+ assert (T + m * N < R * R -> (T + m * N) / R < R) by auto with zarith.
+ assert (H' : (T + m * N) mod (R * R) = if R * R <=? T + m * N then T + m * N - R * R else T + m * N)
+ by (break_match; Z.ltb_to_lt; autorewrite with zsimplify; lia).
+ unfold partial_reduce, partial_reduce_alt, prereduce.
+ rewrite H'; clear H'.
+ simplify_repeated_ifs.
+ set (m' := m) in *.
+ autorewrite with zsimplify; push_Zmod; autorewrite with zsimplify; pull_Zmod.
+ break_match; Z.ltb_to_lt; autorewrite with zsimplify; try reflexivity; lia.
+ Qed.
+ End alt.
+ End redc.
+
+ (** * Arithmetic in Montgomery form *)
+ Section arithmetic.
+ Local Infix "*" := (mul N R N') : montgomery_scope.
+
+ Local Notation to_montgomery := (to_montgomery N R N').
+ Local Notation from_montgomery := (from_montgomery N R N').
+ Lemma to_from_montgomery a : to_montgomery (from_montgomery a) ≡ a.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, from_montgomery.
+ transitivity ((a * 1) * 1); [ | apply f_equal2; lia ].
+ rewrite <- !R'_good, !reduce_correct.
+ unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ apply f_equal2; lia.
+ Qed.
+ Lemma from_to_montgomery a : from_montgomery (to_montgomery a) ≡ a.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, from_montgomery.
+ rewrite !reduce_correct.
+ transitivity (a * ((R * (R * R' mod N) * R') mod N)).
+ { unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ apply f_equal2; lia. }
+ { repeat first [ rewrite R'_good
+ | reflexivity
+ | push_Zmod; pull_Zmod; progress autorewrite with zsimplify
+ | progress unfold Z.equiv_modulo ]. }
+ Qed.
+
+ Theorem mul_correct x y : from_montgomery (x * y) ≡ from_montgomery x * from_montgomery y.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold from_montgomery, mul.
+ rewrite !reduce_correct; apply f_equal2; lia.
+ Qed.
+ Theorem mul_correct_to x y : to_montgomery (x * y) ≡ (to_montgomery x * to_montgomery y)%montgomery.
+ Proof using N'_good N'_in_range N_reasonable R'_good.
+ unfold to_montgomery, mul.
+ rewrite !reduce_correct.
+ transitivity (x * y * R * 1 * 1 * 1);
+ [ rewrite <- R'_good at 1
+ | rewrite <- R'_good at 1 2 3 ];
+ autorewrite with zsimplify;
+ unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
+ { apply f_equal2; lia. }
+ { apply f_equal2; lia. }
+ Qed.
+ End arithmetic.
+ End redc.
+End montgomery.
+
+Module Import LocalizeEquivModuloInstances := Z.RemoveEquivModuloInstances Nop.