aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/MontgomeryReduction/Proofs.v
blob: e6be440fbfe5f47278b6cf4236708da84d59b312 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
(*** Montgomery Multiplication *)
(** This file implements the proofs for Montgomery Form, Montgomery
    Reduction, and Montgomery Multiplication on [Z].  We follow
    Wikipedia. *)
Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.Structures.Equalities.
Require Import Crypto.Arithmetic.MontgomeryReduction.Definition.
Require Import Crypto.Util.ZUtil.EquivModulo.
Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo.
Require Import Crypto.Util.ZUtil.Tactics.LtbToLt.
Require Import Crypto.Util.Tactics.BreakMatch.
Require Import Crypto.Util.Tactics.SimplifyRepeatedIfs.
Require Import Crypto.Util.Notations.

Declare Module Nop : Nop.
Module Import ImportEquivModuloInstances := Z.EquivModuloInstances Nop.

Local Existing Instance eq_Reflexive. (* speed up setoid_rewrite as per https://coq.inria.fr/bugs/show_bug.cgi?id=4978 *)

Local Open Scope Z_scope.

Section montgomery.
  Context (N : Z)
          (N_reasonable : N <> 0)
          (R : Z)
          (R_good : Z.gcd N R = 1).
  Local Notation "x ≡ y" := (Z.equiv_modulo N x y) : type_scope.
  Local Notation "x ≡ᵣ y" := (Z.equiv_modulo R x y) : type_scope.
  Context (R' : Z)
          (R'_good : R * R' ≡ 1).

  Lemma R'_good' : R' * R ≡ 1.
  Proof using R'_good. rewrite <- R'_good; apply f_equal2; lia. Qed.

  Local Notation to_montgomery_naive := (to_montgomery_naive R) (only parsing).
  Local Notation from_montgomery_naive := (from_montgomery_naive R') (only parsing).

  Lemma to_from_montgomery_naive x : to_montgomery_naive (from_montgomery_naive x) ≡ x.
  Proof using R'_good.
    unfold to_montgomery_naive, from_montgomery_naive.
    rewrite <- Z.mul_assoc, R'_good'.
    autorewrite with zsimplify; reflexivity.
  Qed.
  Lemma from_to_montgomery_naive x : from_montgomery_naive (to_montgomery_naive x) ≡ x.
  Proof using R'_good.
    unfold to_montgomery_naive, from_montgomery_naive.
    rewrite <- Z.mul_assoc, R'_good.
    autorewrite with zsimplify; reflexivity.
  Qed.

  (** * Modular arithmetic and Montgomery form *)
  Section general.
    Local Infix "+" := add : montgomery_scope.
    Local Infix "-" := sub : montgomery_scope.
    Local Infix "*" := (mul_naive R') : montgomery_scope.

    Lemma add_correct_naive x y : from_montgomery_naive (x + y) = from_montgomery_naive x + from_montgomery_naive y.
    Proof using Type. unfold from_montgomery_naive, add; lia. Qed.
    Lemma add_correct_naive_to x y : to_montgomery_naive (x + y) = (to_montgomery_naive x + to_montgomery_naive y)%montgomery.
    Proof using Type. unfold to_montgomery_naive, add; autorewrite with push_Zmul; reflexivity. Qed.
    Lemma sub_correct_naive x y : from_montgomery_naive (x - y) = from_montgomery_naive x - from_montgomery_naive y.
    Proof using Type. unfold from_montgomery_naive, sub; lia. Qed.
    Lemma sub_correct_naive_to x y : to_montgomery_naive (x - y) = (to_montgomery_naive x - to_montgomery_naive y)%montgomery.
    Proof using Type. unfold to_montgomery_naive, sub; autorewrite with push_Zmul; reflexivity. Qed.

    Theorem mul_correct_naive x y : from_montgomery_naive (x * y) = from_montgomery_naive x * from_montgomery_naive y.
    Proof using Type. unfold from_montgomery_naive, mul_naive; lia. Qed.
    Theorem mul_correct_naive_to x y : to_montgomery_naive (x * y) ≡ (to_montgomery_naive x * to_montgomery_naive y)%montgomery.
    Proof using R'_good.
      unfold to_montgomery_naive, mul_naive.
      rewrite <- !Z.mul_assoc, R'_good.
      autorewrite with zsimplify; apply (f_equal2 Z.modulo); lia.
    Qed.
  End general.

  (** * The REDC algorithm *)
  Section redc.
    Context (N' : Z)
            (N'_in_range : 0 <= N' < R)
            (N'_good : N * N' ≡ᵣ -1).

    Lemma N'_good' : N' * N ≡ᵣ -1.
    Proof using N'_good. rewrite <- N'_good; apply f_equal2; lia. Qed.

    Lemma N'_good'_alt x : (((x mod R) * (N' mod R)) mod R) * (N mod R) ≡ᵣ x * -1.
    Proof using N'_good.
      rewrite <- N'_good', Z.mul_assoc.
      unfold Z.equiv_modulo; push_Zmod.
      reflexivity.
    Qed.

    Section redc.
      Context (T : Z).

      Local Notation m := (((T mod R) * N') mod R).
      Local Notation prereduce := (prereduce N R N').

      Local Ltac t_fin_correct :=
        unfold Z.equiv_modulo; push_Zmod; autorewrite with zsimplify; reflexivity.

      Lemma prereduce_correct : prereduce T ≡ T * R'.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        transitivity ((T + m * N) * R').
        { unfold prereduce.
          autorewrite with zstrip_div; push_Zmod.
          rewrite N'_good'_alt.
          autorewrite with zsimplify pull_Zmod.
          reflexivity. }
        t_fin_correct.
      Qed.

      Lemma reduce_correct : reduce N R N' T ≡ T * R'.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold reduce.
        break_match; rewrite prereduce_correct; t_fin_correct.
      Qed.

      Lemma partial_reduce_correct : partial_reduce N R N' T ≡ T * R'.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold partial_reduce.
        break_match; rewrite prereduce_correct; t_fin_correct.
      Qed.

      Lemma reduce_via_partial_correct : reduce_via_partial N R N' T ≡ T * R'.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold reduce_via_partial.
        break_match; rewrite partial_reduce_correct; t_fin_correct.
      Qed.

      Let m_small : 0 <= m < R. Proof. auto with zarith. Qed.

      Section generic.
        Lemma prereduce_in_range_gen B
        : 0 <= N
          -> 0 <= T <= R * B
          -> 0 <= prereduce T < B + N.
        Proof using N_reasonable m_small. unfold prereduce; auto with zarith nia. Qed.
      End generic.

      Section N_very_small.
        Context (N_very_small : 0 <= 4 * N < R).

        Lemma prereduce_in_range_very_small
          : 0 <= T <= (2 * N - 1) * (2 * N - 1)
            -> 0 <= prereduce T < 2 * N.
        Proof using N_reasonable N_very_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.
      End N_very_small.

      Section N_small.
        Context (N_small : 0 <= 2 * N < R).

        Lemma prereduce_in_range_small
          : 0 <= T <= (2 * N - 1) * (N - 1)
            -> 0 <= prereduce T < 2 * N.
        Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen N); nia. Qed.

        Lemma prereduce_in_range_small_fully_reduced
          : 0 <= T <= 2 * N
            -> 0 <= prereduce T <= N.
        Proof using N_reasonable N_small m_small. pose proof (prereduce_in_range_gen 1); nia. Qed.
      End N_small.

      Section N_small_enough.
        Context (N_small_enough : 0 <= N < R).

        Lemma prereduce_in_range_small_enough
          : 0 <= T <= R * R
            -> 0 <= prereduce T < R + N.
        Proof using N_reasonable N_small_enough m_small. pose proof (prereduce_in_range_gen R); nia. Qed.

        Lemma reduce_in_range_R
          : 0 <= T <= R * R
            -> 0 <= reduce N R N' T < R.
        Proof using N_reasonable N_small_enough m_small.
          intro H; pose proof (prereduce_in_range_small_enough H).
          unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
        Qed.

        Lemma partial_reduce_in_range_R
          : 0 <= T <= R * R
            -> 0 <= partial_reduce N R N' T < R.
        Proof using N_reasonable N_small_enough m_small.
          intro H; pose proof (prereduce_in_range_small_enough H).
          unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
        Qed.

        Lemma reduce_via_partial_in_range_R
          : 0 <= T <= R * R
            -> 0 <= reduce_via_partial N R N' T < R.
        Proof using N_reasonable N_small_enough m_small.
          intro H; pose proof (prereduce_in_range_small_enough H).
          unfold reduce_via_partial, partial_reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
        Qed.
      End N_small_enough.

      Section unconstrained.
        Lemma prereduce_in_range
          : 0 <= T <= R * N
            -> 0 <= prereduce T < 2 * N.
        Proof using N_reasonable m_small. pose proof (prereduce_in_range_gen N); nia. Qed.

        Lemma reduce_in_range
        : 0 <= T <= R * N
          -> 0 <= reduce N R N' T < N.
        Proof using N_reasonable m_small.
          intro H; pose proof (prereduce_in_range H).
          unfold reduce, prereduce in *; break_match; Z.ltb_to_lt; nia.
        Qed.

        Lemma partial_reduce_in_range
        : 0 <= T <= R * N
          -> Z.min 0 (R - N) <= partial_reduce N R N' T < 2 * N.
        Proof using N_reasonable m_small.
          intro H; pose proof (prereduce_in_range H).
          unfold partial_reduce, prereduce in *; break_match; Z.ltb_to_lt;
            apply Z.min_case_strong; nia.
        Qed.

        Lemma reduce_via_partial_in_range
        : 0 <= T <= R * N
          -> Z.min 0 (R - N) <= reduce_via_partial N R N' T < N.
        Proof using N_reasonable m_small.
          intro H; pose proof (partial_reduce_in_range H).
          unfold reduce_via_partial in *; break_match; Z.ltb_to_lt; lia.
        Qed.
      End unconstrained.

      Section alt.
        Context (N_in_range : 0 <= N < R)
                (T_representable : 0 <= T < R * R).
        Lemma partial_reduce_alt_eq : partial_reduce_alt N R N' T = partial_reduce N R N' T.
        Proof using N_in_range N_reasonable T_representable m_small.
          assert (0 <= T + m * N < 2 * (R * R)) by nia.
          assert (0 <= T + m * N < R * (R + N)) by nia.
          assert (0 <= (T + m * N) / R < R + N) by auto with zarith.
          assert ((T + m * N) / R - N < R) by lia.
          assert (R * R <= T + m * N -> R <= (T + m * N) / R) by auto with zarith.
          assert (T + m * N < R * R -> (T + m * N) / R < R) by auto with zarith.
          assert (H' : (T + m * N) mod (R * R) = if R * R <=? T + m * N then T + m * N - R * R else T + m * N)
            by (break_match; Z.ltb_to_lt; autorewrite with zsimplify; lia).
          unfold partial_reduce, partial_reduce_alt, prereduce.
          rewrite H'; clear H'.
          simplify_repeated_ifs.
          set (m' := m) in *.
          autorewrite with zsimplify; push_Zmod; autorewrite with zsimplify; pull_Zmod.
          break_match; Z.ltb_to_lt; autorewrite with zsimplify; try reflexivity; lia.
        Qed.

        Lemma reduce_via_partial_alt_eq : reduce_via_partial_alt N R N' T = reduce_via_partial N R N' T.
        Proof.
            cbv [reduce_via_partial_alt reduce_via_partial].
            rewrite partial_reduce_alt_eq by omega. reflexivity.
        Qed.
      End alt.
    End redc.

    (** * Arithmetic in Montgomery form *)
    Section arithmetic.
      Local Infix "*" := (mul N R N') : montgomery_scope.

      Local Notation to_montgomery := (to_montgomery N R N').
      Local Notation from_montgomery := (from_montgomery N R N').
      Lemma to_from_montgomery a : to_montgomery (from_montgomery a) ≡ a.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold to_montgomery, from_montgomery.
        transitivity ((a * 1) * 1); [ | apply f_equal2; lia ].
        rewrite <- !R'_good, !reduce_correct.
        unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
        apply f_equal2; lia.
      Qed.
      Lemma from_to_montgomery a : from_montgomery (to_montgomery a) ≡ a.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold to_montgomery, from_montgomery.
        rewrite !reduce_correct.
        transitivity (a * ((R * (R * R' mod N) * R') mod N)).
        { unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
          apply f_equal2; lia. }
        { repeat first [ rewrite R'_good
                       | reflexivity
                       | push_Zmod; pull_Zmod; progress autorewrite with zsimplify
                       | progress unfold Z.equiv_modulo ]. }
      Qed.

      Theorem mul_correct x y : from_montgomery (x * y) ≡ from_montgomery x * from_montgomery y.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold from_montgomery, mul.
        rewrite !reduce_correct; apply f_equal2; lia.
      Qed.
      Theorem mul_correct_to x y : to_montgomery (x * y) ≡ (to_montgomery x * to_montgomery y)%montgomery.
      Proof using N'_good N'_in_range N_reasonable R'_good.
        unfold to_montgomery, mul.
        rewrite !reduce_correct.
        transitivity (x * y * R * 1 * 1 * 1);
          [ rewrite <- R'_good at 1
          | rewrite <- R'_good at 1 2 3 ];
          autorewrite with zsimplify;
          unfold Z.equiv_modulo; push_Zmod; pull_Zmod.
        { apply f_equal2; lia. }
        { apply f_equal2; lia. }
      Qed.
    End arithmetic.
  End redc.
End montgomery.

Module Import LocalizeEquivModuloInstances := Z.RemoveEquivModuloInstances Nop.