aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
blob: d1a6f62285252998ae654a1bd44d5329a0c34c34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.QArith.QArith Coq.QArith.Qround.
Require Import Crypto.Tactics.VerdiTactics.
Local Open Scope Z.

Class SubtractionCoefficient (m : Z) (prm : PseudoMersenneBaseParams m) := {
  coeff : tuple Z (length limb_widths);
  coeff_mod: decode coeff = 0%F
}.

(* Computed versions of some functions. *)

Definition Z_add_opt := Eval compute in Z.add.
Definition Z_sub_opt := Eval compute in Z.sub.
Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.

Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
Definition update_nth_opt {A} := Eval compute in @update_nth A.
Definition map_opt {A B} := Eval compute in @map A B.
Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain.
Definition length_opt := Eval compute in length.
Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths.
Definition minus_opt := Eval compute in minus.
Definition max_ones_opt := Eval compute in @max_ones.
Definition from_list_default_opt {A} := Eval compute in (@from_list_default A).

Definition Let_In {A P} (x : A) (f : forall y : A, P y)
  := let y := x in f y.

(* Some automation that comes in handy when constructing base parameters *)
Ltac opt_step :=
  match goal with
  | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
    => refine (_ : match e with nil => _ | _ => _ end = _);
       destruct e
  end.

Ltac brute_force_indices limb_widths :=
  intros; unfold sum_firstn, limb_widths;  cbv [length limb_widths] in *;
  repeat match goal with
  | _ => progress simpl in *
  | [H : (0 + _ < _)%nat |- _ ] => simpl in H
  | [H : (S _ + _ < S _)%nat |- _ ] => simpl in H
  | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H
  | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x
  | [H : (?x < _)%nat |- _ ] => is_var x; destruct x
  | _ => omega
  end.


Definition limb_widths_from_len len k := Eval compute in
  (fix loop i prev :=
    match i with
    | O => nil
    | S i' => let x := (if (Z.eq_dec ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
                        then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
                        else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
      x - prev:: (loop i' x)
    end) len 0.

Ltac construct_params prime_modulus len k :=
  let lw := fresh "lw" in set (lw := limb_widths_from_len len k);
  cbv in lw;
  eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
  [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
  | abstract (cbv; congruence)
  | abstract brute_force_indices lw
  | abstract apply prime_modulus
  | abstract (cbv; congruence)
  | abstract brute_force_indices lw].

Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
  match limb_widths with
  | nil => nil
  | x :: tail =>
      2 ^ (x + 1) - (2 * c) :: map (fun w => 2 ^ (w + 1) - 2) tail
  end.

Ltac compute_preconditions :=
  cbv; intros; repeat match goal with H : _ \/ _ |- _  =>
    destruct H; subst; [ congruence | ] end; (congruence || omega).

Ltac subst_precondition := match goal with
  | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
end.

Ltac kill_precondition H :=
  forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
  subst_precondition.

Section Carries.
  Context `{prm : PseudoMersenneBaseParams}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition carry_opt_sig
             (i : nat) (b : digits)
    : { d : digits | (i < length limb_widths)%nat -> d = carry i b }.
  Proof.
    eexists ; intros.
    cbv [carry ModularBaseSystemList.carry].
    rewrite <-from_list_default_eq with (d := 0%Z).
    rewrite <-pull_app_if_sumbool.
    cbv beta delta
      [carry carry_and_reduce Pow2Base.carry_gen Pow2Base.carry_and_reduce_single Pow2Base.carry_simple
       Z.pow2_mod Z.ones Z.pred
       PseudoMersenneBaseParams.limb_widths].
    rewrite !add_to_nth_set_nth.
    change @Pow2Base.base_from_limb_widths with @base_from_limb_widths_opt.
    change @nth_default with @nth_default_opt in *.
    change @set_nth with @set_nth_opt in *.
    lazymatch goal with
    | [ |- _ = _ (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:digits); evar (y:digits); transitivity (if br then x else y); subst x; subst y
    end.
    2:cbv zeta.
    2:break_if; reflexivity.

    change @from_list_default with @from_list_default_opt.
    change @nth_default with @nth_default_opt.
    rewrite c_subst.
    change @set_nth with @set_nth_opt.
    change @map with @map_opt.
    rewrite <- @beq_nat_eq_nat_dec.
    reflexivity.
  Defined.

  Definition carry_opt is us := Eval cbv [proj1_sig carry_opt_sig] in
                                          proj1_sig (carry_opt_sig is us).

  Definition carry_opt_correct i us
    : (i < length limb_widths)%nat -> carry_opt i us = carry i us
    := proj2_sig (carry_opt_sig i us).

  Definition carry_sequence_opt_sig (is : list nat) (us : digits)
    : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
  Proof.
    eexists. intros H.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt us is).
    Focus 2.
    { induction is; [ reflexivity | ].
      simpl; rewrite IHis, carry_opt_correct.
      - reflexivity.
      - rewrite base_length in H.
        apply H; apply in_eq.
      - intros. apply H. right. auto.
      }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
                                          proj1_sig (carry_sequence_opt_sig is us).

  Definition carry_sequence_opt_correct is us
    : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_sig is us).

  Definition carry_opt_cps_sig
             {T}
             (i : nat)
             (f : digits -> T)
             (b : digits)
    : { d : T |  (i < length base)%nat -> d = f (carry i b) }.
  Proof.
    eexists. intros H.
    rewrite <- carry_opt_correct by (rewrite base_length in H; assumption).
    cbv beta iota delta [carry_opt].
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (nth_default_opt 0%Z (to_list _ b) i) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (nth_default_opt 0%Z (to_list _ b) i) RHSf).
    change Z.shiftl with Z_shiftl_opt.
    change (-1) with (Z_opp_opt 1).
    change Z.add with Z_add_opt at 5 9 17 21.
    reflexivity.
  Defined.

  Definition carry_opt_cps {T} i f b
    := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).

  Definition carry_opt_cps_correct {T} i f b :
    (i < length base)%nat ->
    @carry_opt_cps T i f b = f (carry i b)
    := proj2_sig (carry_opt_cps_sig i f b).

  Definition carry_sequence_opt_cps2_sig {T} (is : list nat) (us : digits)
     (f : digits -> T)
    : { b : T | (forall i, In i is -> i < length base)%nat -> b = f (carry_sequence is us) }.
  Proof.
    eexists.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt_cps f (List.rev is) us).
    Focus 2.
    {
      assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
        subst. intros. rewrite <- in_rev in *. auto. }
      remember (rev is) as ris eqn:Heq.
      rewrite <- (rev_involutive is), <- Heq.
      clear H Heq is.
      rewrite fold_left_rev_right.
      revert us; induction ris; [ reflexivity | ]; intros.
      { simpl.
        rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
        rewrite carry_opt_cps_correct; [reflexivity|].
        apply Hr; left; reflexivity.
        } }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt_cps2 {T} is us (f : digits -> T) :=
    Eval cbv [proj1_sig carry_sequence_opt_cps2_sig] in
      proj1_sig (carry_sequence_opt_cps2_sig is us f).

  Definition carry_sequence_opt_cps2_correct {T} is us (f : digits -> T)
    : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps2 is us f = f (carry_sequence is us)
    := proj2_sig (carry_sequence_opt_cps2_sig is us f).

  Definition carry_sequence_opt_cps_sig (is : list nat) (us : digits)
    : { b : digits | (forall i, In i is -> i < length base)%nat -> b = carry_sequence is us }.
  Proof.
    eexists.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt_cps id (List.rev is) us).
    Focus 2.
    {
      assert (forall i, In i (rev is) -> i < length base)%nat as Hr. {
        subst. intros. rewrite <- in_rev in *. auto. }
      remember (rev is) as ris eqn:Heq.
      rewrite <- (rev_involutive is), <- Heq.
      clear H Heq is.
      rewrite fold_left_rev_right.
      revert us; induction ris; [ reflexivity | ]; intros.
      { simpl.
        rewrite <- IHris; clear IHris; [|intros; apply Hr; right; assumption].
        rewrite carry_opt_cps_correct; [reflexivity|].
        apply Hr; left; reflexivity.
        } }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt_cps is us := Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
                                              proj1_sig (carry_sequence_opt_cps_sig is us).

  Definition carry_sequence_opt_cps_correct is us
    : (forall i, In i is -> i < length base)%nat -> carry_sequence_opt_cps is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_cps_sig is us).


  Lemma carry_sequence_opt_cps_rep
       : forall (is : list nat) (us : digits) (x : F modulus),
         (forall i : nat, In i is -> i < length base)%nat ->
         rep us x -> rep (carry_sequence_opt_cps is us) x.
  Proof.
    intros.
    rewrite carry_sequence_opt_cps_correct by assumption.
    auto using carry_sequence_rep.
  Qed.

  Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) -> (i < length base)%nat.
  Proof.
    unfold Pow2Base.full_carry_chain; rewrite <-base_length; intros.
    apply Pow2BaseProofs.make_chain_lt; auto.
  Qed.

  Definition carry_full_opt_sig (us : digits) : { b : digits | b = carry_full us }.
  Proof.
    eexists.
    cbv [carry_full].
    change @Pow2Base.full_carry_chain with full_carry_chain_opt.
    rewrite <-carry_sequence_opt_cps_correct by (auto; apply full_carry_chain_bounds).
    reflexivity.
  Defined.

  Definition carry_full_opt (us : digits) : digits
    := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us).

  Definition carry_full_opt_correct us : carry_full_opt us = carry_full us :=
    proj2_sig (carry_full_opt_sig us).

  Definition carry_full_opt_cps_sig
             {T}
             (f : digits -> T)
             (us : digits)
    : { d : T | d = f (carry_full us) }.
  Proof.
    eexists.
    rewrite <- carry_full_opt_correct.
    cbv beta iota delta [carry_full_opt].
    rewrite carry_sequence_opt_cps_correct by apply full_carry_chain_bounds.
    rewrite <-carry_sequence_opt_cps2_correct by apply full_carry_chain_bounds.
    reflexivity.
  Defined.

  Definition carry_full_opt_cps {T} (f : digits -> T) (us : digits) : T
    := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us).

  Definition carry_full_opt_cps_correct {T} us (f : digits -> T) :
    carry_full_opt_cps f us = f (carry_full us) :=
    proj2_sig (carry_full_opt_cps_sig f us).

End Carries.

Section Addition.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add].
    reflexivity.
  Defined.

  Definition add_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs).

  Definition add_opt_correct us vs
    : add_opt us vs = add us vs
    := proj2_sig (add_opt_sig us vs).
End Addition.

Section Subtraction.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
    reflexivity.
  Defined.

  Definition sub_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).

  Definition sub_opt_correct us vs
    : sub_opt us vs = sub coeff us vs
    := proj2_sig (sub_opt_sig us vs).
End Subtraction.

Section Multiplication.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition mul_bi'_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := match vsr with
       | [] => []
       | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs
       end.

  Definition mul_bi'_opt_step_sig
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }.
  Proof.
    eexists.
    cbv [mul_bi'_step].
    opt_step.
    { reflexivity. }
    { cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Defined.

  Definition mul_bi'_opt_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in
        proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs).

  Fixpoint mul_bi'_opt
           (i : nat) (vsr : list Z) (bs : list Z) {struct vsr}
    : list Z
    := mul_bi'_opt_step mul_bi'_opt i vsr bs.

  Definition mul_bi'_opt_correct
             (i : nat) (vsr : list Z) (bs : list Z)
    : mul_bi'_opt i vsr bs = mul_bi' bs i vsr.
  Proof.
    revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
    { reflexivity. }
    { simpl mul_bi'.
      rewrite <- IHvsr; clear IHvsr.
      unfold mul_bi'_opt, mul_bi'_opt_step.
      apply f_equal2; [ | reflexivity ].
      cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Qed.

  Definition mul'_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := match usr with
       | [] => []
       | u :: usr' => BaseSystem.add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs)
       end.

  Lemma map_zeros : forall a n l,
    map (Z.mul a) (zeros n ++ l) = zeros n ++ map (Z.mul a) l.
  Proof.
    induction n; simpl; [ reflexivity | intros; apply f_equal2; [ omega | congruence ] ].
  Qed.

  Definition mul'_opt_step_sig
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : { d : list Z | d = mul'_step mul' usr vs bs }.
  Proof.
    eexists.
    cbv [mul'_step].
    match goal with
    | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
      => refine (_ : match e with nil => _ | _ => _ end = _);
           destruct e
    end.
    { reflexivity. }
    { cbv [mul_each mul_bi].
      rewrite <- mul_bi'_opt_correct.
      rewrite map_zeros.
      change @map with @map_opt.
      cbv [zeros].
      reflexivity. }
  Defined.

  Definition mul'_opt_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs).

  Fixpoint mul'_opt
           (usr vs : list Z) (bs : list Z)
    : list Z
    := mul'_opt_step mul'_opt usr vs bs.

  Definition mul'_opt_correct
           (usr vs : list Z) (bs : list Z)
    : mul'_opt usr vs bs = mul' bs usr vs.
  Proof.
    revert vs; induction usr as [|usr usrs IHusr]; intros.
    { reflexivity. }
    { simpl.
      rewrite <- IHusr; clear IHusr.
      apply f_equal2; [ | reflexivity ].
      cbv [mul_each mul_bi].
      rewrite map_zeros.
      rewrite <- mul_bi'_opt_correct.
      cbv [zeros].
      reflexivity. }
  Qed.

  Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }.
  Proof.
    eexists.
    cbv [mul ModularBaseSystemList.mul BaseSystem.mul mul_each mul_bi mul_bi' zeros reduce].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change (@from_list_default Z) with (@from_list_default_opt Z).
    apply f_equal.
    rewrite ext_base_alt by auto using limb_widths_pos with zarith.
    rewrite <- mul'_opt_correct.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    rewrite Z.map_shiftl by apply k_nonneg.
    rewrite c_subst.
    fold k; rewrite k_subst.
    change @map with @map_opt.
    change @Z.shiftl_by with @Z_shiftl_by_opt.
    reflexivity.
  Defined.

  Definition mul_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).

  Definition mul_opt_correct us vs
    : mul_opt us vs = mul us vs
    := proj2_sig (mul_opt_sig us vs).

  Definition carry_mul_opt_sig {T} (f:digits -> T)
    (us vs : digits) : { x | x = f (carry_mul us vs) }.
  Proof.
    eexists.
    cbv [carry_mul].
    erewrite <-carry_full_opt_cps_correct by eauto.
    erewrite <-mul_opt_correct.
    reflexivity.
  Defined.

  Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T
    := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig f us vs).

  Definition carry_mul_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
    : carry_mul_opt_cps f us vs = f (carry_mul us vs)
    := proj2_sig (carry_mul_opt_sig f us vs).
End Multiplication.

Section with_base.
  Context {modulus} (prm : PseudoMersenneBaseParams modulus).
  Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
  Local Notation log_cap i := (nth_default 0 limb_widths i).

  Record freezePreconditions int_width :=
    mkFreezePreconditions {
        lt_1_length_base : (1 < length base)%nat;
        int_width_pos : 0 < int_width;
        int_width_compat : forall w, In w limb_widths -> w <= int_width;
        c_pos : 0 < c;
        c_reduce1 : c * (Z.ones (int_width - log_cap (pred (length base)))) < 2 ^ log_cap 0;
        c_reduce2 : c < 2 ^ log_cap 0 - c;
        two_pow_k_le_2modulus : 2 ^ k <= 2 * modulus
      }.
End with_base.
Local Hint Resolve lt_1_length_base int_width_pos int_width_compat c_pos
    c_reduce1 c_reduce2 two_pow_k_le_2modulus.

Section Canonicalization.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient modulus prm}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
    {int_width} (preconditions : freezePreconditions prm int_width).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition encodeZ_opt := Eval compute in Pow2Base.encodeZ.

  Definition modulus_digits_opt_sig :
    { b : list Z | b = modulus_digits }.
  Proof.
    eexists.
    cbv beta iota delta [modulus_digits].
    change Pow2Base.encodeZ with encodeZ_opt.
    reflexivity.
  Defined.

  Definition modulus_digits_opt : list Z
    := Eval cbv [proj1_sig modulus_digits_opt_sig] in proj1_sig (modulus_digits_opt_sig).

  Definition modulus_digits_opt_correct
    : modulus_digits_opt = modulus_digits
    := proj2_sig (modulus_digits_opt_sig).

  Definition carry_full_3_opt_cps_sig
             {T} (f : digits -> T)
             (us : digits)
    : { d : T | d = f (carry_full (carry_full (carry_full us))) }.
  Proof.
    eexists.
    transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt_cps c_ f)) us).
    Focus 2. {
      rewrite !carry_full_opt_cps_correct by assumption; reflexivity.
    }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_full_3_opt_cps {T} (f : digits -> T) (us : digits) : T
    := Eval cbv [proj1_sig carry_full_3_opt_cps_sig] in proj1_sig (carry_full_3_opt_cps_sig f us).

  Definition carry_full_3_opt_cps_correct {T} (f : digits -> T) us :
    carry_full_3_opt_cps f us = f (carry_full (carry_full (carry_full us))) :=
    proj2_sig (carry_full_3_opt_cps_sig f us).

  Definition freeze_opt_sig (us : digits) :
    { b : digits | b = freeze us }.
  Proof.
    eexists.
    cbv [freeze conditional_subtract_modulus].
    rewrite <-from_list_default_eq with (d := 0%Z).
    change (@from_list_default Z) with (@from_list_default_opt Z).
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (to_list (length limb_widths) (carry_full (carry_full (carry_full us)))) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (to_list (length limb_widths)  (carry_full (carry_full (carry_full us)))) RHSf).
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (carry_full (carry_full (carry_full us))) in RHS) with ?RHSf _ => RHSf end in
    rewrite <-carry_full_3_opt_cps_correct with (f := RHSf).
    cbv beta iota delta [ge_modulus ge_modulus'].
    change length with length_opt.
    change (nth_default 0 modulus_digits) with (nth_default_opt 0 modulus_digits_opt).
    change @max_ones with max_ones_opt.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    change minus with minus_opt.
    change @map with @map_opt.
    change Z.sub with Z_sub_opt at 1.
    rewrite <-modulus_digits_opt_correct.
    reflexivity.
  Defined.

  Definition freeze_opt (us : digits) : digits
    := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us).

  Definition freeze_opt_correct us
    : freeze_opt us = freeze us
    := proj2_sig (freeze_opt_sig us).
(*
  Lemma freeze_opt_canonical: forall us vs x,
    @pre_carry_bounds _ _ int_width us -> rep us x ->
    @pre_carry_bounds _ _ int_width vs -> rep vs x ->
    freeze_opt us = freeze_opt vs.
  Proof.
    intros.
    rewrite !freeze_opt_correct.
    eapply freeze_canonical with (B := int_width); eauto.
  Qed.

  Lemma freeze_opt_preserves_rep : forall us x, rep us x ->
    rep (freeze_opt us) x.
  Proof.
    intros.
    rewrite freeze_opt_correct.
    eapply freeze_preserves_rep; eauto.
  Qed.

  Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x ->
    @pre_carry_bounds _ _ int_width us ->
    @pre_carry_bounds _ _ int_width vs ->
    (rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs).
  Proof.
    split; eauto using freeze_opt_canonical.
    auto using freeze_opt_preserves_rep.
  Qed.
*)
End Canonicalization.