aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar jadephilipoom <jade.philipoom@gmail.com>2017-02-22 18:44:33 -0500
committerGravatar Andres Erbsen <andreser@mit.edu>2017-02-22 18:44:33 -0500
commitce10def144ca9a21c3b1ca4a262b1c94336513e5 (patch)
tree02d40658aee71f6170032ee360a0fb03fa23974f
parent57a0a97fdbeee2954128d0917d534a7ed8c433cb (diff)
Merge new base system (#112)
* Added sketch of new low-level base system code * Implemented and proved addition * Implemented carrying, which requires defining over Z rather than arbitrary ring * Proved carry and proved ring-ness of base system ops * Implemented split operation * Started implementing modular reduction * NewBaseSystem: prettify some proofs * andres base * improve andresbase * new base * first draft of goldilocks karatsuba * Factored out goldilocks karatsuba * Implement and prove karatsuba * goldilocks cleanup remodularize * merge karatsuba and goldilocs karatsuba parameter blocks * carry impl and proofs (not yet synthesis-ready) * newbasesystem: use rewrite databases * carry index range fix (TODO: allow for carry-then-reduce) * simpler carry implementation * Added compact operation for saturated base systems (this handles carries after multiplying or adding) * debugging reduction for compact_rows * rewrote compact * Converted saturated section to CPS * some progress on cps conversion for non-saturated stuff * Converted associational non-saturated code to CPS, temporarily commented out examples * pushed cps conversion through Positional * moved list/tuple stuff to top of file * proved lingering lemma * worked on generic-style goal for simplified operations * finished proving the generic-form example goal, revising a couple earlier lemmas * revised previous lemmas * finished revising previous lemmas * removed commented-out code * fixed non-terminating string in comment * fix for 8.5 * removed old file * better automation part 1 * better automation part 2 (goodbye proofs) * better automation part 3/3 * some work on freeze * remove saturated code and clean up exported-operations code * Move helper lemmas for list/tuple CPS stuff to new CPSUtil file * qualified imports * fix runtime notations and module-level Let as per comments * moved push_id to CPSUtil and cancel_pair lemmas to Prod * fixed typo * correctly generalized and moved lift_tuple2 (now called lift2_sig) and converted chained_carries into a fold * moved karatsuba section to new file * rename lemmas and definitions (now cps definitions are consistently <name>_cps and non-cps equivalents have no suffix) * updated timing on mulT * renamed push_eval to push_basesystem_eval
-rw-r--r--_CoqProject2
-rw-r--r--src/Karatsuba.v49
-rw-r--r--src/NewBaseSystem.v458
-rw-r--r--src/Util/CPSUtil.v244
-rw-r--r--src/Util/Prod.v4
-rw-r--r--src/Util/Sigma.v9
-rw-r--r--src/Util/ZUtil.v1
7 files changed, 766 insertions, 1 deletions
diff --git a/_CoqProject b/_CoqProject
index 69fdc4dac..ad5fd28b5 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -7,6 +7,7 @@ src/BaseSystem.v
src/BaseSystemProofs.v
src/EdDSARepChange.v
src/MxDHRepChange.v
+src/NewBaseSystem.v
src/Testbit.v
src/Algebra/ZToRing.v
src/Assembly/Bounds.v
@@ -438,6 +439,7 @@ src/Util/AdditionChainExponentiation.v
src/Util/AutoRewrite.v
src/Util/Bool.v
src/Util/CaseUtil.v
+src/Util/CPSUtil.v
src/Util/Curry.v
src/Util/Decidable.v
src/Util/Equality.v
diff --git a/src/Karatsuba.v b/src/Karatsuba.v
new file mode 100644
index 000000000..47ae2facf
--- /dev/null
+++ b/src/Karatsuba.v
@@ -0,0 +1,49 @@
+Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
+Require Import Crypto.Util.ZUtil.
+Local Open Scope Z_scope.
+
+Section Karatsuba.
+ Context {T : Type} (eval : T -> Z)
+ (sub : T -> T -> T)
+ (eval_sub : forall x y, eval (sub x y) = eval x - eval y)
+ (mul : T -> T -> T)
+ (eval_mul : forall x y, eval (mul x y) = eval x * eval y)
+ (add : T -> T -> T)
+ (eval_add : forall x y, eval (add x y) = eval x + eval y)
+ (scmul : Z -> T -> T)
+ (eval_scmul : forall c x, eval (scmul c x) = c * eval x)
+ (split : Z -> T -> T * T)
+ (eval_split : forall s x, s <> 0 -> eval (fst (split s x)) + s * (eval (snd (split s x))) = eval x)
+ .
+
+ Definition karatsuba_mul s (x y : T) : T :=
+ let xab := split s x in
+ let yab := split s y in
+ let xy0 := mul (fst xab) (fst yab) in
+ let xy2 := mul (snd xab) (snd yab) in
+ let xy1 := sub (mul (add (fst xab) (snd xab)) (add (fst yab) (snd yab))) (add xy2 xy0) in
+ add (add (scmul (s^2) xy2) (scmul s xy1)) xy0.
+
+ Lemma eval_karatsuba_mul s x y (s_nonzero:s <> 0) :
+ eval (karatsuba_mul s x y) = eval x * eval y.
+ Proof. cbv [karatsuba_mul]; repeat rewrite ?eval_sub, ?eval_mul, ?eval_add, ?eval_scmul.
+ rewrite <-(eval_split s x), <-(eval_split s y) by assumption; ring. Qed.
+
+
+ Definition goldilocks_mul s (xs ys : T) : T :=
+ let a_b := split s xs in
+ let c_d := split s ys in
+ let ac := mul (fst a_b) (fst c_d) in
+ (add (add ac (mul (snd a_b) (snd c_d)))
+ (scmul s (sub (mul (add (fst a_b) (snd a_b)) (add (fst c_d) (snd c_d))) ac))).
+
+ Local Existing Instances Z.equiv_modulo_Reflexive RelationClasses.eq_Reflexive Z.equiv_modulo_Symmetric Z.equiv_modulo_Transitive Z.mul_mod_Proper Z.add_mod_Proper Z.modulo_equiv_modulo_Proper.
+
+ Lemma goldilocks_mul_correct (p : Z) (p_nonzero : p <> 0) s (s_nonzero : s <> 0) (s2_modp : (s^2) mod p = (s+1) mod p) xs ys :
+ (eval (goldilocks_mul s xs ys)) mod p = (eval xs * eval ys) mod p.
+ Proof. cbv [goldilocks_mul]; Zmod_to_equiv_modulo.
+ repeat rewrite ?eval_mul, ?eval_add, ?eval_sub, ?eval_scmul, <-?(eval_split s xs), <-?(eval_split s ys) by assumption; ring_simplify.
+ setoid_rewrite s2_modp.
+ apply f_equal2; nsatz. Qed.
+End Karatsuba.
diff --git a/src/NewBaseSystem.v b/src/NewBaseSystem.v
new file mode 100644
index 000000000..549ec84a0
--- /dev/null
+++ b/src/NewBaseSystem.v
@@ -0,0 +1,458 @@
+Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
+Require Import Coq.ZArith.BinIntDef.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
+Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn.
+Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
+Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.
+
+Require Import Coq.Lists.List. Import ListNotations.
+Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
+
+Local Ltac prove_id :=
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress simpl
+ | _ => progress cbv [Let_In]
+ | _ => progress (autorewrite with uncps push_id in * )
+ | _ => break_if
+ | _ => break_match
+ | _ => contradiction
+ | _ => reflexivity
+ | _ => nsatz
+ | _ => solve [auto]
+ end.
+
+Create HintDb push_basesystem_eval discriminated.
+Local Ltac prove_eval :=
+ repeat match goal with
+ | _ => progress intros
+ | _ => progress simpl
+ | _ => progress cbv [Let_In]
+ | _ => progress (autorewrite with push_basesystem_eval uncps push_id cancel_pair in * )
+ | _ => break_if
+ | _ => break_match
+ | _ => split
+ | H : _ /\ _ |- _ => destruct H
+ | H : Some _ = Some _ |- _ => progress (inversion H; subst)
+ | _ => discriminate
+ | _ => reflexivity
+ | _ => nsatz
+ end.
+
+Delimit Scope runtime_scope with RT.
+Definition runtime_mul := Z.mul.
+Global Notation "a * b" := (runtime_mul a%RT b%RT) : runtime_scope.
+Definition runtime_add := Z.add.
+Global Notation "a + b" := (runtime_add a%RT b%RT) : runtime_scope.
+Definition runtime_fst {A B} := @fst A B.
+Definition runtime_snd {A B} := @snd A B.
+
+Module B.
+ Local Definition limb := (Z*Z)%type. (* position coefficient and run-time value *)
+ Module Associational.
+ Definition eval (p:list limb) : Z :=
+ List.fold_right Z.add 0%Z (List.map (fun t => fst t * snd t) p).
+
+ Lemma eval_nil : eval nil = 0. Proof. reflexivity. Qed.
+ Lemma eval_cons p q : eval (p::q) = (fst p) * (snd p) + eval q. Proof. reflexivity. Qed.
+ Lemma eval_app p q: eval (p++q) = eval p + eval q.
+ Proof. induction p; simpl eval; rewrite ?eval_nil, ?eval_cons; nsatz. Qed.
+ Hint Rewrite eval_nil eval_cons eval_app : push_basesystem_eval.
+
+ Definition multerm (t t' : limb) : limb :=
+ (fst t * fst t', (snd t * snd t')%RT).
+ Definition mul_cps (p q:list limb) {T} (f : list limb->T) :=
+ flat_map_cps (fun t => @map_cps _ _ (multerm t) q) p f.
+ Definition mul (p q:list limb) := mul_cps p q id.
+ Hint Opaque mul : uncps.
+ Lemma eval_map_mul (a:limb) (q:list limb) : eval (List.map (multerm a) q) = fst a * snd a * eval q.
+ Proof.
+ induction q; cbv [multerm]; simpl List.map;
+ autorewrite with push_basesystem_eval cancel_pair; nsatz.
+ Qed. Hint Rewrite eval_map_mul : push_basesystem_eval.
+ Lemma mul_cps_id p q: forall {T} f,
+ @mul_cps p q T f = f (mul p q).
+ Proof. cbv [mul_cps mul]; prove_id. Qed. Hint Rewrite mul_cps_id : uncps.
+ Lemma eval_mul_noncps p q:
+ eval (mul p q) = eval p * eval q.
+ Proof.
+ cbv [mul mul_cps]; induction p; prove_eval. Qed. Hint Rewrite eval_mul_noncps : push_basesystem_eval.
+
+ Fixpoint split (s:Z) (xs:list limb)
+ {T} (f :list limb*list limb->T) :=
+ match xs with
+ | nil => f (nil, nil)
+ | cons x xs' =>
+ split s xs'
+ (fun sxs' =>
+ if dec (fst x mod s = 0)
+ then f (fst sxs', cons (fst x / s, snd x) (snd sxs'))
+ else f (cons x (fst sxs'), snd sxs'))
+ end.
+ Definition split_noncps s xs := split s xs id.
+ Hint Opaque split_noncps : uncps.
+ Lemma split_id s p: forall {T} f,
+ @split s p T f = f (split_noncps s p).
+ Proof.
+ induction p;
+ repeat match goal with
+ | _ => rewrite IHp
+ | _ => progress (cbv [split_noncps]; prove_id)
+ end.
+ Qed. Hint Rewrite split_id : uncps.
+ Lemma eval_split_noncps s p (s_nonzero:s<>0):
+ eval (fst (split_noncps s p)) + s*eval (snd (split_noncps s p)) = eval p.
+ Proof.
+ cbv [split_noncps]; induction p; prove_eval.
+ match goal with H:_ |- _ =>
+ unique pose proof (Z_div_exact_full_2 _ _ s_nonzero H)
+ end; nsatz.
+ Qed. Hint Rewrite @eval_split_noncps using auto : push_basesystem_eval.
+
+ Definition reduce_cps (s:Z) (c:list limb) (p:list limb)
+ {T} (f : list limb->T) :=
+ split s p (fun ab =>mul_cps c (snd ab) (fun rr =>f (fst ab ++ rr))).
+ Definition reduce s c p := reduce_cps s c p id.
+ Hint Opaque reduce : uncps.
+ Lemma reduction_rule a b s c (modulus_nonzero:s-c<>0) :
+ (a + s * b) mod (s - c) = (a + c * b) mod (s - c).
+ Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz.
+ rewrite Z.add_mod, Z_mod_mult, Z.add_0_r, Z.mod_mod; trivial. Qed.
+ Lemma reduce_cps_id s c p {T} f:
+ @reduce_cps s c p T f = f (reduce s c p).
+ Proof. cbv [reduce_cps reduce]; prove_id. Qed. Hint Rewrite reduce_cps_id : uncps.
+ Lemma eval_reduce s c p (s_nonzero:s<>0) (modulus_nonzero:s-eval c<>0):
+ eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c).
+ Proof.
+ cbv [reduce reduce_cps]; prove_eval;
+ rewrite <-reduction_rule by auto; prove_eval.
+ Qed. Hint Rewrite eval_reduce : push_basesystem_eval.
+
+ Section Carries.
+ Context {modulo div:Z->Z->Z}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+
+ Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) :=
+ if dec (fst t = w)
+ then dlet d := div (snd t) fw in
+ dlet m := modulo (snd t) fw in
+ f ((w*fw, d) :: (w, m) :: @nil limb)
+ else f [t].
+ Definition carry_cps(w fw:Z) (p:list limb) {T} (f:list limb->T) :=
+ flat_map_cps (carryterm_cps w fw) p f.
+ Definition carryterm w fw t := carryterm_cps w fw t id.
+ Hint Opaque carryterm : uncps.
+ Definition carry w fw p := carry_cps w fw p id.
+ Hint Opaque carry : uncps.
+ Lemma carryterm_cps_id w fw t {T} f :
+ @carryterm_cps w fw t T f
+ = f (@carryterm w fw t).
+ Proof. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed. Hint Rewrite carryterm_cps_id : uncps.
+ Lemma eval_carryterm w fw (t:limb) (fw_nonzero:fw<>0):
+ eval (carryterm w fw t) = eval [t].
+ Proof.
+ cbv [carryterm_cps carryterm Let_In]; prove_eval.
+ specialize (div_mod (snd t) fw fw_nonzero).
+ nsatz.
+ Qed. Hint Rewrite eval_carryterm using auto : push_basesystem_eval.
+ Lemma carry_cps_id w fw p {T} f:
+ @carry_cps w fw p T f = f (carry w fw p).
+ Proof. cbv [carry_cps carry]; prove_id. Qed.
+ Hint Rewrite carry_cps_id : uncps.
+ Lemma eval_carry w fw p (fw_nonzero:fw<>0):
+ eval (carry w fw p) = eval p.
+ Proof. cbv [carry_cps carry]; induction p; prove_eval. Qed.
+ Hint Rewrite eval_carry using auto : push_basesystem_eval.
+ End Carries.
+
+ Section Saturated.
+ Context {word_max : Z} {word_max_pos : 1 < word_max}
+ {add : Z -> Z -> Z * Z}
+ {add_correct : forall x y, fst (add x y) + word_max * snd (add x y) = x + y}
+ {mul : Z -> Z -> Z * Z}
+ {mul_correct : forall x y, fst (mul x y) + word_max * snd (mul x y) = x * y}
+ {end_wt:Z} {end_wt_pos : 0 < end_wt}
+ .
+
+ Definition sat_multerm_cps (t t' : limb) {T} (f:list limb->T) :=
+ dlet tt' := mul (snd t) (snd t') in
+ f ((fst t*fst t', runtime_fst tt') :: (fst t*fst t'*word_max, runtime_snd tt') :: nil)%list.
+ Definition sat_mul_cps (p q : list limb) {T} (f:list limb->T) :=
+ flat_map_cps (fun t => @flat_map_cps _ _ (sat_multerm_cps t) q) p f.
+ (* TODO (jgross): kind of an interesting behavior--it infers the type arguments like this but fails to check if I leave them implicit *)
+ Definition sat_multerm t t' := sat_multerm_cps t t' id.
+ Definition sat_mul p q := sat_mul_cps p q id.
+ Hint Opaque sat_multerm sat_mul : uncps.
+ Lemma sat_multerm_cps_id t t' : forall {T} (f:list limb->T),
+ sat_multerm_cps t t' f = f (sat_multerm t t').
+ Proof. reflexivity. Qed. Hint Rewrite sat_multerm_cps_id : uncps.
+ Lemma eval_map_sat_multerm_cps t q :
+ eval (flat_map (fun x => sat_multerm_cps t x id) q) = fst t * snd t * eval q.
+ Proof.
+ cbv [sat_multerm sat_multerm_cps Let_In runtime_fst runtime_snd];
+ induction q; prove_eval;
+ try match goal with |- context [mul ?a ?b] =>
+ specialize (mul_correct a b) end;
+ nsatz.
+ Qed. Hint Rewrite eval_map_sat_multerm_cps : push_basesystem_eval.
+ Lemma sat_mul_cps_id p q {T} f : @sat_mul_cps p q T f = f (sat_mul p q).
+ Proof. cbv [sat_mul_cps sat_mul]; prove_id. Qed. Hint Rewrite sat_mul_cps_id : uncps.
+ Lemma eval_sat_mul p q : eval (sat_mul p q) = eval p * eval q.
+ Proof. cbv [sat_mul_cps sat_mul]; induction p; prove_eval. Qed.
+ Hint Rewrite eval_sat_mul : push_basesystem_eval.
+
+ End Saturated.
+ End Associational.
+ Hint Rewrite
+ @Associational.sat_mul_cps_id
+ @Associational.sat_multerm_cps_id
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ @Associational.reduce_cps_id
+ @Associational.split_id
+ @Associational.mul_cps_id : uncps.
+
+ Module Positional.
+ Section Positional.
+ Import Associational.
+ Context (weight : nat -> Z) (* [weight i] is the weight of position [i] *)
+ (weight_0 : weight 0%nat = 1%Z)
+ (weight_nonzero : forall i, weight i <> 0).
+
+ (** Converting from positional to associational *)
+
+ Definition to_associational_cps {n:nat} (xs:tuple Z n)
+ {T} (f:list limb->T) :=
+ map_cps weight (seq 0 n)
+ (fun r =>
+ to_list_cps n xs (fun rr => combine_cps r rr f)).
+ Definition to_associational {n} xs := @to_associational_cps n xs _ id.
+ Definition eval {n} x := @to_associational_cps n x _ Associational.eval.
+ Lemma to_associational_cps_id {n} x {T} f:
+ @to_associational_cps n x T f = f (to_associational x).
+ Proof. cbv [to_associational_cps to_associational]; prove_id. Qed.
+ Hint Rewrite @to_associational_cps_id : uncps.
+ Lemma eval_to_associational {n} x :
+ Associational.eval (@to_associational n x) = eval x.
+ Proof. cbv [to_associational_cps eval to_associational]; prove_eval. Qed.
+ Hint Rewrite @eval_to_associational : push_basesystem_eval.
+
+ (** Converting from associational to positional *)
+
+ Program Definition zeros n : tuple Z n := Tuple.from_list n (List.map (fun _ => 0) (List.seq 0 n)) _.
+ Next Obligation. autorewrite with distr_length; reflexivity. Qed.
+ Lemma eval_zeros n : eval (zeros n) = 0.
+ Proof.
+ cbv [eval Associational.eval to_associational_cps zeros];
+ autorewrite with uncps; rewrite Tuple.to_list_from_list.
+ generalize dependent (List.seq 0 n); intro xs; induction xs; simpl; nsatz.
+ Qed. Hint Rewrite eval_zeros : push_basesystem_eval.
+
+ Definition add_to_nth_cps {n} i x t {T} (f:tuple Z n->T) :=
+ @on_tuple_cps _ _ 0 (update_nth_cps i (runtime_add x)) n n t _ f.
+ Definition add_to_nth {n} i x t := @add_to_nth_cps n i x t _ id.
+ Hint Opaque add_to_nth : uncps.
+ Lemma add_to_nth_cps_id {n} i x xs {T} f:
+ @add_to_nth_cps n i x xs T f = f (add_to_nth i x xs).
+ Proof.
+ cbv [add_to_nth_cps add_to_nth]; erewrite !on_tuple_cps_correct
+ by (intros; autorewrite with uncps; reflexivity); prove_id.
+ Unshelve.
+ intros; subst. autorewrite with uncps push_id. distr_length.
+ Qed. Hint Rewrite @add_to_nth_cps_id : uncps.
+ Lemma eval_add_to_nth {n} (i:nat) (x:Z) (H:(i<n)%nat) (xs:tuple Z n):
+ eval (@add_to_nth n i x xs) = weight i * x + eval xs.
+ Proof.
+ cbv [eval to_associational_cps add_to_nth add_to_nth_cps runtime_add].
+ erewrite on_tuple_cps_correct by (intros; autorewrite with uncps; reflexivity).
+ prove_eval.
+ cbv [Tuple.on_tuple].
+ rewrite !Tuple.to_list_from_list.
+ autorewrite with uncps push_id.
+ rewrite ListUtil.combine_update_nth_r at 1.
+ rewrite <-(update_nth_id i (List.combine _ _)) at 2.
+ rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _ (weight 0, 0)); cbv [ListUtil.splice_nth id];
+ repeat match goal with
+ | _ => progress (apply Zminus_eq; ring_simplify)
+ | _ => progress autorewrite with push_basesystem_eval cancel_pair distr_length
+ | _ => progress rewrite <-?ListUtil.map_nth_default_always, ?map_fst_combine, ?List.firstn_all2, ?ListUtil.map_nth_default_always, ?nth_default_seq_inbouns, ?plus_O_n
+ end; trivial; lia.
+ Unshelve.
+ intros; subst. autorewrite with uncps push_id. distr_length.
+ Qed. Hint Rewrite @eval_add_to_nth using omega : push_basesystem_eval.
+
+ Fixpoint place_cps (t:limb) (i:nat) {T} (f:nat * Z->T) :=
+ if dec (fst t mod weight i = 0)
+ then f (i, let c := fst t / weight i in (c * snd t)%RT)
+ else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+ Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat.
+ Proof. induction n; simpl; break_match; simpl; omega. Qed.
+ Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t.
+ Proof.
+ induction i; cbv [id]; simpl place_cps; break_match;
+ autorewrite with cancel_pair;
+ try find_apply_lem_hyp Z_div_exact_full_2; nsatz || auto.
+ Qed.
+ Definition place t i := place_cps t i id.
+ Hint Opaque place : uncps.
+ Lemma place_cps_id t i {T} f :
+ @place_cps t i T f = f (place t i).
+ Proof. cbv [place]; induction i; prove_id. Qed.
+ Hint Rewrite place_cps_id : uncps.
+ Definition from_associational_cps n (p:list limb) {T} (f:tuple Z n->T):=
+ fold_right_cps (fun t st => place_cps t (pred n) (fun p=> add_to_nth_cps (fst p) (snd p) st id)) (zeros n) p f.
+ Definition from_associational n p := from_associational_cps n p id.
+ Hint Opaque from_associational : uncps.
+ Lemma from_associational_cps_id {n} p {T} f:
+ @from_associational_cps n p T f = f (from_associational n p).
+ Proof. cbv [from_associational_cps from_associational]; prove_id. Qed.
+ Hint Rewrite @from_associational_cps_id : uncps.
+ Lemma eval_from_associational {n} p (n_nonzero:n<>O):
+ eval (from_associational n p) = Associational.eval p.
+ Proof.
+ cbv [from_associational_cps from_associational]; induction p;
+ [|pose proof (place_cps_in_range a (pred n))]; prove_eval.
+ cbv [place]; rewrite weight_place_cps. nsatz.
+ Qed. Hint Rewrite @eval_from_associational using omega : push_basesystem_eval.
+
+ Section Carries.
+ Context {modulo div : Z->Z->Z}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
+ a = b * (div a b) + modulo a b}.
+ Definition carry_cps(index:nat) (p:list limb) {T} (f:list limb->T) :=
+ @Associational.carry_cps modulo div (weight index) (weight (S index) / weight index) p T f.
+ Definition carry i p := carry_cps i p id.
+ Hint Opaque carry : uncps.
+ Lemma carry_cps_id i p {T} f:
+ @carry_cps i p T f = f (carry i p).
+ Proof. cbv [carry_cps carry]; prove_id; rewrite carry_cps_id; reflexivity. Qed.
+ Hint Rewrite carry_cps_id : uncps.
+ Lemma eval_carry i p: weight (S i) / weight i <> 0 ->
+ Associational.eval (carry i p) = Associational.eval p.
+ Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed.
+ Hint Rewrite @eval_carry : push_basesystem_eval.
+ End Carries.
+ End Positional.
+ End Positional.
+ Hint Rewrite
+ @Associational.sat_mul_cps_id
+ @Associational.sat_multerm_cps_id
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ @Associational.reduce_cps_id
+ @Associational.split_id
+ @Associational.mul_cps_id
+ @Positional.carry_cps_id
+ @Positional.from_associational_cps_id
+ @Positional.place_cps_id
+ @Positional.add_to_nth_cps_id
+ @Positional.to_associational_cps_id
+ : uncps.
+ Hint Rewrite
+ @Associational.eval_sat_mul
+ @Associational.eval_mul_noncps
+ @Positional.eval_to_associational
+ @Associational.eval_carry
+ @Associational.eval_carryterm
+ @Associational.eval_reduce
+ @Associational.eval_split_noncps
+ @Positional.eval_carry
+ @Positional.eval_from_associational
+ @Positional.eval_add_to_nth
+ using (omega || assumption) : push_basesystem_eval.
+End B.
+
+Local Coercion Z.of_nat : nat >-> Z.
+Import Coq.Lists.List.ListNotations. Local Open Scope list_scope.
+Import B.
+
+Ltac assert_preconditions :=
+ repeat match goal with
+ | |- context [Positional.from_associational_cps ?wt ?n] =>
+ unique assert (wt 0%nat = 1) by (cbv; congruence)
+ | |- context [Positional.from_associational_cps ?wt ?n] =>
+ unique assert (forall i, wt i <> 0) by (intros; apply Z.pow_nonzero; try (cbv; congruence); solve [zero_bounds])
+ | |- context [Positional.from_associational_cps ?wt ?n] =>
+ unique assert (n <> 0%nat) by (cbv; congruence)
+ | |- context [Positional.carry_cps?wt ?i] =>
+ unique assert (wt (S i) / wt i <> 0) by (cbv; congruence)
+ end.
+
+Ltac op_simplify :=
+ cbv - [runtime_add runtime_mul Let_In];
+ cbv [runtime_add runtime_mul];
+ repeat progress rewrite ?Z.mul_1_l, ?Z.mul_1_r, ?Z.add_0_l, ?Z.add_0_r.
+
+Ltac prove_op sz x :=
+ cbv [Tuple.tuple Tuple.tuple'] in *;
+ repeat match goal with p : _ * Z |- _ => destruct p end;
+ apply lift2_sig;
+ eexists; cbv zeta beta; intros;
+ match goal with |- Positional.eval ?wt _ = ?op (Positional.eval ?wt ?a) (Positional.eval ?wt ?b) =>
+ transitivity (Positional.eval wt (x wt a b))
+ end;
+ [ apply f_equal; op_simplify; reflexivity
+ | assert_preconditions;
+ progress autorewrite with uncps push_id push_basesystem_eval;
+ reflexivity ]
+.
+
+Section Ops.
+ Context
+ (modulo : Z -> Z -> Z)
+ (div : Z -> Z -> Z)
+ (div_mod : forall a b : Z, b <> 0 ->
+ a = b * div a b + modulo a b).
+ Local Infix "^" := tuple : type_scope.
+
+ Let wt := fun i : nat => 2^(25 * (i / 2) + 26 * ((i + 1) / 2)).
+ Let sz := 10%nat.
+ Let sz2 := Eval compute in ((sz * 2) - 1)%nat.
+
+ (* shorthand for many carries in a row *)
+ Definition chained_carries (w : nat -> Z) (p:list B.limb) (idxs : list nat)
+ {T} (f:list B.limb->T) :=
+ fold_right_cps2 (@Positional.carry_cps w modulo div) p idxs f.
+
+ Definition addT :
+ { add : (Z^sz -> Z^sz -> Z^sz)%type &
+ forall a b : Z^sz,
+ let eval {n} := Positional.eval (n := n) wt in
+ eval (add a b) = eval a + eval b }.
+ Proof.
+ prove_op sz (
+ fun wt a b =>
+ Positional.to_associational_cps (n := sz) wt a
+ (fun r => Positional.to_associational_cps (n := sz) wt b
+ (fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id
+ ))).
+ Defined.
+
+
+ Definition mulT :
+ {mul : (Z^sz -> Z^sz -> Z^sz2)%type &
+ forall a b : Z^sz,
+ let eval {n} := Positional.eval (n := n) wt in
+ eval (mul a b) = eval a * eval b }.
+ Proof.
+ let x := (eval cbv [chained_carries seq fold_right_cps2 sz2] in
+ (fun w a b =>
+ Positional.to_associational_cps (n := sz) w a
+ (fun r => Positional.to_associational_cps (n := sz) w b
+ (fun r0 => Associational.mul_cps r r0
+ (fun r1 => Positional.from_associational_cps w sz2 r1
+ (fun r2 => Positional.to_associational_cps w r2
+ (fun r3 => chained_carries w r3 (seq 0 sz2)
+ (fun r13 => Positional.from_associational_cps w sz2 r13 id
+ )))))))) in
+ prove_op sz x.
+ Time Defined. (* Finished transaction in 139.086 secs *)
+
+End Ops.
+
+Eval cbv [projT1 addT lift2_sig proj1_sig] in (projT1 addT).
+Eval cbv [projT1 mulT lift2_sig proj1_sig] in
+ (fun m d div_mod => projT1 (mulT m d div_mod)).
diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v
new file mode 100644
index 000000000..5d2a80399
--- /dev/null
+++ b/src/Util/CPSUtil.v
@@ -0,0 +1,244 @@
+Require Import Coq.Lists.List. Import ListNotations.
+Require Import Coq.ZArith.ZArith Coq.omega.Omega.
+Require Import Crypto.Util.ListUtil Crypto.Util.Tactics.
+Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
+Local Open Scope Z_scope.
+
+Lemma push_id {A} (a:A) : id a = a. reflexivity. Qed.
+Create HintDb push_id discriminated. Hint Rewrite @push_id : push_id.
+
+Lemma update_nth_id {T} i (xs:list T) : ListUtil.update_nth i id xs = xs.
+Proof.
+ revert xs; induction i; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma map_fst_combine {A B} (xs:list A) (ys:list B) : List.map fst (List.combine xs ys) = List.firstn (length ys) xs.
+Proof.
+ revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma map_snd_combine {A B} (xs:list A) (ys:list B) : List.map snd (List.combine xs ys) = List.firstn (length xs) ys.
+Proof.
+ revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma nth_default_seq_inbouns d s n i (H:(i < n)%nat) :
+ List.nth_default d (List.seq s n) i = (s+i)%nat.
+Proof.
+ progress cbv [List.nth_default].
+ rewrite ListUtil.nth_error_seq.
+ break_innermost_match; solve [ trivial | omega ].
+Qed.
+
+Lemma mod_add_mul_full a b c k m : m <> 0 -> c mod m = k mod m ->
+ (a + b * c) mod m = (a + b * k) mod m.
+Proof.
+ intros; rewrite Z.add_mod, Z.mul_mod by auto.
+ match goal with H : _ mod _ = _ mod _ |- _ => rewrite H end.
+ rewrite <-Z.mul_mod, <-Z.add_mod by auto; reflexivity.
+Qed.
+
+(* TODO
+Lemma to_nat_neg : forall x, x < 0 -> Z.to_nat x = 0%nat.
+Proof. destruct x; try reflexivity; intros. pose proof (Pos2Z.is_pos p). omega. Qed.
+ *)
+
+Fixpoint map_cps {A B} (g : A->B) ls
+ {T} (f:list B->T):=
+ match ls with
+ | nil => f nil
+ | a :: t => map_cps g t (fun r => f (g a :: r))
+ end.
+Lemma map_cps_correct {A B} g ls: forall {T} f,
+ @map_cps A B g ls T f = f (map g ls).
+Proof. induction ls; simpl; intros; rewrite ?IHls; reflexivity. Qed.
+Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps.
+
+Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) :=
+ match ls with
+ | nil => f nil
+ | (x::tl)%list => g x (fun r => flat_map_cps g tl (fun rr => f (r ++ rr))%list)
+ end.
+Lemma flat_map_cps_correct {A B} (g:A->forall {T}, (list B->T)->T) ls :
+ forall {T} (f:list B->T),
+ (forall x T h, @g x T h = h (g x id)) ->
+ @flat_map_cps A B g ls T f = f (List.flat_map (fun x => g x id) ls).
+Proof.
+ induction ls; intros; [reflexivity|].
+ simpl flat_map_cps. simpl flat_map.
+ rewrite H; erewrite IHls by eassumption.
+ reflexivity.
+Qed.
+Hint Rewrite @flat_map_cps_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Fixpoint from_list_default'_cps {A} (d y:A) n xs:
+ forall {T}, (Tuple.tuple' A n -> T) -> T:=
+ match n as n0 return (forall {T}, (Tuple.tuple' A n0 ->T) ->T) with
+ | O => fun T f => f y
+ | S n' => fun T f =>
+ match xs with
+ | nil => from_list_default'_cps d d n' nil (fun r => f (r, y))
+ | x :: xs' => from_list_default'_cps d x n' xs' (fun r => f (r, y))
+ end
+ end.
+Lemma from_list_default'_cps_correct {A} n : forall d y l {T} f,
+ @from_list_default'_cps A d y n l T f = f (Tuple.from_list_default' d y n l).
+Proof.
+ induction n; intros; simpl; [reflexivity|].
+ break_match; subst; apply IHn.
+Qed.
+Definition from_list_default_cps {A} (d:A) n (xs:list A) :
+ forall {T}, (Tuple.tuple A n -> T) -> T:=
+ match n as n0 return (forall {T}, (Tuple.tuple A n0 ->T) ->T) with
+ | O => fun T f => f tt
+ | S n' => fun T f =>
+ match xs with
+ | nil => from_list_default'_cps d d n' nil f
+ | x :: xs' => from_list_default'_cps d x n' xs' f
+ end
+ end.
+Lemma from_list_default_cps_correct {A} n : forall d l {T} f,
+ @from_list_default_cps A d n l T f = f (Tuple.from_list_default d n l).
+Proof.
+ destruct n; intros; simpl; [reflexivity|].
+ break_match; auto using from_list_default'_cps_correct.
+Qed.
+Hint Rewrite @from_list_default_cps_correct : uncps.
+Fixpoint to_list'_cps {A} n
+ {T} (f:list A -> T) : Tuple.tuple' A n -> T :=
+ match n as n0 return (Tuple.tuple' A n0 -> T) with
+ | O => fun x => f [x]
+ | S n' => fun (xs: Tuple.tuple' A (S n')) =>
+ let (xs', x) := xs in
+ to_list'_cps n' (fun r => f (x::r)) xs'
+ end.
+Lemma to_list'_cps_correct {A} n: forall t {T} f,
+ @to_list'_cps A n T f t = f (Tuple.to_list' n t).
+Proof.
+ induction n; simpl; intros; [reflexivity|].
+ destruct_head prod. apply IHn.
+Qed.
+Definition to_list_cps' {A} n {T} (f:list A->T)
+ : Tuple.tuple A n -> T :=
+ match n as n0 return (Tuple.tuple A n0 ->T) with
+ | O => fun _ => f nil
+ | S n' => to_list'_cps n' f
+ end.
+Definition to_list_cps {A} n t {T} f :=
+ @to_list_cps' A n T f t.
+Lemma to_list_cps_correct {A} n t {T} f :
+ @to_list_cps A n t T f = f (Tuple.to_list n t).
+Proof. cbv [to_list_cps to_list_cps' Tuple.to_list]; break_match; auto using to_list'_cps_correct. Qed.
+Hint Rewrite @to_list_cps_correct : uncps.
+
+Definition on_tuple_cps {A B} (d:B) (g:list A ->forall {T},(list B->T)->T) {n m}
+ (xs : Tuple.tuple A n) {T} (f:tuple B m ->T) :=
+ to_list_cps n xs (fun r => g r (fun rr => from_list_default_cps d m rr f)).
+Lemma on_tuple_cps_correct {A B} d (g:list A -> forall {T}, (list B->T)->T)
+ {n m} xs {T} f
+ (Hg : forall x {T} h, @g x T h = h (g x id)) : forall H,
+ @on_tuple_cps A B d g n m xs T f = f (@Tuple.on_tuple A B (fun x => g x id) n m H xs).
+Proof.
+ cbv [on_tuple_cps Tuple.on_tuple]; intros.
+ rewrite to_list_cps_correct, Hg, from_list_default_cps_correct.
+ rewrite (Tuple.from_list_default_eq _ _ _ (H _ (Tuple.length_to_list _))).
+ reflexivity.
+Qed. Hint Rewrite @on_tuple_cps_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Fixpoint update_nth_cps {A} n (g:A->A) xs {T} (f:list A->T) :=
+ match n with
+ | O =>
+ match xs with
+ | [] => f []
+ | x' :: xs' => f (g x' :: xs')
+ end
+ | S n' =>
+ match xs with
+ | [] => f []
+ | x' :: xs' => update_nth_cps n' g xs' (fun r => f (x' :: r))
+ end
+ end.
+Lemma update_nth_cps_correct {A} n g: forall xs T f,
+ @update_nth_cps A n g xs T f = f (update_nth n g xs).
+Proof. induction n; intros; simpl; break_match; try apply IHn; reflexivity. Qed.
+Hint Rewrite @update_nth_cps_correct : uncps.
+
+Fixpoint combine_cps {A B} (la :list A) (lb : list B)
+ {T} (f:list (A*B)->T) :=
+ match la with
+ | nil => f nil
+ | a :: tla =>
+ match lb with
+ | nil => f nil
+ | b :: tlb => combine_cps tla tlb (fun lab => f ((a,b)::lab))
+ end
+ end.
+Lemma combine_cps_correct {A B} la: forall lb {T} f,
+ @combine_cps A B la lb T f = f (combine la lb).
+Proof.
+ induction la; simpl combine_cps; simpl combine; intros;
+ try break_match; try apply IHla; reflexivity.
+Qed.
+Hint Rewrite @combine_cps_correct: uncps.
+
+(* differs from fold_right_cps in that the functional argument `g` is also a CPS function *)
+Fixpoint fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) :=
+ match l with
+ | nil => f a0
+ | b :: tl => fold_right_cps2 g a0 tl (fun r => g b r f)
+ end.
+Lemma fold_right_cps2_correct {A B} g a0 l : forall {T} f,
+ (forall b a T h, @g b a T h = h (@g b a A id)) ->
+ @fold_right_cps2 A B g a0 l T f = f (List.fold_right (fun b a => @g b a A id) a0 l).
+Proof.
+ induction l; intros; [reflexivity|].
+ simpl fold_right_cps2. simpl fold_right.
+ rewrite H; erewrite IHl by eassumption.
+ rewrite H; reflexivity.
+Qed.
+Hint Rewrite @fold_right_cps2_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Definition fold_right_no_starter {A} (f:A->A->A) ls : option A :=
+ match ls with
+ | nil => None
+ | cons x tl => Some (List.fold_right f x tl)
+ end.
+Lemma fold_right_min ls x :
+ x = List.fold_right Z.min x ls
+ \/ List.In (List.fold_right Z.min x ls) ls.
+Proof.
+ induction ls; intros; simpl in *; try tauto.
+ match goal with |- context [Z.min ?x ?y] =>
+ destruct (Z.min_spec x y) as [[? Hmin]|[? Hmin]]
+ end; rewrite Hmin; tauto.
+Qed.
+Lemma fold_right_no_starter_min ls : forall x,
+ fold_right_no_starter Z.min ls = Some x ->
+ List.In x ls.
+Proof.
+ cbv [fold_right_no_starter]; intros; destruct ls; try discriminate.
+ inversion H; subst; clear H.
+ destruct (fold_right_min ls z);
+ simpl List.In; tauto.
+Qed.
+Fixpoint fold_right_cps {A B} (g:B->A->A) (a0:A) (l:list B) {T} (f:A->T) :=
+ match l with
+ | nil => f a0
+ | cons a tl => fold_right_cps g a0 tl (fun r => f (g a r))
+ end.
+Lemma fold_right_cps_correct {A B} g a0 l: forall {T} f,
+ @fold_right_cps A B g a0 l T f = f (List.fold_right g a0 l).
+Proof. induction l; intros; simpl; rewrite ?IHl; auto. Qed.
+Hint Rewrite @fold_right_cps_correct : uncps.
+
+Definition fold_right_no_starter_cps {A} g ls {T} (f:option A->T) :=
+ match ls with
+ | nil => f None
+ | cons x tl => f (Some (List.fold_right g x tl))
+ end.
+Lemma fold_right_no_starter_cps_correct {A} g ls {T} f :
+ @fold_right_no_starter_cps A g ls T f = f (fold_right_no_starter g ls).
+Proof.
+ cbv [fold_right_no_starter_cps fold_right_no_starter]; break_match; reflexivity.
+Qed.
+Hint Rewrite @fold_right_no_starter_cps_correct : uncps.
diff --git a/src/Util/Prod.v b/src/Util/Prod.v
index bcd9404a6..6e6c7d3c4 100644
--- a/src/Util/Prod.v
+++ b/src/Util/Prod.v
@@ -16,6 +16,10 @@ Local Arguments f_equal {_ _} _ {_ _} _.
Scheme Equality for prod.
+Definition fst_pair {A B} (a:A) (b:B) : fst (a,b) = a := eq_refl.
+Definition snd_pair {A B} (a:A) (b:B) : snd (a,b) = b := eq_refl.
+Create HintDb cancel_pair discriminated. Hint Rewrite @fst_pair @snd_pair : cancel_pair.
+
(** ** Equality for [prod] *)
Section prod.
(** *** Projecting an equality of a pair to equality of the first components *)
diff --git a/src/Util/Sigma.v b/src/Util/Sigma.v
index 57c82df68..7a1d0cacb 100644
--- a/src/Util/Sigma.v
+++ b/src/Util/Sigma.v
@@ -16,6 +16,15 @@ Local Arguments f_equal {_ _} _ {_ _} _.
(** ** Equality for [sigT] *)
Section sigT.
+ (* Lift foralls out of sigT proofs and leave a sig goal *)
+ Definition lift2_sig {R S T} f (g:R->S)
+ (X : forall a b, {prod | g prod = f a b}) :
+ { op : T -> T -> R & forall a b, g (op a b) = f a b }.
+ Proof.
+ exists (fun a b => proj1_sig (X a b)).
+ exact (fun a b => proj2_sig (X a b)).
+ Defined.
+
(** *** Projecting an equality of a pair to equality of the first components *)
Definition pr1_path {A} {P : A -> Type} {u v : sigT P} (p : u = v)
: projT1 u = projT1 v
diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v
index ee280ca06..4c6e2441d 100644
--- a/src/Util/ZUtil.v
+++ b/src/Util/ZUtil.v
@@ -766,7 +766,6 @@ Module Z.
apply Z.mod_mul, Z.pow_nonzero; omega. }
Qed.
-
Lemma odd_mod : forall a b, (b <> 0)%Z ->
Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a.
Proof.