aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
blob: 981680b4a93d936e7a48dd3bbd11cb89a36b4ef7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseRep.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.BaseSystem Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Coq.Lists.List.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.QArith.QArith Coq.QArith.Qround.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z.

(* Computed versions of some functions. *)

Definition Z_add_opt := Eval compute in Z.add.
Definition Z_sub_opt := Eval compute in Z.sub.
Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by.

Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
Definition map_opt {A B} := Eval compute in @map A B.
Definition base_from_limb_widths_opt := Eval compute in base_from_limb_widths.

Definition Let_In {A P} (x : A) (f : forall y : A, P y)
  := let y := x in f y.

(* Some automation that comes in handy when constructing base parameters *)
Ltac opt_step :=
  match goal with
  | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
    => refine (_ : match e with nil => _ | _ => _ end = _);
       destruct e
  end.

Ltac brute_force_indices limb_widths := intros; unfold sum_firstn, limb_widths; simpl in *;
  repeat match goal with
  | _ => progress simpl in *
  | _ => reflexivity
  | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H
  | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x
  | [H : (?x < _)%nat |- _ ] => is_var x; destruct x
  | _ => omega
  end.


Definition limb_widths_from_len len k := Eval compute in
  (fix loop i prev :=
    match i with
    | O => nil
    | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
                        then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
                        else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
      x - prev:: (loop i' x)
    end) len 0.

Ltac construct_params prime_modulus len k :=
  let lw := fresh "lw" in set (lw := limb_widths_from_len len k);
  cbv in lw;
  eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
  [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
  | abstract (unfold limb_widths; cbv; congruence)
  | abstract brute_force_indices lw
  | abstract apply prime_modulus
  | abstract brute_force_indices lw].

Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=  
  match limb_widths with
  | nil => nil
  | x :: tail =>
      2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail
  end.

Ltac subst_precondition := match goal with
  | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
end.

Ltac kill_precondition H := 
  forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
  subst_precondition.

Ltac compute_formula :=
  match goal with
  | [H : _ -> _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula
  | [H : _ -> PseudoMersenneBaseRep.rep _ ?result |- PseudoMersenneBaseRep.rep _ ?result] => kill_precondition H; compute_formula
  | [H : @PseudoMersenneBaseRep.rep ?M ?P _ ?result |- @PseudoMersenneBaseRep.rep ?M ?P _ ?result] =>
    let m := fresh "m" in set (m := M) in H at 1; change M with m at 1;
    let p := fresh "p" in set (p := P) in H at 1; change P with p at 1;
    let r := fresh "r" in set (r := result) in H |- *;
    cbv -[m p r PseudoMersenneBaseRep.rep] in H;
    repeat rewrite ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_assoc, ?Z.mul_assoc in H;
    exact H
  end.

Section Carries.
  Context `{prm : PseudoMersenneBaseParams}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).

  Definition carry_opt_sig
             (i : nat) (b : digits)
    : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
  Proof.
    eexists ; intros.
    cbv [carry].
    rewrite <- pull_app_if_sumbool.
    cbv beta delta
      [carry carry_and_reduce carry_simple add_to_nth log_cap
       pow2_mod Z.ones Z.pred base
       PseudoMersenneBaseParams.limb_widths].
    change @nth_default with @nth_default_opt in *.
    change @set_nth with @set_nth_opt in *.
    lazymatch goal with
    | [ |- _ = (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
    end.
    2:cbv zeta.
    2:break_if; reflexivity.

    change @nth_default with @nth_default_opt.
    rewrite c_subst.
    change @set_nth with @set_nth_opt.
    change @map with @map_opt.
    rewrite <- @beq_nat_eq_nat_dec.
    change base_from_limb_widths with base_from_limb_widths_opt.
    reflexivity.
  Defined.

  Definition carry_opt i b
    := Eval cbv beta iota delta [proj1_sig carry_opt_sig] in proj1_sig (carry_opt_sig i b).

  Definition carry_opt_correct i b : (i < length limb_widths)%nat -> carry_opt i b = carry i b := proj2_sig (carry_opt_sig i b).

  Definition carry_sequence_opt_sig (is : list nat) (us : digits)
    : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
  Proof.
    eexists. intros H.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt us is).
    Focus 2.
    { induction is; [ reflexivity | ].
      simpl; rewrite IHis, carry_opt_correct.
      - reflexivity.
      - rewrite base_length in H.
        apply H; apply in_eq.
      - intros. apply H. right. auto.
      }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
                                          proj1_sig (carry_sequence_opt_sig is us).

  Definition carry_sequence_opt_correct is us
    : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_sig is us).

  Definition carry_opt_cps_sig
             {T}
             (i : nat)
             (f : digits -> T)
             (b : digits)
    : { d : T |  (i < length base)%nat -> d = f (carry i b) }.
  Proof.
    eexists. intros H.
    rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
    cbv beta iota delta [carry_opt].
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (nth_default_opt 0%Z b i) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (nth_default_opt 0%Z b i) RHSf).
    change Z.shiftl with Z_shiftl_opt.
    change (-1) with (Z_opp_opt 1).
    change Z.add with Z_add_opt at 8 12 20 24.
    reflexivity.
  Defined.

  Definition carry_opt_cps {T} i f b
    := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).

  Definition carry_opt_cps_correct {T} i f b :
    (i < length base)%nat ->
    @carry_opt_cps T i f b = f (carry i b)
    := proj2_sig (carry_opt_cps_sig i f b).

  Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
    : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
  Proof.
    eexists.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt_cps id (List.rev is) us).
    Focus 2.
    { 
      assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
        subst. intros. rewrite <- in_rev in *. auto. }
      remember (rev is) as ris eqn:Heq.
      rewrite <- (rev_involutive is), <- Heq.
      clear H Heq is.
      rewrite fold_left_rev_right.
      revert us; induction ris; [ reflexivity | ]; intros.
      { simpl.
        rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
        rewrite carry_opt_cps_correct; [reflexivity|].
        apply Hr; left; reflexivity.
        } }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
                                              proj1_sig (carry_sequence_opt_cps_sig is us).

  Definition carry_sequence_opt_cps_correct is us
    : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_cps_sig is us).


  Lemma carry_sequence_opt_cps_rep
       : forall (is : list nat) (us : list Z) (x : F modulus),
         (forall i : nat, In i is -> i < length base)%nat ->
         length us = length base ->
         rep us x -> rep (carry_sequence_opt_cps is us) x.
  Proof.
    intros.
    rewrite carry_sequence_opt_cps_correct by assumption.
    apply carry_sequence_rep; assumption.
  Qed.

End Carries.

Section Addition.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.

  Definition add_opt_sig (us vs : T) : { b : digits | b = add us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add].
    reflexivity.
  Defined.

  Definition add_opt (us vs : T) : digits
    := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs).

  Definition add_opt_correct us vs
    : add_opt us vs = add us vs
    := proj2_sig (add_opt_sig us vs).

  Lemma add_opt_rep: forall (u v : T) (x y : F modulus),
    PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
    PseudoMersenneBaseRep.rep (add_opt u v) (x + y)%F.
  Proof.
    intros.
    rewrite add_opt_correct.
    auto using PseudoMersenneBaseRep.add_rep.
  Qed.

End Addition.

Section Subtraction.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.

  Definition sub_opt_sig (us vs : T) : { b : digits | b = sub coeff coeff_mod us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
    reflexivity.
  Defined.

  Definition sub_opt (us vs : T) : digits
    := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).

  Definition sub_opt_correct us vs
    : sub_opt us vs = sub coeff coeff_mod us vs
    := proj2_sig (sub_opt_sig us vs).

  Lemma sub_opt_rep: forall (u v : T) (x y : F modulus),
    PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
    PseudoMersenneBaseRep.rep (sub_opt u v) (x - y)%F.
  Proof.
    intros.
    rewrite sub_opt_correct.
    change (sub coeff coeff_mod) with PseudoMersenneBaseRep.sub.
    apply PseudoMersenneBaseRep.sub_rep; auto using coeff_length.
  Qed.

End Subtraction.

Section Multiplication.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Definition mul_bi'_step
             (mul_bi' : nat -> digits -> list Z -> list Z)
             (i : nat) (vsr : digits) (bs : list Z)
    : list Z
    := match vsr with
       | [] => []
       | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs
       end.

  Definition mul_bi'_opt_step_sig
             (mul_bi' : nat -> digits -> list Z -> list Z)
             (i : nat) (vsr : digits) (bs : list Z)
    : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }.
  Proof.
    eexists.
    cbv [mul_bi'_step].
    opt_step.
    { reflexivity. }
    { cbv [crosscoef ext_base base].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Defined.

  Definition mul_bi'_opt_step
             (mul_bi' : nat -> digits -> list Z -> list Z)
             (i : nat) (vsr : digits) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in
        proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs).

  Fixpoint mul_bi'_opt
           (i : nat) (vsr : digits) (bs : list Z) {struct vsr}
    : list Z
    := mul_bi'_opt_step mul_bi'_opt i vsr bs.

  Definition mul_bi'_opt_correct
             (i : nat) (vsr : digits) (bs : list Z)
    : mul_bi'_opt i vsr bs = mul_bi' bs i vsr.
  Proof.
    revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
    { reflexivity. }
    { simpl mul_bi'.
      rewrite <- IHvsr; clear IHvsr.
      unfold mul_bi'_opt, mul_bi'_opt_step.
      apply f_equal2; [ | reflexivity ].
      cbv [crosscoef ext_base base].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Qed.

  Definition mul'_step
             (mul' : digits -> digits -> list Z -> digits)
             (usr vs : digits) (bs : list Z)
    : digits
    := match usr with
       | [] => []
       | u :: usr' => add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs)
       end.

  Lemma map_zeros : forall a n l,
    map (Z.mul a) (zeros n ++ l) = zeros n ++ map (Z.mul a) l.
  Admitted.

  Definition mul'_opt_step_sig
             (mul' : digits -> digits -> list Z -> digits)
             (usr vs : digits) (bs : list Z)
    : { d : digits | d = mul'_step mul' usr vs bs }.
  Proof.
    eexists.
    cbv [mul'_step].
    match goal with
    | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
      => refine (_ : match e with nil => _ | _ => _ end = _);
           destruct e
    end.
    { reflexivity. }
    { cbv [mul_each mul_bi].
      rewrite <- mul_bi'_opt_correct.
      rewrite map_zeros.
      change @map with @map_opt.
      cbv [zeros].
      reflexivity. }
  Defined.

  Definition mul'_opt_step
             (mul' : digits -> digits -> list Z -> digits)
             (usr vs : digits) (bs : list Z)
    : digits
    := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs).

  Fixpoint mul'_opt
           (usr vs : digits) (bs : list Z)
    : digits
    := mul'_opt_step mul'_opt usr vs bs.

  Definition mul'_opt_correct
           (usr vs : digits) (bs : list Z)
    : mul'_opt usr vs bs = mul' bs usr vs.
  Proof.
    revert vs; induction usr as [|usr usrs IHusr]; intros.
    { reflexivity. }
    { simpl.
      rewrite <- IHusr; clear IHusr.
      apply f_equal2; [ | reflexivity ].
      cbv [mul_each mul_bi].
      rewrite map_zeros.
      rewrite <- mul_bi'_opt_correct.
      reflexivity. }
  Qed.

  Definition mul_opt_sig (us vs : T) : { b : digits | b = mul us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce].
    rewrite <- mul'_opt_correct.
    cbv [base PseudoMersenneBaseParams.limb_widths].
    rewrite map_shiftl by apply k_nonneg.
    rewrite c_subst.
    rewrite k_subst.
    change @map with @map_opt.
    change base_from_limb_widths with base_from_limb_widths_opt.  
    change @Z_shiftl_by with @Z_shiftl_by_opt.
    reflexivity.
  Defined.

  Definition mul_opt (us vs : T) : digits
    := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).

  Definition mul_opt_correct us vs
    : mul_opt us vs = mul us vs
    := proj2_sig (mul_opt_sig us vs).

  Lemma mul_opt_rep:
    forall (u v : T) (x y : F modulus), PseudoMersenneBaseRep.rep u x -> PseudoMersenneBaseRep.rep v y ->
    PseudoMersenneBaseRep.rep (mul_opt u v) (x * y)%F.
  Proof.
    intros.
    rewrite mul_opt_correct.
    change mul with PseudoMersenneBaseRep.mul.
    auto using PseudoMersenneBaseRep.mul_rep.
  Qed.

  Definition carry_mul_opt 
             (is : list nat)
             (us vs : list Z)
             : list Z
    := carry_sequence_opt_cps c_ is (mul_opt us vs).

  Lemma carry_mul_opt_correct
    : forall (is : list nat) (us vs : list Z) (x  y: F modulus),
      PseudoMersenneBaseRep.rep us x -> PseudoMersenneBaseRep.rep vs y ->
      (forall i : nat, In i is -> i < length base)%nat ->
      length (mul_opt us vs) = length base ->
      PseudoMersenneBaseRep.rep (carry_mul_opt is us vs) (x*y)%F.
  Proof.
    intros is us vs x y; intros.
    change (carry_mul_opt _ _ _) with (carry_sequence_opt_cps c_ is (mul_opt us vs)).
    apply carry_sequence_opt_cps_rep, mul_opt_rep; auto.
  Qed.
End Multiplication.