aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
blob: 436d309c7be563dbc22e66cb3d686934a78f85da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.QArith.QArith Coq.QArith.Qround.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z.

Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := {
  coeff : tuple Z (length limb_widths);
  coeff_mod: decode coeff = 0%F
}.

(* Computed versions of some functions. *)

Definition Z_add_opt := Eval compute in Z.add.
Definition Z_sub_opt := Eval compute in Z.sub.
Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
Definition Z_ones_opt := Eval compute in Z.ones.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.

Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
Definition update_nth_opt {A} := Eval compute in @update_nth A.
Definition map_opt {A B} := Eval compute in @map A B.
Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain.
Definition length_opt := Eval compute in length.
Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths.
Definition minus_opt := Eval compute in minus.
Definition max_ones_opt := Eval compute in @max_ones.
Definition from_list_default_opt {A} := Eval compute in (@from_list_default A).

Definition Let_In {A P} (x : A) (f : forall y : A, P y)
  := let y := x in f y.

(* Some automation that comes in handy when constructing base parameters *)
Ltac opt_step :=
  match goal with
  | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
    => refine (_ : match e with nil => _ | _ => _ end = _);
       destruct e
  end.

Ltac brute_force_indices limb_widths :=
  intros; unfold sum_firstn, limb_widths;  cbv [length limb_widths] in *;
  repeat match goal with
  | _ => progress simpl in *
  | [H : (0 + _ < _)%nat |- _ ] => simpl in H
  | [H : (S _ + _ < S _)%nat |- _ ] => simpl in H
  | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H
  | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x
  | [H : (?x < _)%nat |- _ ] => is_var x; destruct x
  | _ => omega
  end.


Definition limb_widths_from_len len k := Eval compute in
  (fix loop i prev :=
    match i with
    | O => nil
    | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
                        then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
                        else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
      x - prev:: (loop i' x)
    end) len 0.

Ltac construct_params prime_modulus len k :=
  let lw := fresh "lw" in set (lw := limb_widths_from_len len k);
  cbv in lw;
  eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
  [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
  | abstract (cbv; congruence)
  | abstract brute_force_indices lw
  | abstract apply prime_modulus
  | abstract (cbv; congruence)
  | abstract brute_force_indices lw].

Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
  match limb_widths with
  | nil => nil
  | x :: tail =>
      2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail
  end.

Ltac compute_preconditions :=
  cbv; intros; repeat match goal with H : _ \/ _ |- _  =>
    destruct H; subst; [ congruence | ] end; (congruence || omega).

Ltac subst_precondition := match goal with
  | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
end.

Ltac kill_precondition H :=
  forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
  subst_precondition.

Section Carries.
  Context `{prm : PseudoMersenneBaseParams}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition carry_gen_opt_sig fc fi i us
    : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat ->
                     d = carry_gen limb_widths fc fi i us}.
  Proof.
    eexists; intros.
    cbv beta iota delta [carry_gen carry_single Z.pow2_mod].
    rewrite add_to_nth_set_nth.
    change @nth_default with @nth_default_opt in *.
    change @set_nth with @set_nth_opt in *.
    change Z.ones with Z_ones_opt.
    rewrite set_nth_nth_default by assumption.
    rewrite <- @beq_nat_eq_nat_dec.
    reflexivity.
  Defined.

  Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in
                                                   proj1_sig (carry_gen_opt_sig fc fi i us).

  Definition carry_gen_opt_correct fc fi i us
    : (0 <= fi (S (fi i)) < length us)%nat ->
      carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us
    := proj2_sig (carry_gen_opt_sig fc fi i us).

  Definition carry_opt_sig
             (i : nat) (b : list Z)
    : { d : list Z | (length b = length limb_widths)
                     -> (i < length limb_widths)%nat
                     -> d = carry i b }.
  Proof.
    eexists ; intros.
    cbv [carry].
    rewrite <-pull_app_if_sumbool.
    cbv beta delta
      [carry carry_and_reduce carry_simple].
    lazymatch goal with
    | [ |- _ = (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y
    end.
    Focus 2. {
      cbv zeta.
      break_if; rewrite <-carry_gen_opt_correct by (omega ||
          (replace (length b) with (length limb_widths) by congruence;
           apply Nat.mod_bound_pos; omega)); reflexivity.
    } Unfocus.
    rewrite c_subst.
    rewrite <- @beq_nat_eq_nat_dec.
    cbv [carry_gen_opt].
    reflexivity.
  Defined.

  Definition carry_opt is us := Eval cbv [proj1_sig carry_opt_sig] in
                                          proj1_sig (carry_opt_sig is us).

  Definition carry_opt_correct i us
    : length us = length limb_widths
      -> (i < length limb_widths)%nat
      -> carry_opt i us = carry i us
    := proj2_sig (carry_opt_sig i us).

  Definition carry_sequence_opt_sig (is : list nat) (us : list Z)
    : { b : list Z | (length us = length limb_widths)
                     -> (forall i, In i is -> i < length limb_widths)%nat
                     -> b = carry_sequence is us }.
  Proof.
    eexists. intros H.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt us is).
    Focus 2.
    { induction is; [ reflexivity | ].
      simpl; rewrite IHis, carry_opt_correct.
      - reflexivity.
      - fold (carry_sequence is us). auto using length_carry_sequence.
      - auto using in_eq.
      - intros. auto using in_cons.
      }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
                                          proj1_sig (carry_sequence_opt_sig is us).

  Definition carry_sequence_opt_correct is us
    : (length us = length limb_widths)
      -> (forall i, In i is -> i < length limb_widths)%nat
      -> carry_sequence_opt is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_sig is us).

  Definition carry_gen_opt_cps_sig
             {T} fc fi
             (i : nat)
             (f : list Z -> T)
             (b : list Z)
    : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }.
  Proof.
    eexists. intros H.
    rewrite <-carry_gen_opt_correct by assumption.
    cbv beta iota delta [carry_gen_opt].
    match goal with |- appcontext[?a & Z_ones_opt _] => 
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (a) RHSf) end.
    reflexivity.
  Defined.

  Definition carry_gen_opt_cps {T} fc fi i f b
    := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in
                                 proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b).

  Definition carry_gen_opt_cps_correct {T} fc fi i f b :
   (0 <= fi (S (fi i)) < length b)%nat ->
    @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b)
    := proj2_sig (carry_gen_opt_cps_sig fc fi i f b).

  Definition carry_opt_cps_sig
             {T}
             (i : nat)
             (f : list Z -> T)
             (b : list Z)
    : { d : T | (length b = length limb_widths)
                 -> (i < length limb_widths)%nat
                 -> d = f (carry i b) }.
  Proof.
    eexists. intros.
    cbv beta delta
      [carry carry_and_reduce carry_simple].
    rewrite <-pull_app_if_sumbool.
    lazymatch goal with
    | [ |- _ = ?f (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y
    end.
    Focus 2. {
      cbv zeta.
      break_if; rewrite <-carry_gen_opt_cps_correct by (omega ||
          (replace (length b) with (length limb_widths) by congruence;
           apply Nat.mod_bound_pos; omega)); reflexivity.
    } Unfocus.
    rewrite c_subst.
    rewrite <- @beq_nat_eq_nat_dec.
    reflexivity.
  Defined.

  Definition carry_opt_cps {T} i f b
    := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).

  Definition carry_opt_cps_correct {T} i f b :
    (length b = length limb_widths)
    -> (i < length limb_widths)%nat
    -> @carry_opt_cps T i f b = f (carry i b)
    := proj2_sig (carry_opt_cps_sig i f b).

  Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z)
     (f : list Z -> T)
    : { b : T | (length us = length limb_widths)
                -> (forall i, In i is -> i < length limb_widths)%nat
                -> b = f (carry_sequence is us) }.
  Proof.
    eexists.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt_cps f (List.rev is) us).
    Focus 2.
    {
      assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. {
        subst. intros. rewrite <- in_rev in *. auto. }
      remember (rev is) as ris eqn:Heq.
      rewrite <- (rev_involutive is), <- Heq in H0 |- *.
      clear H0 Heq is.
      rewrite fold_left_rev_right.
      revert H. revert us; induction ris; [ reflexivity | ]; intros.
      { simpl.
        rewrite <- IHris; clear IHris;
          [|intros; apply Hr; right; assumption|auto using length_carry].
        rewrite carry_opt_cps_correct; [reflexivity|congruence|].
        apply Hr; left; reflexivity.
        } }
    Unfocus.
    cbv [carry_opt_cps].
    reflexivity.
  Defined.

  Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) :=
    Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
      proj1_sig (carry_sequence_opt_cps_sig is us f).

  Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T)
    : (length us = length limb_widths)
      -> (forall i, In i is -> i < length limb_widths)%nat
      -> carry_sequence_opt_cps is us f = f (carry_sequence is us)
    := proj2_sig (carry_sequence_opt_cps_sig is us f).

  Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) ->
    (i < length limb_widths)%nat.
  Proof.
    unfold Pow2Base.full_carry_chain; intros.
    apply Pow2BaseProofs.make_chain_lt; auto.
  Qed.

  Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
  Proof.
    eexists.
    cbv [carry_full ModularBaseSystemList.carry_full].
    rewrite <-from_list_default_eq with (d := 0).
    rewrite <-carry_sequence_opt_cps_correct by (rewrite ?length_to_list; auto; apply full_carry_chain_bounds).
    change @Pow2Base.full_carry_chain with full_carry_chain_opt.
    reflexivity.
  Defined.

  Definition carry_full_opt (us : digits) : digits
    := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us).

  Definition carry_full_opt_correct us : carry_full_opt us = carry_full us :=
    proj2_sig (carry_full_opt_sig us).

  Definition carry_full_opt_cps_sig
             {T}
             (f : digits -> T)
             (us : digits)
    : { d : T | d = f (carry_full us) }.
  Proof.
    eexists.
    rewrite <- carry_full_opt_correct.
    cbv beta iota delta [carry_full_opt].
    rewrite carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
    match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
      change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
    rewrite <-carry_sequence_opt_cps_correct by (apply length_to_list || apply full_carry_chain_bounds).
    reflexivity.
  Defined.

  Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T
    := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us).

  Definition carry_full_opt_cps_correct {T} us (f : digits -> T) :
    carry_full_opt_cps f us = f (carry_full us) :=
    proj2_sig (carry_full_opt_cps_sig f us).

End Carries.

Section Addition.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add].
    reflexivity.
  Defined.

  Definition add_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs).

  Definition add_opt_correct us vs
    : add_opt us vs = add us vs
    := proj2_sig (add_opt_sig us vs).
End Addition.

Section Subtraction.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
    reflexivity.
  Defined.

  Definition sub_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).

  Definition sub_opt_correct us vs
    : sub_opt us vs = sub coeff us vs
    := proj2_sig (sub_opt_sig us vs).
End Subtraction.

Section Multiplication.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition mul_bi'_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := match vsr with
       | [] => []
       | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs
       end.

  Definition mul_bi'_opt_step_sig
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }.
  Proof.
    eexists.
    cbv [mul_bi'_step].
    opt_step.
    { reflexivity. }
    { cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Defined.

  Definition mul_bi'_opt_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in
        proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs).

  Fixpoint mul_bi'_opt
           (i : nat) (vsr : list Z) (bs : list Z) {struct vsr}
    : list Z
    := mul_bi'_opt_step mul_bi'_opt i vsr bs.

  Definition mul_bi'_opt_correct
             (i : nat) (vsr : list Z) (bs : list Z)
    : mul_bi'_opt i vsr bs = mul_bi' bs i vsr.
  Proof.
    revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
    { reflexivity. }
    { simpl mul_bi'.
      rewrite <- IHvsr; clear IHvsr.
      unfold mul_bi'_opt, mul_bi'_opt_step.
      apply f_equal2; [ | reflexivity ].
      cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Qed.

  Definition mul'_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := match usr with
       | [] => []
       | u :: usr' => BaseSystem.add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs)
       end.

  Lemma map_zeros : forall a n l,
    map (Z.mul a) (zeros n ++ l) = zeros n ++ map (Z.mul a) l.
  Proof.
    induction n; simpl; [ reflexivity | intros; apply f_equal2; [ omega | congruence ] ].
  Qed.

  Definition mul'_opt_step_sig
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : { d : list Z | d = mul'_step mul' usr vs bs }.
  Proof.
    eexists.
    cbv [mul'_step].
    match goal with
    | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
      => refine (_ : match e with nil => _ | _ => _ end = _);
           destruct e
    end.
    { reflexivity. }
    { cbv [mul_each mul_bi].
      rewrite <- mul_bi'_opt_correct.
      rewrite map_zeros.
      change @map with @map_opt.
      cbv [zeros].
      reflexivity. }
  Defined.

  Definition mul'_opt_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs).

  Fixpoint mul'_opt
           (usr vs : list Z) (bs : list Z)
    : list Z
    := mul'_opt_step mul'_opt usr vs bs.

  Definition mul'_opt_correct
           (usr vs : list Z) (bs : list Z)
    : mul'_opt usr vs bs = mul' bs usr vs.
  Proof.
    revert vs; induction usr as [|usr usrs IHusr]; intros.
    { reflexivity. }
    { simpl.
      rewrite <- IHusr; clear IHusr.
      apply f_equal2; [ | reflexivity ].
      cbv [mul_each mul_bi].
      rewrite map_zeros.
      rewrite <- mul_bi'_opt_correct.
      cbv [zeros].
      reflexivity. }
  Qed.

  Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }.
  Proof.
    eexists.
    cbv [mul ModularBaseSystemList.mul BaseSystem.mul mul_each mul_bi mul_bi' zeros reduce].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change (@from_list_default Z) with (@from_list_default_opt Z).
    apply f_equal.
    rewrite ext_base_alt by auto using limb_widths_pos with zarith.
    rewrite <- mul'_opt_correct.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    rewrite Z.map_shiftl by apply k_nonneg.
    rewrite c_subst.
    fold k; rewrite k_subst.
    change @map with @map_opt.
    change @Z.shiftl_by with @Z_shiftl_by_opt.
    reflexivity.
  Defined.

  Definition mul_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).

  Definition mul_opt_correct us vs
    : mul_opt us vs = mul us vs
    := proj2_sig (mul_opt_sig us vs).

  Definition carry_mul_opt_sig {T} (f:digits -> T)
    (us vs : digits) : { x | x = f (carry_mul us vs) }.
  Proof.
    eexists.
    cbv [carry_mul].
    erewrite <-carry_full_opt_cps_correct by eauto.
    erewrite <-mul_opt_correct.
    cbv [carry_full_opt_cps mul_opt].
    erewrite from_list_default_eq.
    rewrite to_list_from_list.
    reflexivity.
    Grab Existential Variables.
    rewrite mul'_opt_correct.
    distr_length.
    assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega).
    rewrite Min.min_l; rewrite !length_to_list; break_match; try omega.
    rewrite Max.max_l; omega.
  Defined.

  Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T
    := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig f us vs).

  Definition carry_mul_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
    : carry_mul_opt_cps f us vs = f (carry_mul us vs)
    := proj2_sig (carry_mul_opt_sig f us vs).

  Definition carry_mul_opt := carry_mul_opt_cps id.

  Definition carry_mul_opt_correct (us vs : digits)
    : carry_mul_opt us vs = carry_mul us vs :=
    carry_mul_opt_cps_correct id us vs.

End Multiplication.

Section with_base.
  Context {modulus} (prm : PseudoMersenneBaseParams modulus).
  Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
  Local Notation log_cap i := (nth_default 0 limb_widths i).

  Record freezePreconditions int_width :=
    mkFreezePreconditions {
        lt_1_length_base : (1 < length base)%nat;
        int_width_pos : 0 < int_width;
        int_width_compat : forall w, In w limb_widths -> w <= int_width;
        c_pos : 0 < c;
        c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < 2 ^ log_cap 0;
        c_reduce2 : c < 2 ^ log_cap 0 - c;
        two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus
      }.
End with_base.
Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos
    c_reduce1 c_reduce2 two_pow_k_le_2modulus.

Section Canonicalization.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
    {int_width} (preconditions : freezePreconditions prm int_width).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ.

  Definition modulus_digits_opt_sig :
    { b : list Z | b = modulus_digits }.
  Proof.
    eexists.
    cbv beta iota delta [modulus_digits].
    change Pow2Base.encodeZ with encodeZ_opt.
    reflexivity.
  Defined.

  Definition modulus_digits_opt : list Z
    := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig).

  Definition modulus_digits_opt_correct
    : modulus_digits_opt = modulus_digits
    := proj2_sig (modulus_digits_opt_sig).

  Definition carry_full_3_opt_cps_sig
             {T} (f : digits -> T)
             (us : digits)
    : { d : T | d = f (carry_full (carry_full (carry_full us))) }.
  Proof.
    eexists.
    transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us).
    Focus 2. {
      rewrite !carry_full_opt_cps_correct by assumption; reflexivity.
    }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T
    := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us).

  Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us :
    carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) :=
    proj2_sig (carry_full_3_opt_cps_sig f us).

  Definition freeze_opt_sig (us : digits) :
    { b : digits | b = freeze us }.
  Proof.
    eexists.
    cbv [freeze conditional_subtract_modulus].
    rewrite <-from_list_default_eq with (d := 0%Z).
    change (@from_list_default Z) with (@from_list_default_opt Z).
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (to_list (length limb_widths)  (carry_full (carry_full (carry_full us)))) RHSf).
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in
    rewrite <-carry_full_3_opt_cps_correct with (f := RHSf).
    cbv beta iota delta [ge_modulus ge_modulus'].
    change length with length_opt.
    change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt).
    change @max_ones with max_ones_opt.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    change minus with minus_opt.
    change @map with @map_opt.
    change Z.sub with Z_sub_opt at 1.
    rewrite <-modulus_digits_opt_correct.
    reflexivity.
  Defined.

  Definition freeze_opt (us : digits) : digits
    := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us).

  Definition freeze_opt_correct us
    : freeze_opt us = freeze us
    := proj2_sig (freeze_opt_sig us).
(*
  Lemma freeze_opt_canonical: forall us vs x,
    @pre_carry_bounds _ _ int_width us -> rep us x ->
    @pre_carry_bounds _ _ int_width vs -> rep vs x ->
    freeze_opt us = freeze_opt vs.
  Proof.
    intros.
    rewrite !freeze_opt_correct.
    eapply freeze_canonical with (B := int_width); eauto.
  Qed.

  Lemma freeze_opt_preserves_rep : forall us x, rep us x ->
    rep (freeze_opt us) x.
  Proof.
    intros.
    rewrite freeze_opt_correct.
    eapply freeze_preserves_rep; eauto.
  Qed.

  Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x ->
    @pre_carry_bounds _ _ int_width us ->
    @pre_carry_bounds _ _ int_width vs ->
    (rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs).
  Proof.
    split; eauto using freeze_opt_canonical.
    auto using freeze_opt_preserves_rep.
  Qed.
*)
End Canonicalization.