diff options
author | 2017-04-06 22:53:07 -0400 | |
---|---|---|
committer | 2017-04-06 22:53:07 -0400 | |
commit | c9fc5a3cdf1f5ea2d104c150c30d1b1a6ac64239 (patch) | |
tree | db7187f6984acff324ca468e7b33d9285806a1eb /src/Arithmetic/MontgomeryReduction | |
parent | 21198245dab432d3c0ba2bb8a02254e7d0594382 (diff) |
rename-everything
Diffstat (limited to 'src/Arithmetic/MontgomeryReduction')
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/Definition.v | 179 | ||||
-rw-r--r-- | src/Arithmetic/MontgomeryReduction/Proofs.v | 296 |
2 files changed, 475 insertions, 0 deletions
diff --git a/src/Arithmetic/MontgomeryReduction/Definition.v b/src/Arithmetic/MontgomeryReduction/Definition.v new file mode 100644 index 000000000..78d3c037f --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/Definition.v @@ -0,0 +1,179 @@ +(*** Montgomery Multiplication *) +(** This file implements Montgomery Form, Montgomery Reduction, and + Montgomery Multiplication on [Z]. We follow Wikipedia. *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Local Open Scope Z_scope. +Delimit Scope montgomery_naive_scope with montgomery_naive. +Delimit Scope montgomery_scope with montgomery. +Definition montgomeryZ := Z. +Bind Scope montgomery_scope with montgomeryZ. + +Section montgomery. + Context (N : Z) + (R : Z) + (R' : Z). (* R' is R⁻¹ mod N *) + Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope. + Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope. + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Montgomery_modular_multiplication>: *) + (** In modular arithmetic computation, Montgomery modular + multiplication, more commonly referred to as Montgomery + multiplication, is a method for performing fast modular + multiplication, introduced in 1985 by the American mathematician + Peter L. Montgomery. *) + (** Given two integers [a] and [b], the classical modular + multiplication algorithm computes [ab mod N]. Montgomery + multiplication works by transforming [a] and [b] into a special + representation known as Montgomery form. For a modulus [N], the + Montgomery form of [a] is defined to be [aR mod N] for some + constant [R] depending only on [N] and the underlying computer + architecture. If [aR mod N] and [bR mod N] are the Montgomery + forms of [a] and [b], then their Montgomery product is [abR mod + N]. Montgomery multiplication is a fast algorithm to compute the + Montgomery product. Transforming the result out of Montgomery + form yields the classical modular product [ab mod N]. *) + + Definition to_montgomery_naive (x : Z) : montgomeryZ := x * R. + Definition from_montgomery_naive (x : montgomeryZ) : Z := x * R'. + + (** * Modular arithmetic and Montgomery form *) + Section general. + (** Addition and subtraction in Montgomery form are the same as + ordinary modular addition and subtraction because of the + distributive law: *) + (** [aR + bR = (a+b)R], *) + (** [aR - bR = (a-b)R]. *) + (** This is a consequence of the fact that, because [gcd(R, N) = + 1], multiplication by [R] is an isomorphism on the additive + group [ℤ/Nℤ]. *) + + Definition add : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR + bR. + Definition sub : montgomeryZ -> montgomeryZ -> montgomeryZ := fun aR bR => aR - bR. + + (** Multiplication in Montgomery form, however, is seemingly more + complicated. The usual product of [aR] and [bR] does not + represent the product of [a] and [b] because it has an extra + factor of R: *) + (** [(aR mod N)(bR mod N) mod N = (abR)R mod N]. *) + (** Computing products in Montgomery form requires removing the + extra factor of [R]. While division by [R] is cheap, the + intermediate product [(aR mod N)(bR mod N)] is not divisible + by [R] because the modulo operation has destroyed that + property. *) + (** Removing the extra factor of R can be done by multiplying by + an integer [R′] such that [RR' ≡ 1 (mod N)], that is, by an + [R′] whose residue class is the modular inverse of [R] mod + [N]. Then, working modulo [N], *) + (** [(aR mod N)(bR mod N)R' ≡ (aR)(bR)R⁻¹ ≡ (ab)R (mod N)]. *) + + Definition mul_naive : montgomeryZ -> montgomeryZ -> montgomeryZ + := fun aR bR => aR * bR * R'. + End general. + + (** * The REDC algorithm *) + Section redc. + (** While the above algorithm is correct, it is slower than + multiplication in the standard representation because of the + need to multiply by [R′] and divide by [N]. Montgomery + reduction, also known as REDC, is an algorithm that + simultaneously computes the product by [R′] and reduces modulo + [N] more quickly than the naive method. The speed is because + all computations are done using only reduction and divisions + with respect to [R], not [N]: *) + (** +<< +function REDC is + input: Integers R and N with gcd(R, N) = 1, + Integer N′ in [0, R − 1] such that NN′ ≡ −1 mod R, + Integer T in the range [0, RN − 1] + output: Integer S in the range [0, N − 1] such that S ≡ TR⁻¹ mod N + + m ← ((T mod R)N′) mod R + t ← (T + mN) / R + if t ≥ N then + return t − N + else + return t + end if +end function +>> *) + Context (N' : Z). (* N' is (-N⁻¹) mod R *) + Section redc. + Context (T : Z). + + Let m := ((T mod R) * N') mod R. + Let t := (T + m * N) / R. + Definition prereduce : montgomeryZ := t. + + Definition partial_reduce : montgomeryZ + := if R <=? t then + prereduce - N + else + prereduce. + + Definition partial_reduce_alt : montgomeryZ + := let v0 := (T + m * N) in + let v := (v0 mod (R * R)) / R in + if R * R <=? v0 then + (v - N) mod R + else + v. + + Definition reduce : montgomeryZ + := if N <=? t then + prereduce - N + else + prereduce. + + Definition reduce_via_partial : montgomeryZ + := if N <=? partial_reduce then + partial_reduce - N + else + partial_reduce. + + Definition reduce_via_partial_alt : montgomeryZ + := if N <=? partial_reduce then + partial_reduce - N + else + partial_reduce. + End redc. + + (** * Arithmetic in Montgomery form *) + Section arithmetic. + (** Many operations of interest modulo [N] can be expressed + equally well in Montgomery form. Addition, subtraction, + negation, comparison for equality, multiplication by an + integer not in Montgomery form, and greatest common divisors + with [N] may all be done with the standard algorithms. *) + (** When [R > N], most other arithmetic operations can be + expressed in terms of REDC. This assumption implies that the + product of two representatives [mod N] is less than [RN], + the exact hypothesis necessary for REDC to generate correct + output. In particular, the product of [aR mod N] and [bR mod + N] is [REDC((aR mod N)(bR mod N))]. The combined operation + of multiplication and REDC is often called Montgomery + multiplication. *) + Definition mul : montgomeryZ -> montgomeryZ -> montgomeryZ + := fun aR bR => reduce (aR * bR). + + (** Conversion into Montgomery form is done by computing + [REDC((a mod N)(R² mod N))]. Conversion out of Montgomery + form is done by computing [REDC(aR mod N)]. The modular + inverse of [aR mod N] is [REDC((aR mod N)⁻¹(R³ mod + N))]. Modular exponentiation can be done using + exponentiation by squaring by initializing the initial + product to the Montgomery representation of 1, that is, to + [R mod N], and by replacing the multiply and square steps by + Montgomery multiplies. *) + Definition to_montgomery (a : Z) : montgomeryZ := reduce (a * (R*R mod N)). + Definition from_montgomery (aR : montgomeryZ) : Z := reduce aR. + End arithmetic. + End redc. +End montgomery. + +Infix "+" := add : montgomery_scope. +Infix "-" := sub : montgomery_scope. +Infix "*" := mul_naive : montgomery_naive_scope. +Infix "*" := mul : montgomery_scope. diff --git a/src/Arithmetic/MontgomeryReduction/Proofs.v b/src/Arithmetic/MontgomeryReduction/Proofs.v new file mode 100644 index 000000000..d5de00213 --- /dev/null +++ b/src/Arithmetic/MontgomeryReduction/Proofs.v @@ -0,0 +1,296 @@ +(*** Montgomery Multiplication *) +(** This file implements the proofs for Montgomery Form, Montgomery + Reduction, and Montgomery Multiplication on [Z]. We follow + Wikipedia. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.Structures.Equalities. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SimplifyRepeatedIfs. +Require Import Crypto.Util.Notations. + +Declare Module Nop : Nop. +Module Import ImportEquivModuloInstances := Z.EquivModuloInstances Nop. + +Local Existing Instance eq_Reflexive. (* speed up setoid_rewrite as per https://coq.inria.fr/bugs/show_bug.cgi?id=4978 *) + +Local Open Scope Z_scope. + +Section montgomery. + Context (N : Z) + (N_reasonable : N <> 0) + (R : Z) + (R_good : Z.gcd N R = 1). + Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope. + Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope. + Context (R' : Z) + (R'_good : R * R' ≡ 1). + + Lemma R'_good' : R' * R ≡ 1. + Proof using R'_good. rewrite <- R'_good; apply f_equal2; lia. Qed. + + Local Notation to_montgomery_naive := (to_montgomery_naive R) (only parsing). + Local Notation from_montgomery_naive := (from_montgomery_naive R') (only parsing). + + Lemma to_from_montgomery_naive x : to_montgomery_naive (from_montgomery_naive x) ≡ x. + Proof using R'_good. + unfold to_montgomery_naive, from_montgomery_naive. + rewrite <- Z.mul_assoc, R'_good'. + autorewrite with zsimplify; reflexivity. + Qed. + Lemma from_to_montgomery_naive x : from_montgomery_naive (to_montgomery_naive x) ≡ x. + Proof using R'_good. + unfold to_montgomery_naive, from_montgomery_naive. + rewrite <- Z.mul_assoc, R'_good. + autorewrite with zsimplify; reflexivity. + Qed. + + (** * Modular arithmetic and Montgomery form *) + Section general. + Local Infix "+" := add : montgomery_scope. + Local Infix "-" := sub : montgomery_scope. + Local Infix "*" := (mul_naive R') : montgomery_scope. + + Lemma add_correct_naive x y : from_montgomery_naive (x + y) = from_montgomery_naive x + from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, add; lia. Qed. + Lemma add_correct_naive_to x y : to_montgomery_naive (x + y) = (to_montgomery_naive x + to_montgomery_naive y)%montgomery. + Proof using Type. unfold to_montgomery_naive, add; autorewrite with push_Zmul; reflexivity. Qed. + Lemma sub_correct_naive x y : from_montgomery_naive (x - y) = from_montgomery_naive x - from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, sub; lia. Qed. + Lemma sub_correct_naive_to x y : to_montgomery_naive (x - y) = (to_montgomery_naive x - to_montgomery_naive y)%montgomery. + Proof using Type. unfold to_montgomery_naive, sub; autorewrite with push_Zmul; reflexivity. Qed. + + Theorem mul_correct_naive x y : from_montgomery_naive (x * y) = from_montgomery_naive x * from_montgomery_naive y. + Proof using Type. unfold from_montgomery_naive, mul_naive; lia. Qed. + Theorem mul_correct_naive_to x y : to_montgomery_naive (x * y) ≡ (to_montgomery_naive x * to_montgomery_naive y)%montgomery. + Proof using R'_good. + unfold to_montgomery_naive, mul_naive. + rewrite <- !Z.mul_assoc, R'_good. + autorewrite with zsimplify; apply (f_equal2 Z.modulo); lia. + Qed. + End general. + + (** * The REDC algorithm *) + Section redc. + Context (N' : Z) + (N'_in_range : 0 <= N' < R) + (N'_good : N * N' ≡ᵣ -1). + + Lemma N'_good' : N' * N ≡ᵣ -1. + Proof using N'_good. rewrite <- N'_good; apply f_equal2; lia. Qed. + + Lemma N'_good'_alt x : (((x mod R) * (N' mod R)) mod R) * (N mod R) ≡ᵣ x * -1. + Proof using N'_good. + rewrite <- N'_good', Z.mul_assoc. + unfold Z.equiv_modulo; push_Zmod. + reflexivity. + Qed. + + Section redc. + Context (T : Z). + + Local Notation m := (((T mod R) * N') mod R). + Local Notation prereduce := (prereduce N R N'). + + Local Ltac t_fin_correct := + unfold Z.equiv_modulo; push_Zmod; autorewrite with zsimplify; reflexivity. + + Lemma prereduce_correct : prereduce T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + transitivity ((T + m * N) * R'). + { unfold prereduce. + autorewrite with zstrip_div; push_Zmod. + rewrite N'_good'_alt. + autorewrite with zsimplify pull_Zmod. + reflexivity. } + t_fin_correct. + Qed. + + Lemma reduce_correct : reduce N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold reduce. + break_match; rewrite prereduce_correct; t_fin_correct. + Qed. + + Lemma partial_reduce_correct : partial_reduce N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold partial_reduce. + break_match; rewrite prereduce_correct; t_fin_correct. + Qed. + + Lemma reduce_via_partial_correct : reduce_via_partial N R N' T ≡ T * R'. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold reduce_via_partial. + break_match; rewrite partial_reduce_correct; t_fin_correct. + Qed. + + Let m_small : 0 <= m < R. Proof. auto with zarith. Qed. + + Section generic. + Lemma prereduce_in_range_gen B + : 0 <= N + -> 0 <= T <= R * B + -> 0 <= prereduce T < B + N. + Proof using N_reasonable m_small. unfold prereduce; auto with zarith nia. Qed. + End generic. + + Section N_very_small. + Context (N_very_small : 0 <= 4 * N < R). + + Lemma prereduce_in_range_very_small + : 0 <= T <= (2 * N - 1) * (2 * N - 1) + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable N_very_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + End N_very_small. + + Section N_small. + Context (N_small : 0 <= 2 * N < R). + + Lemma prereduce_in_range_small + : 0 <= T <= (2 * N - 1) * (N - 1) + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + + Lemma prereduce_in_range_small_fully_reduced + : 0 <= T <= 2 * N + -> 0 <= prereduce T <= N. + Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen 1); nia. Qed. + End N_small. + + Section N_small_enough. + Context (N_small_enough : 0 <= N < R). + + Lemma prereduce_in_range_small_enough + : 0 <= T <= R * R + -> 0 <= prereduce T < R + N. + Proof using N_reasonable N_small_enough m_small. pose proof (prereduce_in_range_gen R); nia. Qed. + + Lemma reduce_in_range_R + : 0 <= T <= R * R + -> 0 <= reduce N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma partial_reduce_in_range_R + : 0 <= T <= R * R + -> 0 <= partial_reduce N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma reduce_via_partial_in_range_R + : 0 <= T <= R * R + -> 0 <= reduce_via_partial N R N' T < R. + Proof using N_reasonable N_small_enough m_small. + intro H; pose proof (prereduce_in_range_small_enough H). + unfold reduce_via_partial, partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + End N_small_enough. + + Section unconstrained. + Lemma prereduce_in_range + : 0 <= T <= R * N + -> 0 <= prereduce T < 2 * N. + Proof using N_reasonable m_small. pose proof (prereduce_in_range_gen N); nia. Qed. + + Lemma reduce_in_range + : 0 <= T <= R * N + -> 0 <= reduce N R N' T < N. + Proof using N_reasonable m_small. + intro H; pose proof (prereduce_in_range H). + unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia. + Qed. + + Lemma partial_reduce_in_range + : 0 <= T <= R * N + -> Z.min 0 (R - N) <= partial_reduce N R N' T < 2 * N. + Proof using N_reasonable m_small. + intro H; pose proof (prereduce_in_range H). + unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; + apply Z.min_case_strong; nia. + Qed. + + Lemma reduce_via_partial_in_range + : 0 <= T <= R * N + -> Z.min 0 (R - N) <= reduce_via_partial N R N' T < N. + Proof using N_reasonable m_small. + intro H; pose proof (partial_reduce_in_range H). + unfold reduce_via_partial in *; break_match; Z.ltb_to_lt; lia. + Qed. + End unconstrained. + + Section alt. + Context (N_in_range : 0 <= N < R) + (T_representable : 0 <= T < R * R). + Lemma partial_reduce_alt_eq : partial_reduce_alt N R N' T = partial_reduce N R N' T. + Proof using N_in_range N_reasonable T_representable m_small. + assert (0 <= T + m * N < 2 * (R * R)) by nia. + assert (0 <= T + m * N < R * (R + N)) by nia. + assert (0 <= (T + m * N) / R < R + N) by auto with zarith. + assert ((T + m * N) / R - N < R) by lia. + assert (R * R <= T + m * N -> R <= (T + m * N) / R) by auto with zarith. + assert (T + m * N < R * R -> (T + m * N) / R < R) by auto with zarith. + assert (H' : (T + m * N) mod (R * R) = if R * R <=? T + m * N then T + m * N - R * R else T + m * N) + by (break_match; Z.ltb_to_lt; autorewrite with zsimplify; lia). + unfold partial_reduce, partial_reduce_alt, prereduce. + rewrite H'; clear H'. + simplify_repeated_ifs. + set (m' := m) in *. + autorewrite with zsimplify; push_Zmod; autorewrite with zsimplify; pull_Zmod. + break_match; Z.ltb_to_lt; autorewrite with zsimplify; try reflexivity; lia. + Qed. + End alt. + End redc. + + (** * Arithmetic in Montgomery form *) + Section arithmetic. + Local Infix "*" := (mul N R N') : montgomery_scope. + + Local Notation to_montgomery := (to_montgomery N R N'). + Local Notation from_montgomery := (from_montgomery N R N'). + Lemma to_from_montgomery a : to_montgomery (from_montgomery a) ≡ a. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, from_montgomery. + transitivity ((a * 1) * 1); [ | apply f_equal2; lia ]. + rewrite <- !R'_good, !reduce_correct. + unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + apply f_equal2; lia. + Qed. + Lemma from_to_montgomery a : from_montgomery (to_montgomery a) ≡ a. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, from_montgomery. + rewrite !reduce_correct. + transitivity (a * ((R * (R * R' mod N) * R') mod N)). + { unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + apply f_equal2; lia. } + { repeat first [ rewrite R'_good + | reflexivity + | push_Zmod; pull_Zmod; progress autorewrite with zsimplify + | progress unfold Z.equiv_modulo ]. } + Qed. + + Theorem mul_correct x y : from_montgomery (x * y) ≡ from_montgomery x * from_montgomery y. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold from_montgomery, mul. + rewrite !reduce_correct; apply f_equal2; lia. + Qed. + Theorem mul_correct_to x y : to_montgomery (x * y) ≡ (to_montgomery x * to_montgomery y)%montgomery. + Proof using N'_good N'_in_range N_reasonable R'_good. + unfold to_montgomery, mul. + rewrite !reduce_correct. + transitivity (x * y * R * 1 * 1 * 1); + [ rewrite <- R'_good at 1 + | rewrite <- R'_good at 1 2 3 ]; + autorewrite with zsimplify; + unfold Z.equiv_modulo; push_Zmod; pull_Zmod. + { apply f_equal2; lia. } + { apply f_equal2; lia. } + Qed. + End arithmetic. + End redc. +End montgomery. + +Module Import LocalizeEquivModuloInstances := Z.RemoveEquivModuloInstances Nop. |