aboutsummaryrefslogtreecommitdiff
path: root/src/ModularArithmetic/ModularBaseSystemOpt.v
blob: 121e605a3740896b6af297606dfe0a4f72d3452c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
Require Import Crypto.ModularArithmetic.PrimeFieldTheorems.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams.
Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs.
Require Import Crypto.ModularArithmetic.ExtendedBaseVector.
Require Import Crypto.ModularArithmetic.Conversion.
Require Import Crypto.ModularArithmetic.Pow2Base.
Require Import Crypto.ModularArithmetic.Pow2BaseProofs.
Require Import Crypto.BaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations.
Require Import Crypto.ModularArithmetic.ModularBaseSystemList.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListProofs.
Require Import Crypto.ModularArithmetic.ModularBaseSystem.
Require Import Crypto.ModularArithmetic.ModularBaseSystemProofs.
Require Import Coq.Lists.List.
Require Import Crypto.Util.Tuple.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.AdditionChainExponentiation.
Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil Crypto.Util.NatUtil Crypto.Util.CaseUtil.
Import ListNotations.
Require Import Coq.ZArith.ZArith Coq.ZArith.Zpower Coq.ZArith.ZArith Coq.ZArith.Znumtheory.
Require Import Coq.Numbers.Natural.Peano.NPeano.
Require Import Coq.QArith.QArith Coq.QArith.Qround.
Require Import Crypto.Tactics.VerdiTactics.
Require Export Crypto.Util.FixCoqMistakes.
Local Open Scope Z.

(* Computed versions of some functions. *)

Definition plus_opt := Eval compute in plus.

Definition Z_add_opt := Eval compute in Z.add.
Definition Z_sub_opt := Eval compute in Z.sub.
Definition Z_mul_opt := Eval compute in Z.mul.
Definition Z_div_opt := Eval compute in Z.div.
Definition Z_pow_opt := Eval compute in Z.pow.
Definition Z_opp_opt := Eval compute in Z.opp.
Definition Z_min_opt := Eval compute in Z.min.
Definition Z_ones_opt := Eval compute in Z.ones.
Definition Z_of_nat_opt := Eval compute in Z.of_nat.
Definition Z_le_dec_opt := Eval compute in Z_le_dec.
Definition Z_lt_dec_opt := Eval compute in Z_lt_dec.
Definition Z_shiftl_opt := Eval compute in Z.shiftl.
Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by.

Definition nth_default_opt {A} := Eval compute in @nth_default A.
Definition set_nth_opt {A} := Eval compute in @set_nth A.
Definition update_nth_opt {A} := Eval compute in @update_nth A.
Definition map_opt {A B} := Eval compute in @List.map A B.
Definition full_carry_chain_opt := Eval compute in @Pow2Base.full_carry_chain.
Definition length_opt := Eval compute in length.
Definition base_from_limb_widths_opt := Eval compute in @Pow2Base.base_from_limb_widths.
Definition minus_opt := Eval compute in minus.
Definition from_list_default_opt {A} := Eval compute in (@from_list_default A).
Definition sum_firstn_opt {A} := Eval compute in (@sum_firstn A).
Definition zeros_opt := Eval compute in (@zeros).
Definition bit_index_opt := Eval compute in bit_index.
Definition digit_index_opt := Eval compute in digit_index.

(* Some automation that comes in handy when constructing base parameters *)
Ltac opt_step :=
  match goal with
  | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
    => refine (_ : match e with nil => _ | _ => _ end = _);
       destruct e
  end.

Definition limb_widths_from_len_step loop len k :=
  (fun i prev =>
    match i with
    | O => nil
    | S i' => let x := (if (Z.eqb ((k * Z.of_nat (len - i + 1)) mod (Z.of_nat len)) 0)
                        then (k * Z.of_nat (len - i + 1)) / Z.of_nat len
                        else (k * Z.of_nat (len - i + 1)) / Z.of_nat len + 1)in
      x - prev:: (loop i' x)
    end).
Definition limb_widths_from_len len k :=
  (fix loop i prev := limb_widths_from_len_step loop len k i prev) len 0.

Definition brute_force_indices0 lw : bool
  := List.fold_right
       andb true
       (List.map
          (fun i
           => List.fold_right
                andb true
                (List.map
                   (fun j
                    => sum_firstn lw (i + j) <=? sum_firstn lw i + sum_firstn lw j)
                   (seq 0 (length lw - i))))
          (seq 0 (length lw))).

Lemma brute_force_indices_correct0 lw
  : brute_force_indices0 lw = true -> forall i j : nat,
      (i + j < length lw)%nat -> sum_firstn lw (i + j) <= sum_firstn lw i + sum_firstn lw j.
Proof.
  unfold brute_force_indices0.
  progress repeat setoid_rewrite fold_right_andb_true_map_iff.
  setoid_rewrite in_seq.
  setoid_rewrite Z.leb_le.
  eauto with omega.
Qed.

Definition brute_force_indices1 lw : bool
  := List.fold_right
       andb true
       (List.map
          (fun i
           => List.fold_right
                andb true
                (List.map
                   (fun j
                    => let w_sum := sum_firstn lw in
                       sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <=? w_sum i + w_sum j)
                   (seq (length lw - i) (length lw - (length lw - i)))))
          (seq 1 (length lw - 1))).

Lemma brute_force_indices_correct1 lw
  : brute_force_indices1 lw = true -> forall i j : nat,
  (i < length lw)%nat ->
  (j < length lw)%nat ->
  (i + j >= length lw)%nat ->
  let w_sum := sum_firstn lw in
  sum_firstn lw (length lw) + w_sum (i + j - length lw)%nat <= w_sum i + w_sum j.
Proof.
  unfold brute_force_indices1.
  progress repeat setoid_rewrite fold_right_andb_true_map_iff.
  setoid_rewrite in_seq.
  setoid_rewrite Z.leb_le.
  eauto with omega.
Qed.

Ltac construct_params prime_modulus len k :=
  let lwv := (eval cbv in (limb_widths_from_len len k)) in
  let lw := fresh "lw" in pose lwv as lw;
  eapply Build_PseudoMersenneBaseParams with (limb_widths := lw);
  [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto)
  | abstract (cbv; congruence)
  | abstract (refine (@brute_force_indices_correct0 lw _); vm_cast_no_check (eq_refl true))
  | abstract apply prime_modulus
  | abstract (cbv; congruence)
  | abstract (refine (@brute_force_indices_correct1 lw _); vm_cast_no_check (eq_refl true))].

Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits :=
  match limb_widths with
  | nil => nil
  | x :: tail =>
      2 ^ (x + 1) - (2 * c) :: List.map (fun w => 2 ^ (w + 1) - 2) tail
  end.

Ltac compute_preconditions :=
  cbv; intros; repeat match goal with H : _ \/ _ |- _  =>
    destruct H; subst; [ congruence | ] end; (congruence || omega).

Ltac subst_precondition := match goal with
  | [H : ?P, H' : ?P -> _ |- _] => specialize (H' H); clear H
end.

Ltac kill_precondition H :=
  forward H; [abstract (try exact eq_refl; clear; cbv; intros; repeat break_or_hyp; intuition)|];
  subst_precondition.

Section Carries.
  Context `{prm : PseudoMersenneBaseParams}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation base := (Pow2Base.base_from_limb_widths limb_widths).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition carry_gen_opt_sig fc fi i us
    : { d : list Z | (0 <= fi (S (fi i)) < length us)%nat ->
                     d = carry_gen limb_widths fc fi i us}.
  Proof.
    eexists; intros.
    cbv beta iota delta [carry_gen carry_single Z.pow2_mod].
    rewrite add_to_nth_set_nth.
    change @nth_default with @nth_default_opt in *.
    change @set_nth with @set_nth_opt in *.
    change Z.ones with Z_ones_opt.
    rewrite set_nth_nth_default by assumption.
    rewrite <- @beq_nat_eq_nat_dec.
    reflexivity.
  Defined.

  Definition carry_gen_opt fc fi i us := Eval cbv [proj1_sig carry_gen_opt_sig] in
                                                   proj1_sig (carry_gen_opt_sig fc fi i us).

  Definition carry_gen_opt_correct fc fi i us
    : (0 <= fi (S (fi i)) < length us)%nat ->
      carry_gen_opt fc fi i us = carry_gen limb_widths fc fi i us
    := proj2_sig (carry_gen_opt_sig fc fi i us).

  Definition carry_opt_sig
             (i : nat) (b : list Z)
    : { d : list Z | (length b = length limb_widths)
                     -> (i < length limb_widths)%nat
                     -> d = carry i b }.
  Proof.
    eexists ; intros.
    cbv [carry].
    rewrite <-pull_app_if_sumbool.
    cbv beta delta
      [carry carry_and_reduce carry_simple].
    lazymatch goal with
    | [ |- _ = (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:list Z); evar (y:list Z); transitivity (if br then x else y); subst x; subst y
    end.
    Focus 2. {
      cbv zeta.
      break_if; rewrite <-carry_gen_opt_correct by (omega ||
          (replace (length b) with (length limb_widths) by congruence;
           apply Nat.mod_bound_pos; omega)); reflexivity.
    } Unfocus.
    rewrite c_subst.
    rewrite <- @beq_nat_eq_nat_dec.
    cbv [carry_gen_opt].
    reflexivity.
  Defined.

  Definition carry_opt is us := Eval cbv [proj1_sig carry_opt_sig] in
                                          proj1_sig (carry_opt_sig is us).

  Definition carry_opt_correct i us
    : length us = length limb_widths
      -> (i < length limb_widths)%nat
      -> carry_opt i us = carry i us
    := proj2_sig (carry_opt_sig i us).

  Definition carry_sequence_opt_sig (is : list nat) (us : list Z)
    : { b : list Z | (length us = length limb_widths)
                     -> (forall i, In i is -> i < length limb_widths)%nat
                     -> b = carry_sequence is us }.
  Proof.
    eexists. intros H.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt us is).
    Focus 2.
    { induction is; [ reflexivity | ].
      simpl; rewrite IHis, carry_opt_correct.
      - reflexivity.
      - fold (carry_sequence is us). auto using length_carry_sequence.
      - auto using in_eq.
      - intros. auto using in_cons.
      }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_sequence_opt is us := Eval cbv [proj1_sig carry_sequence_opt_sig] in
                                          proj1_sig (carry_sequence_opt_sig is us).

  Definition carry_sequence_opt_correct is us
    : (length us = length limb_widths)
      -> (forall i, In i is -> i < length limb_widths)%nat
      -> carry_sequence_opt is us = carry_sequence is us
    := proj2_sig (carry_sequence_opt_sig is us).

  Definition carry_gen_opt_cps_sig
             {T} fc fi
             (i : nat)
             (f : list Z -> T)
             (b : list Z)
    : { d : T | (0 <= fi (S (fi i)) < length b)%nat -> d = f (carry_gen limb_widths fc fi i b) }.
  Proof.
    eexists. intros H.
    rewrite <-carry_gen_opt_correct by assumption.
    cbv beta iota delta [carry_gen_opt].
    match goal with |- appcontext[?a &' Z_ones_opt _] =>
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (a) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (a) RHSf) end.
    reflexivity.
  Defined.

  Definition carry_gen_opt_cps {T} fc fi i f b
    := Eval cbv beta iota delta [proj1_sig carry_gen_opt_cps_sig] in
                                 proj1_sig (@carry_gen_opt_cps_sig T fc fi i f b).

  Definition carry_gen_opt_cps_correct {T} fc fi i f b :
   (0 <= fi (S (fi i)) < length b)%nat ->
    @carry_gen_opt_cps T fc fi i f b = f (carry_gen limb_widths fc fi i b)
    := proj2_sig (carry_gen_opt_cps_sig fc fi i f b).

  Definition carry_opt_cps_sig
             {T}
             (i : nat)
             (f : list Z -> T)
             (b : list Z)
    : { d : T | (length b = length limb_widths)
                 -> (i < length limb_widths)%nat
                 -> d = f (carry i b) }.
  Proof.
    eexists. intros.
    cbv beta delta
      [carry carry_and_reduce carry_simple].
    rewrite <-pull_app_if_sumbool.
    lazymatch goal with
    | [ |- _ = ?f (if ?br then ?c else ?d) ]
      => let x := fresh "x" in let y := fresh "y" in evar (x:T); evar (y:T); transitivity (if br then x else y); subst x; subst y
    end.
    Focus 2. {
      cbv zeta.
      break_if; rewrite <-carry_gen_opt_cps_correct by (omega ||
          (replace (length b) with (length limb_widths) by congruence;
           apply Nat.mod_bound_pos; omega)); reflexivity.
    } Unfocus.
    rewrite c_subst.
    rewrite <- @beq_nat_eq_nat_dec.
    reflexivity.
  Defined.

  Definition carry_opt_cps {T} i f b
    := Eval cbv beta iota delta [proj1_sig carry_opt_cps_sig] in proj1_sig (@carry_opt_cps_sig T i f b).

  Definition carry_opt_cps_correct {T} i f b :
    (length b = length limb_widths)
    -> (i < length limb_widths)%nat
    -> @carry_opt_cps T i f b = f (carry i b)
    := proj2_sig (carry_opt_cps_sig i f b).

  Definition carry_sequence_opt_cps_sig {T} (is : list nat) (us : list Z)
     (f : list Z -> T)
    : { b : T | (length us = length limb_widths)
                -> (forall i, In i is -> i < length limb_widths)%nat
                -> b = f (carry_sequence is us) }.
  Proof.
    eexists.
    cbv [carry_sequence].
    transitivity (fold_right carry_opt_cps f (List.rev is) us).
    Focus 2.
    {
      assert (forall i, In i (rev is) -> i < length limb_widths)%nat as Hr. {
        subst. intros. rewrite <- in_rev in *. auto. }
      remember (rev is) as ris eqn:Heq.
      rewrite <- (rev_involutive is), <- Heq in H0 |- *.
      clear H0 Heq is.
      rewrite fold_left_rev_right.
      revert H. revert us; induction ris; [ reflexivity | ]; intros.
      { simpl.
        rewrite <- IHris; clear IHris;
          [|intros; apply Hr; right; assumption|auto using length_carry].
        rewrite carry_opt_cps_correct; [reflexivity|congruence|].
        apply Hr; left; reflexivity.
        } }
    Unfocus.
    cbv [carry_opt_cps].
    reflexivity.
  Defined.

  Definition carry_sequence_opt_cps {T} is us (f : list Z -> T) :=
    Eval cbv [proj1_sig carry_sequence_opt_cps_sig] in
      proj1_sig (carry_sequence_opt_cps_sig is us f).

  Definition carry_sequence_opt_cps_correct {T} is us (f : list Z -> T)
    : (length us = length limb_widths)
      -> (forall i, In i is -> i < length limb_widths)%nat
      -> carry_sequence_opt_cps is us f = f (carry_sequence is us)
    := proj2_sig (carry_sequence_opt_cps_sig is us f).

  Lemma full_carry_chain_bounds : forall i, In i (Pow2Base.full_carry_chain limb_widths) ->
    (i < length limb_widths)%nat.
  Proof.
    unfold Pow2Base.full_carry_chain; intros.
    apply Pow2BaseProofs.make_chain_lt; auto.
  Qed.

  Definition carry_full_opt_sig (us : list Z) :
    { b : list Z | (length us = length limb_widths)
                   -> b = carry_full us }.
  Proof.
    eexists;  cbv [carry_full]; intros.
    match goal with |- ?LHS = ?RHS => change (LHS = id RHS) end.
    rewrite <-carry_sequence_opt_cps_correct with (f := id)  by (auto; apply full_carry_chain_bounds).
    change @Pow2Base.full_carry_chain with full_carry_chain_opt.
    reflexivity.
  Defined.

  Definition carry_full_opt (us : list Z) : list Z
    := Eval cbv [proj1_sig carry_full_opt_sig] in proj1_sig (carry_full_opt_sig us).

  Definition carry_full_opt_correct us
    : length us = length limb_widths
      -> carry_full_opt us = carry_full us
    := proj2_sig (carry_full_opt_sig us).

  Definition carry_full_opt_cps_sig
             {T}
             (f : list Z -> T)
             (us : list Z)
    : { d : T | length us = length limb_widths
                -> d = f (carry_full us) }.
  Proof.
    eexists; intros.
    rewrite <- carry_full_opt_correct by auto.
    cbv beta iota delta [carry_full_opt].
    rewrite carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds).
    match goal with |- ?LHS = ?f (?g (carry_sequence ?is ?us)) =>
      change (LHS = (fun x => f (g x)) (carry_sequence is us)) end.
    rewrite <-carry_sequence_opt_cps_correct by (auto || apply full_carry_chain_bounds).
    reflexivity.
  Defined.

  Definition carry_full_opt_cps {T} (f : list Z -> T) (us : list Z) : T
    := Eval cbv [proj1_sig carry_full_opt_cps_sig] in proj1_sig (carry_full_opt_cps_sig f us).

  Definition carry_full_opt_cps_correct {T} us (f : list Z -> T)
    : length us = length limb_widths
      -> carry_full_opt_cps f us = f (carry_full us)
    := proj2_sig (carry_full_opt_cps_sig f us).

End Carries.

Section CarryChain.
  Context `{prm : PseudoMersenneBaseParams} {cc : CarryChain limb_widths}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition carry__opt_sig {T} (f : digits -> T) (us : digits)
    : { x | x = f (carry_ carry_chain us) }.
  Proof.
    eexists.
    cbv [carry_].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change @from_list_default with @from_list_default_opt.
    erewrite <-carry_sequence_opt_cps_correct by eauto using carry_chain_valid, length_to_list.
    cbv [carry_sequence_opt_cps].
    reflexivity.
  Defined.

  Definition carry__opt_cps {T} (f:digits -> T) (us : digits) : T
    := Eval cbv [proj1_sig carry__opt_sig] in proj1_sig (carry__opt_sig f us).

  Definition carry__opt_cps_correct {T} (f:digits -> T) (us : digits)
    : carry__opt_cps f us = f (carry_ carry_chain us)
    := proj2_sig (carry__opt_sig f us).
End CarryChain.

Section Addition.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition add_opt_sig (us vs : digits) : { b : digits | b = add us vs }.
  Proof.
    eexists.
    reflexivity.
  Defined.

  Definition add_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig add_opt_sig] in proj1_sig (add_opt_sig us vs).

  Definition add_opt_correct us vs
    : add_opt us vs = add us vs
    := proj2_sig (add_opt_sig us vs).

  Definition carry_add_opt_sig {T} (f:digits -> T)
    (us vs : digits) : { x | x = f (carry_add carry_chain us vs) }.
  Proof.
    eexists.
    cbv [carry_add].
    rewrite <-carry__opt_cps_correct, <-add_opt_correct.
    cbv [carry_sequence_opt_cps carry__opt_cps add_opt add].
    rewrite to_list_from_list.
    reflexivity.
  Defined.

  Definition carry_add_opt_cps {T} (f:digits -> T) (us vs : digits) : T
    := Eval cbv [proj1_sig carry_add_opt_sig] in proj1_sig (carry_add_opt_sig f us vs).

  Definition carry_add_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
    : carry_add_opt_cps f us vs = f (carry_add carry_chain us vs)
    := proj2_sig (carry_add_opt_sig f us vs).

  Definition carry_add_opt := carry_add_opt_cps id.

  Definition carry_add_opt_correct (us vs : digits)
    : carry_add_opt us vs = carry_add carry_chain us vs :=
    carry_add_opt_cps_correct id us vs.
End Addition.

Section Subtraction.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}.
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition sub_opt_sig (us vs : digits) : { b : digits | b = sub coeff coeff_mod us vs }.
  Proof.
    eexists.
    cbv [BaseSystem.add ModularBaseSystem.sub BaseSystem.sub].
    reflexivity.
  Defined.

  Definition sub_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig sub_opt_sig] in proj1_sig (sub_opt_sig us vs).

  Definition sub_opt_correct us vs
    : sub_opt us vs = sub coeff coeff_mod us vs
    := proj2_sig (sub_opt_sig us vs).

  Definition carry_sub_opt_sig {T} (f:digits -> T)
    (us vs : digits) : { x | x = f (carry_sub carry_chain coeff coeff_mod us vs) }.
  Proof.
    eexists.
    cbv [carry_sub].
    rewrite <-carry__opt_cps_correct, <-sub_opt_correct.
    cbv [carry_sequence_opt_cps carry__opt_cps sub_opt].
    rewrite to_list_from_list.
    reflexivity.
  Defined.

  Definition carry_sub_opt_cps {T} (f:digits -> T) (us vs : digits) : T
    := Eval cbv [proj1_sig carry_sub_opt_sig] in proj1_sig (carry_sub_opt_sig f us vs).

  Definition carry_sub_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
    : carry_sub_opt_cps f us vs = f (carry_sub carry_chain coeff coeff_mod us vs)
    := proj2_sig (carry_sub_opt_sig f us vs).

  Definition carry_sub_opt := carry_sub_opt_cps id.

  Definition carry_sub_opt_correct (us vs : digits)
    : carry_sub_opt us vs = carry_sub carry_chain coeff coeff_mod us vs :=
    carry_sub_opt_cps_correct id us vs.

  Definition opp_opt_sig (us : digits) : { b : digits | b = opp coeff coeff_mod us }.
  Proof.
    eexists.
    cbv [opp].
    rewrite <-sub_opt_correct.
    reflexivity.
  Defined.

  Definition opp_opt (us : digits) : digits
    := Eval cbv [proj1_sig opp_opt_sig] in proj1_sig (opp_opt_sig us).

  Definition opp_opt_correct us
    : opp_opt us = opp coeff coeff_mod us
    := proj2_sig (opp_opt_sig us).

  Definition carry_opp_opt_sig (us : digits) : { b : digits | b = carry_opp carry_chain coeff coeff_mod us }.
  Proof.
    eexists.
    cbv [carry_opp].
    rewrite <-carry_sub_opt_correct.
    reflexivity.
  Defined.

  Definition carry_opp_opt (us : digits) : digits
    := Eval cbv [proj1_sig carry_opp_opt_sig] in proj1_sig (carry_opp_opt_sig us).

  Definition carry_opp_opt_correct us
    : carry_opp_opt us = carry_opp carry_chain coeff coeff_mod us
    := proj2_sig (carry_opp_opt_sig us).

End Subtraction.

Section Multiplication.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient} {cc : CarryChain limb_widths}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition mul_bi'_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := match vsr with
       | [] => []
       | v :: vsr' => (v * crosscoef bs i (length vsr'))%Z :: mul_bi' i vsr' bs
       end.

  Definition mul_bi'_opt_step_sig
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : { l : list Z | l = mul_bi'_step mul_bi' i vsr bs }.
  Proof.
    eexists.
    cbv [mul_bi'_step].
    opt_step.
    { reflexivity. }
    { cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Defined.

  Definition mul_bi'_opt_step
             (mul_bi' : nat -> list Z -> list Z -> list Z)
             (i : nat) (vsr : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul_bi'_opt_step_sig] in
        proj1_sig (mul_bi'_opt_step_sig mul_bi' i vsr bs).

  Fixpoint mul_bi'_opt
           (i : nat) (vsr : list Z) (bs : list Z) {struct vsr}
    : list Z
    := mul_bi'_opt_step mul_bi'_opt i vsr bs.

  Definition mul_bi'_opt_correct
             (i : nat) (vsr : list Z) (bs : list Z)
    : mul_bi'_opt i vsr bs = mul_bi' bs i vsr.
  Proof.
    revert i; induction vsr as [|vsr vsrs IHvsr]; intros.
    { reflexivity. }
    { simpl mul_bi'.
      rewrite <- IHvsr; clear IHvsr.
      unfold mul_bi'_opt, mul_bi'_opt_step.
      apply f_equal2; [ | reflexivity ].
      cbv [crosscoef].
      change Z.div with Z_div_opt.
      change Z.mul with Z_mul_opt at 2.
      change @nth_default with @nth_default_opt.
      reflexivity. }
  Qed.

  Definition mul'_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := match usr with
       | [] => []
       | u :: usr' => BaseSystem.add (mul_each u (mul_bi bs (length usr') vs)) (mul' usr' vs bs)
       end.

  Lemma map_zeros : forall a n l,
    List.map (Z.mul a) (zeros n ++ l) = zeros n ++ List.map (Z.mul a) l.
  Proof.
    induction n; simpl; [ reflexivity | intros; apply f_equal2; [ omega | congruence ] ].
  Qed.

  Definition mul'_opt_step_sig
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : { d : list Z | d = mul'_step mul' usr vs bs }.
  Proof.
    eexists.
    cbv [mul'_step].
    match goal with
    | [ |- _ = match ?e with nil => _ | _ => _ end :> ?T ]
      => refine (_ : match e with nil => _ | _ => _ end = _);
           destruct e
    end.
    { reflexivity. }
    { cbv [mul_each mul_bi].
      rewrite <- mul_bi'_opt_correct.
      rewrite map_zeros.
      change @List.map with @map_opt.
      cbv [zeros].
      reflexivity. }
  Defined.

  Definition mul'_opt_step
             (mul' : list Z -> list Z -> list Z -> list Z)
             (usr vs : list Z) (bs : list Z)
    : list Z
    := Eval cbv [proj1_sig mul'_opt_step_sig] in proj1_sig (mul'_opt_step_sig mul' usr vs bs).

  Fixpoint mul'_opt
           (usr vs : list Z) (bs : list Z)
    : list Z
    := mul'_opt_step mul'_opt usr vs bs.

  Definition mul'_opt_correct
           (usr vs : list Z) (bs : list Z)
    : mul'_opt usr vs bs = mul' bs usr vs.
  Proof.
    revert vs; induction usr as [|usr usrs IHusr]; intros.
    { reflexivity. }
    { simpl.
      rewrite <- IHusr; clear IHusr.
      apply f_equal2; [ | reflexivity ].
      cbv [mul_each mul_bi].
      rewrite map_zeros.
      rewrite <- mul_bi'_opt_correct.
      cbv [zeros].
      reflexivity. }
  Qed.

  Definition mul_opt_sig (us vs : digits) : { b : digits | b = mul us vs }.
  Proof.
    eexists.
    cbv [mul ModularBaseSystemList.mul BaseSystem.mul mul_each mul_bi mul_bi' zeros reduce].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change (@from_list_default Z) with (@from_list_default_opt Z).
    apply f_equal.
    rewrite ext_base_alt by auto using limb_widths_pos with zarith.
    rewrite <- mul'_opt_correct.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    rewrite Z.map_shiftl by apply k_nonneg.
    rewrite c_subst.
    fold k; rewrite k_subst.
    change @List.map with @map_opt.
    change @Z.shiftl_by with @Z_shiftl_by_opt.
    reflexivity.
  Defined.

  Definition mul_opt (us vs : digits) : digits
    := Eval cbv [proj1_sig mul_opt_sig] in proj1_sig (mul_opt_sig us vs).

  Definition mul_opt_correct us vs
    : mul_opt us vs = mul us vs
    := proj2_sig (mul_opt_sig us vs).

  Definition carry_mul_opt_sig {T} (f:digits -> T)
    (us vs : digits) : { x | x = f (carry_mul carry_chain us vs) }.
  Proof.
    eexists.
    cbv [carry_mul].
    rewrite <-carry__opt_cps_correct, <-mul_opt_correct.
    cbv [carry_sequence_opt_cps carry__opt_cps mul_opt].
    erewrite from_list_default_eq.
    rewrite to_list_from_list.
    reflexivity.
    Grab Existential Variables.
    rewrite mul'_opt_correct.
    distr_length.
    assert (0 < length limb_widths)%nat by (pose proof limb_widths_nonnil; destruct limb_widths; congruence || simpl; omega).
    rewrite Min.min_l; break_match; try omega.
    rewrite Max.max_l; omega.
  Defined.

  Definition carry_mul_opt_cps {T} (f:digits -> T) (us vs : digits) : T
    := Eval cbv [proj1_sig carry_mul_opt_sig] in proj1_sig (carry_mul_opt_sig f us vs).

  Definition carry_mul_opt_cps_correct {T} (f:digits -> T) (us vs : digits)
    : carry_mul_opt_cps f us vs = f (carry_mul carry_chain us vs)
    := proj2_sig (carry_mul_opt_sig f us vs).

  Definition carry_mul_opt := carry_mul_opt_cps id.

  Definition carry_mul_opt_correct (us vs : digits)
    : carry_mul_opt us vs = carry_mul carry_chain us vs :=
    carry_mul_opt_cps_correct id us vs.

End Multiplication.

Import Morphisms.
Global Instance Proper_fold_chain {T} {Teq} {Teq_Equivalence : Equivalence Teq}
  : Proper (Logic.eq
              ==> (fun f g => forall x1 x2 y1 y2 : T, Teq x1 x2 -> Teq y1 y2 -> Teq (f x1 y1) (g x2 y2))
              ==> Logic.eq
              ==> SetoidList.eqlistA Teq
              ==> Teq) fold_chain.
Proof.
  do 9 intro.
  subst; induction y1; repeat intro;
    unfold fold_chain; fold @fold_chain.
  + inversion H; assumption || reflexivity.
  + destruct a.
    apply IHy1.
    econstructor; try assumption.
    apply H0; eapply Proper_nth_default; eauto; reflexivity.
Qed.

Section PowInv.
  Context `{prm : PseudoMersenneBaseParams}
          (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
          {cc : CarryChain limb_widths}.
  Local Notation digits := (tuple Z (length limb_widths)).
  Context (one_ : digits) (one_subst : one = one_).

  Fixpoint fold_chain_opt {T} (id : T) op chain acc :=
  match chain with
  | [] => match acc with
          | [] => id
          | ret :: _ => ret
          end
  | (i, j) :: chain' =>
      Let_In (op (nth_default id acc i) (nth_default id acc j))
      (fun ijx => fold_chain_opt id op chain' (ijx :: acc))
  end.

  Lemma fold_chain_opt_correct : forall {T} (id : T) op chain acc,
    fold_chain_opt id op chain acc = fold_chain id op chain acc.
  Proof.
    reflexivity.
  Qed.

  Definition pow_opt_sig x chain :
    {y | eq y (ModularBaseSystem.pow x chain)}.
  Proof.
    eexists.
    cbv beta iota delta [ModularBaseSystem.pow].
    transitivity (fold_chain one_ (carry_mul_opt k_ c_) chain [x]).
    Focus 2. {
      apply Proper_fold_chain; auto; try reflexivity.
      cbv [eq]; intros.
      rewrite carry_mul_opt_correct by assumption.
      rewrite carry_mul_rep, mul_rep by reflexivity.
      congruence.
    } Unfocus.
    rewrite <-fold_chain_opt_correct.
    reflexivity.
  Defined.

  Definition pow_opt x chain : digits
    := Eval cbv [proj1_sig pow_opt_sig] in (proj1_sig (pow_opt_sig x chain)).

  Definition pow_opt_correct x chain
    : eq (pow_opt x chain) (ModularBaseSystem.pow x chain)
    := Eval cbv [proj2_sig pow_opt_sig] in (proj2_sig (pow_opt_sig x chain)).

  Context {ec : ExponentiationChain (modulus - 2)}.

  Definition inv_opt_sig x:
    {y | eq y (inv chain chain_correct x)}.
  Proof.
    eexists.
    cbv [inv].
    rewrite <-pow_opt_correct.
    reflexivity.
  Defined.

  Definition inv_opt x : digits
    := Eval cbv [proj1_sig inv_opt_sig] in (proj1_sig (inv_opt_sig x)).

  Definition inv_opt_correct x
    : eq (inv_opt x) (inv chain chain_correct x)
    := Eval cbv [proj2_sig inv_opt_sig] in (proj2_sig (inv_opt_sig x)).
End PowInv.

Section Conversion.

  Definition convert'_opt_sig {lwA lwB}
             (nonnegA : forall x, In x lwA -> 0 <= x)
             (nonnegB : forall x, In x lwB -> 0 <= x)
             bits_fit inp i out :
    { y | y = convert' nonnegA nonnegB bits_fit inp i out}.
  Proof.
    eexists.
    rewrite convert'_equation.
    change sum_firstn with @sum_firstn_opt.
    change length with length_opt.
    change Z_le_dec with Z_le_dec_opt.
    change Z.of_nat with Z_of_nat_opt.
    change digit_index with digit_index_opt.
    change bit_index with bit_index_opt.
    change Z.min with Z_min_opt.
    change (nth_default 0 lwA) with (nth_default_opt 0 lwA).
    change (nth_default 0 lwB) with (nth_default_opt 0 lwB).
    cbv [update_by_concat_bits concat_bits Z.pow2_mod].
    change Z.ones with Z_ones_opt.
    change @update_nth with @update_nth_opt.
    change plus with plus_opt.
    change Z.sub with Z_sub_opt.
    reflexivity.
  Defined.

  Definition convert'_opt {lwA lwB}
             (nonnegA : forall x, In x lwA -> 0 <= x)
             (nonnegB : forall x, In x lwB -> 0 <= x)
             bits_fit inp i out :=
    Eval cbv [proj1_sig convert'_opt_sig] in
      proj1_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out).

  Definition convert'_opt_correct {lwA lwB}
             (nonnegA : forall x, In x lwA -> 0 <= x)
             (nonnegB : forall x, In x lwB -> 0 <= x)
             bits_fit inp i out :
    convert'_opt nonnegA nonnegB bits_fit inp i out = convert' nonnegA nonnegB bits_fit inp i out :=
    Eval cbv [proj2_sig convert'_opt_sig] in
      proj2_sig (convert'_opt_sig nonnegA nonnegB bits_fit inp i out).

  Context {modulus} (prm : PseudoMersenneBaseParams modulus)
          {target_widths} (target_widths_nonneg : forall x, In x target_widths -> 0 <= x) (bits_eq : sum_firstn limb_widths (length limb_widths) = sum_firstn target_widths (length target_widths)).
  Local Notation digits := (tuple Z (length limb_widths)).
  Local Notation target_digits := (tuple Z (length target_widths)).

  Definition pack_opt_sig (x : digits) : { y | y = pack target_widths_nonneg bits_eq x}.
  Proof.
    eexists.
    cbv [pack].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change @from_list_default with @from_list_default_opt.
    cbv [ModularBaseSystemList.pack convert].
    change length with length_opt.
    change sum_firstn with @sum_firstn_opt.
    change zeros with zeros_opt.
    reflexivity.
  Defined.

  Definition pack_opt (x : digits) : target_digits :=
    Eval cbv [proj1_sig pack_opt_sig] in proj1_sig (pack_opt_sig x).

  Definition pack_correct (x : digits) :
    pack_opt x = pack target_widths_nonneg bits_eq x
    := Eval cbv [proj2_sig pack_opt_sig] in proj2_sig (pack_opt_sig x).

  Definition unpack_opt_sig (x : target_digits) : { y | y = unpack target_widths_nonneg bits_eq x}.
  Proof.
    eexists.
    cbv [unpack].
    rewrite <- from_list_default_eq with (d := 0%Z).
    change @from_list_default with @from_list_default_opt.
    cbv [ModularBaseSystemList.unpack convert].
    change length with length_opt.
    change sum_firstn with @sum_firstn_opt.
    change zeros with zeros_opt.
    reflexivity.
  Defined.

  Definition unpack_opt (x : target_digits) : digits :=
    Eval cbv [proj1_sig unpack_opt_sig] in proj1_sig (unpack_opt_sig x).

  Definition unpack_correct (x : target_digits) :
    unpack_opt x = unpack target_widths_nonneg bits_eq x
    := Eval cbv [proj2_sig unpack_opt_sig] in proj2_sig (unpack_opt_sig x).

End Conversion.

Local Hint Resolve lt_1_length_limb_widths int_width_pos B_pos B_compat
  c_reduce1 c_reduce2.

Section Canonicalization.
  Context `{prm : PseudoMersenneBaseParams} {sc : SubtractionCoefficient}
    (* allows caller to precompute k and c *)
    (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
    {int_width freeze_input_bound}
    (preconditions : FreezePreconditions freeze_input_bound int_width).
  Local Notation digits := (tuple Z (length limb_widths)).

  Definition carry_full_3_opt_sig
             (us : list Z)
    : { d : list Z | length us = length limb_widths
                 -> d = carry_full (carry_full (carry_full us)) }.
  Proof.
    eexists.
    transitivity (carry_full_opt_cps c_ (carry_full_opt_cps c_ (carry_full_opt c_)) us).
    Focus 2. {
      rewrite !carry_full_opt_cps_correct; try rewrite carry_full_opt_correct; repeat (autorewrite with distr_length; rewrite ?length_carry_full; auto).
    }
    Unfocus.
    reflexivity.
  Defined.

  Definition carry_full_3_opt (us : list Z) : list Z
    := Eval cbv [proj1_sig carry_full_3_opt_sig] in proj1_sig (carry_full_3_opt_sig us).

  Definition carry_full_3_opt_correct us
    : length us = length limb_widths
      -> carry_full_3_opt us = carry_full (carry_full (carry_full us))
    := proj2_sig (carry_full_3_opt_sig us).

  Lemma ge_modulus'_cps : forall {A} (f : Z -> A) (us : list Z) i b,
    f (ge_modulus' id us b i) = ge_modulus' f us b i.
  Proof.
    induction i; intros; simpl; cbv [Let_In cmovl cmovne]; break_if; try reflexivity;
      apply IHi.
  Qed.

  Definition ge_modulus_opt_sig (us : list Z) :
    { a : Z | a = ge_modulus us}.
  Proof.
    eexists.
    cbv [ge_modulus ge_modulus'].
    change length with length_opt.
    change nth_default with @nth_default_opt.
    change minus with minus_opt.
    reflexivity.
  Defined.

  Definition ge_modulus_opt us : Z
    := Eval cbv [proj1_sig ge_modulus_opt_sig] in proj1_sig (ge_modulus_opt_sig us).

  Definition ge_modulus_opt_correct us :
    ge_modulus_opt us= ge_modulus us
    := Eval cbv [proj2_sig ge_modulus_opt_sig] in proj2_sig (ge_modulus_opt_sig us).

  Definition conditional_subtract_modulus_opt_sig (f : list Z):
    { g | g = conditional_subtract_modulus int_width f (ge_modulus f) }.
  Proof.
    eexists.
    cbv [conditional_subtract_modulus].
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (neg int_width (ge_modulus f)) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (neg int_width (ge_modulus f)) RHSf).
    cbv [ge_modulus].
    rewrite ge_modulus'_cps.
    cbv beta iota delta [ge_modulus ge_modulus'].
    change length with length_opt.
    change nth_default with @nth_default_opt.
    change @Pow2Base.base_from_limb_widths with base_from_limb_widths_opt.
    change minus with minus_opt.
    reflexivity.
  Defined.

  Definition conditional_subtract_modulus_opt f : list Z
    := Eval cbv [proj1_sig conditional_subtract_modulus_opt_sig] in proj1_sig (conditional_subtract_modulus_opt_sig f).

  Definition conditional_subtract_modulus_opt_correct f
    : conditional_subtract_modulus_opt f = conditional_subtract_modulus int_width f (ge_modulus f)
    := Eval cbv [proj2_sig conditional_subtract_modulus_opt_sig] in proj2_sig (conditional_subtract_modulus_opt_sig f).


  Definition freeze_opt_sig (us : list Z) :
    { b : list Z | length us = length limb_widths
                   -> b = ModularBaseSystemList.freeze int_width us }.
  Proof.
    eexists.
    cbv [ModularBaseSystemList.freeze].
    rewrite <-conditional_subtract_modulus_opt_correct.
    intros.
    rewrite <-carry_full_3_opt_correct by auto.
    let LHS := match goal with |- ?LHS = ?RHS => LHS end in
    let RHS := match goal with |- ?LHS = ?RHS => RHS end in
    let RHSf := match (eval pattern (carry_full_3_opt us) in RHS) with ?RHSf _ => RHSf end in
    change (LHS = Let_In (carry_full_3_opt us) RHSf).
    reflexivity.
  Defined.

  Definition freeze_opt (us : list Z) : list Z
    := Eval cbv beta iota delta [proj1_sig freeze_opt_sig] in proj1_sig (freeze_opt_sig us).

  Definition freeze_opt_correct us
    : length us = length limb_widths
      -> freeze_opt us = ModularBaseSystemList.freeze int_width us
    := proj2_sig (freeze_opt_sig us).

End Canonicalization.

Section SquareRoots.
  Context `{prm : PseudoMersenneBaseParams}.
  Context {cc : CarryChain limb_widths}.
  Local Notation digits := (tuple Z (length limb_widths)).
          (* allows caller to precompute k and c *)
  Context (k_ c_ : Z) (k_subst : k = k_) (c_subst : c = c_)
          (one_ : digits) (one_subst : one = one_).

  (* TODO : where should this lemma go? Alternatively, is there a standard-library
     tactic/lemma for this? *)
  Lemma if_equiv : forall {A} (eqA : A -> A -> Prop) (x0 x1 : bool) y0 y1 z0 z1,
    x0 = x1 -> eqA y0 y1 -> eqA z0 z1 ->
    eqA (if x0 then y0 else z0) (if x1 then y1 else z1).
  Proof.
    intros; repeat break_if; congruence.
  Qed.

  Section SquareRoot3mod4.
  Context {ec : ExponentiationChain (modulus / 4 + 1)}.

  Definition sqrt_3mod4_opt_sig (us : digits) :
    { vs : digits | eq vs (sqrt_3mod4 chain chain_correct us)}.
  Proof.
    eexists; cbv [sqrt_3mod4].
    apply @pow_opt_correct; eassumption.
  Defined.

  Definition sqrt_3mod4_opt us := Eval cbv [proj1_sig sqrt_3mod4_opt_sig] in
    proj1_sig (sqrt_3mod4_opt_sig us).

  Definition sqrt_3mod4_opt_correct us
    : eq (sqrt_3mod4_opt us) (sqrt_3mod4 chain chain_correct us)
    := Eval cbv [proj2_sig sqrt_3mod4_opt_sig] in proj2_sig (sqrt_3mod4_opt_sig us).

  End SquareRoot3mod4.

  Section SquareRoot5mod8.
  Context {ec : ExponentiationChain (modulus / 8 + 1)}.
  Context (sqrt_m1 : digits) (sqrt_m1_correct : rep (mul sqrt_m1 sqrt_m1) (F.opp 1%F)).
  Context {int_width freeze_input_bound}
          (preconditions : FreezePreconditions freeze_input_bound int_width).

  Definition sqrt_5mod8_opt_sig (powx powx_squared us : digits) :
    { vs : digits |
      eq vs (sqrt_5mod8 int_width powx powx_squared chain chain_correct sqrt_m1 us)}.
  Proof.
    cbv [sqrt_5mod8].
    match goal with
      |- appcontext[(if ?P then ?t else mul ?a ?b)] =>
      assert (eq (carry_mul_opt k_ c_ a b) (mul a b))
        by (rewrite carry_mul_opt_correct by auto;
           cbv [eq]; rewrite carry_mul_rep, mul_rep; reflexivity)
    end.
    let RHS := match goal with |- {vs | eq vs ?RHS} => RHS end in
    let RHSf := match (eval pattern powx in RHS) with ?RHSf _ => RHSf end in
    change ({vs | eq vs (Let_In powx RHSf)}).
    match goal with
    | H : eq (?g powx) (?f powx)
    |- {vs | eq vs (Let_In powx (fun x => if ?P then x else ?f x))} =>
      exists (Let_In powx (fun x => if P then x else g x))
    end.
    break_if; try reflexivity.
    cbv [Let_In].
    auto.
  Defined.

  Definition sqrt_5mod8_opt powx powx_squared us := Eval cbv [proj1_sig sqrt_5mod8_opt_sig] in
    proj1_sig (sqrt_5mod8_opt_sig powx powx_squared us).

  Definition sqrt_5mod8_opt_correct powx powx_squared us
    : eq (sqrt_5mod8_opt powx powx_squared us) (ModularBaseSystem.sqrt_5mod8 int_width _ _ chain chain_correct sqrt_m1 us)
    := Eval cbv [proj2_sig sqrt_5mod8_opt_sig] in proj2_sig (sqrt_5mod8_opt_sig powx powx_squared us).

  End SquareRoot5mod8.

End SquareRoots.