aboutsummaryrefslogtreecommitdiff
path: root/src/Util/CPSUtil.v
diff options
context:
space:
mode:
authorGravatar jadephilipoom <jade.philipoom@gmail.com>2017-02-22 18:44:33 -0500
committerGravatar Andres Erbsen <andreser@mit.edu>2017-02-22 18:44:33 -0500
commitce10def144ca9a21c3b1ca4a262b1c94336513e5 (patch)
tree02d40658aee71f6170032ee360a0fb03fa23974f /src/Util/CPSUtil.v
parent57a0a97fdbeee2954128d0917d534a7ed8c433cb (diff)
Merge new base system (#112)
* Added sketch of new low-level base system code * Implemented and proved addition * Implemented carrying, which requires defining over Z rather than arbitrary ring * Proved carry and proved ring-ness of base system ops * Implemented split operation * Started implementing modular reduction * NewBaseSystem: prettify some proofs * andres base * improve andresbase * new base * first draft of goldilocks karatsuba * Factored out goldilocks karatsuba * Implement and prove karatsuba * goldilocks cleanup remodularize * merge karatsuba and goldilocs karatsuba parameter blocks * carry impl and proofs (not yet synthesis-ready) * newbasesystem: use rewrite databases * carry index range fix (TODO: allow for carry-then-reduce) * simpler carry implementation * Added compact operation for saturated base systems (this handles carries after multiplying or adding) * debugging reduction for compact_rows * rewrote compact * Converted saturated section to CPS * some progress on cps conversion for non-saturated stuff * Converted associational non-saturated code to CPS, temporarily commented out examples * pushed cps conversion through Positional * moved list/tuple stuff to top of file * proved lingering lemma * worked on generic-style goal for simplified operations * finished proving the generic-form example goal, revising a couple earlier lemmas * revised previous lemmas * finished revising previous lemmas * removed commented-out code * fixed non-terminating string in comment * fix for 8.5 * removed old file * better automation part 1 * better automation part 2 (goodbye proofs) * better automation part 3/3 * some work on freeze * remove saturated code and clean up exported-operations code * Move helper lemmas for list/tuple CPS stuff to new CPSUtil file * qualified imports * fix runtime notations and module-level Let as per comments * moved push_id to CPSUtil and cancel_pair lemmas to Prod * fixed typo * correctly generalized and moved lift_tuple2 (now called lift2_sig) and converted chained_carries into a fold * moved karatsuba section to new file * rename lemmas and definitions (now cps definitions are consistently <name>_cps and non-cps equivalents have no suffix) * updated timing on mulT * renamed push_eval to push_basesystem_eval
Diffstat (limited to 'src/Util/CPSUtil.v')
-rw-r--r--src/Util/CPSUtil.v244
1 files changed, 244 insertions, 0 deletions
diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v
new file mode 100644
index 000000000..5d2a80399
--- /dev/null
+++ b/src/Util/CPSUtil.v
@@ -0,0 +1,244 @@
+Require Import Coq.Lists.List. Import ListNotations.
+Require Import Coq.ZArith.ZArith Coq.omega.Omega.
+Require Import Crypto.Util.ListUtil Crypto.Util.Tactics.
+Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.
+Local Open Scope Z_scope.
+
+Lemma push_id {A} (a:A) : id a = a. reflexivity. Qed.
+Create HintDb push_id discriminated. Hint Rewrite @push_id : push_id.
+
+Lemma update_nth_id {T} i (xs:list T) : ListUtil.update_nth i id xs = xs.
+Proof.
+ revert xs; induction i; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma map_fst_combine {A B} (xs:list A) (ys:list B) : List.map fst (List.combine xs ys) = List.firstn (length ys) xs.
+Proof.
+ revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma map_snd_combine {A B} (xs:list A) (ys:list B) : List.map snd (List.combine xs ys) = List.firstn (length xs) ys.
+Proof.
+ revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ].
+Qed.
+
+Lemma nth_default_seq_inbouns d s n i (H:(i < n)%nat) :
+ List.nth_default d (List.seq s n) i = (s+i)%nat.
+Proof.
+ progress cbv [List.nth_default].
+ rewrite ListUtil.nth_error_seq.
+ break_innermost_match; solve [ trivial | omega ].
+Qed.
+
+Lemma mod_add_mul_full a b c k m : m <> 0 -> c mod m = k mod m ->
+ (a + b * c) mod m = (a + b * k) mod m.
+Proof.
+ intros; rewrite Z.add_mod, Z.mul_mod by auto.
+ match goal with H : _ mod _ = _ mod _ |- _ => rewrite H end.
+ rewrite <-Z.mul_mod, <-Z.add_mod by auto; reflexivity.
+Qed.
+
+(* TODO
+Lemma to_nat_neg : forall x, x < 0 -> Z.to_nat x = 0%nat.
+Proof. destruct x; try reflexivity; intros. pose proof (Pos2Z.is_pos p). omega. Qed.
+ *)
+
+Fixpoint map_cps {A B} (g : A->B) ls
+ {T} (f:list B->T):=
+ match ls with
+ | nil => f nil
+ | a :: t => map_cps g t (fun r => f (g a :: r))
+ end.
+Lemma map_cps_correct {A B} g ls: forall {T} f,
+ @map_cps A B g ls T f = f (map g ls).
+Proof. induction ls; simpl; intros; rewrite ?IHls; reflexivity. Qed.
+Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps.
+
+Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) :=
+ match ls with
+ | nil => f nil
+ | (x::tl)%list => g x (fun r => flat_map_cps g tl (fun rr => f (r ++ rr))%list)
+ end.
+Lemma flat_map_cps_correct {A B} (g:A->forall {T}, (list B->T)->T) ls :
+ forall {T} (f:list B->T),
+ (forall x T h, @g x T h = h (g x id)) ->
+ @flat_map_cps A B g ls T f = f (List.flat_map (fun x => g x id) ls).
+Proof.
+ induction ls; intros; [reflexivity|].
+ simpl flat_map_cps. simpl flat_map.
+ rewrite H; erewrite IHls by eassumption.
+ reflexivity.
+Qed.
+Hint Rewrite @flat_map_cps_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Fixpoint from_list_default'_cps {A} (d y:A) n xs:
+ forall {T}, (Tuple.tuple' A n -> T) -> T:=
+ match n as n0 return (forall {T}, (Tuple.tuple' A n0 ->T) ->T) with
+ | O => fun T f => f y
+ | S n' => fun T f =>
+ match xs with
+ | nil => from_list_default'_cps d d n' nil (fun r => f (r, y))
+ | x :: xs' => from_list_default'_cps d x n' xs' (fun r => f (r, y))
+ end
+ end.
+Lemma from_list_default'_cps_correct {A} n : forall d y l {T} f,
+ @from_list_default'_cps A d y n l T f = f (Tuple.from_list_default' d y n l).
+Proof.
+ induction n; intros; simpl; [reflexivity|].
+ break_match; subst; apply IHn.
+Qed.
+Definition from_list_default_cps {A} (d:A) n (xs:list A) :
+ forall {T}, (Tuple.tuple A n -> T) -> T:=
+ match n as n0 return (forall {T}, (Tuple.tuple A n0 ->T) ->T) with
+ | O => fun T f => f tt
+ | S n' => fun T f =>
+ match xs with
+ | nil => from_list_default'_cps d d n' nil f
+ | x :: xs' => from_list_default'_cps d x n' xs' f
+ end
+ end.
+Lemma from_list_default_cps_correct {A} n : forall d l {T} f,
+ @from_list_default_cps A d n l T f = f (Tuple.from_list_default d n l).
+Proof.
+ destruct n; intros; simpl; [reflexivity|].
+ break_match; auto using from_list_default'_cps_correct.
+Qed.
+Hint Rewrite @from_list_default_cps_correct : uncps.
+Fixpoint to_list'_cps {A} n
+ {T} (f:list A -> T) : Tuple.tuple' A n -> T :=
+ match n as n0 return (Tuple.tuple' A n0 -> T) with
+ | O => fun x => f [x]
+ | S n' => fun (xs: Tuple.tuple' A (S n')) =>
+ let (xs', x) := xs in
+ to_list'_cps n' (fun r => f (x::r)) xs'
+ end.
+Lemma to_list'_cps_correct {A} n: forall t {T} f,
+ @to_list'_cps A n T f t = f (Tuple.to_list' n t).
+Proof.
+ induction n; simpl; intros; [reflexivity|].
+ destruct_head prod. apply IHn.
+Qed.
+Definition to_list_cps' {A} n {T} (f:list A->T)
+ : Tuple.tuple A n -> T :=
+ match n as n0 return (Tuple.tuple A n0 ->T) with
+ | O => fun _ => f nil
+ | S n' => to_list'_cps n' f
+ end.
+Definition to_list_cps {A} n t {T} f :=
+ @to_list_cps' A n T f t.
+Lemma to_list_cps_correct {A} n t {T} f :
+ @to_list_cps A n t T f = f (Tuple.to_list n t).
+Proof. cbv [to_list_cps to_list_cps' Tuple.to_list]; break_match; auto using to_list'_cps_correct. Qed.
+Hint Rewrite @to_list_cps_correct : uncps.
+
+Definition on_tuple_cps {A B} (d:B) (g:list A ->forall {T},(list B->T)->T) {n m}
+ (xs : Tuple.tuple A n) {T} (f:tuple B m ->T) :=
+ to_list_cps n xs (fun r => g r (fun rr => from_list_default_cps d m rr f)).
+Lemma on_tuple_cps_correct {A B} d (g:list A -> forall {T}, (list B->T)->T)
+ {n m} xs {T} f
+ (Hg : forall x {T} h, @g x T h = h (g x id)) : forall H,
+ @on_tuple_cps A B d g n m xs T f = f (@Tuple.on_tuple A B (fun x => g x id) n m H xs).
+Proof.
+ cbv [on_tuple_cps Tuple.on_tuple]; intros.
+ rewrite to_list_cps_correct, Hg, from_list_default_cps_correct.
+ rewrite (Tuple.from_list_default_eq _ _ _ (H _ (Tuple.length_to_list _))).
+ reflexivity.
+Qed. Hint Rewrite @on_tuple_cps_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Fixpoint update_nth_cps {A} n (g:A->A) xs {T} (f:list A->T) :=
+ match n with
+ | O =>
+ match xs with
+ | [] => f []
+ | x' :: xs' => f (g x' :: xs')
+ end
+ | S n' =>
+ match xs with
+ | [] => f []
+ | x' :: xs' => update_nth_cps n' g xs' (fun r => f (x' :: r))
+ end
+ end.
+Lemma update_nth_cps_correct {A} n g: forall xs T f,
+ @update_nth_cps A n g xs T f = f (update_nth n g xs).
+Proof. induction n; intros; simpl; break_match; try apply IHn; reflexivity. Qed.
+Hint Rewrite @update_nth_cps_correct : uncps.
+
+Fixpoint combine_cps {A B} (la :list A) (lb : list B)
+ {T} (f:list (A*B)->T) :=
+ match la with
+ | nil => f nil
+ | a :: tla =>
+ match lb with
+ | nil => f nil
+ | b :: tlb => combine_cps tla tlb (fun lab => f ((a,b)::lab))
+ end
+ end.
+Lemma combine_cps_correct {A B} la: forall lb {T} f,
+ @combine_cps A B la lb T f = f (combine la lb).
+Proof.
+ induction la; simpl combine_cps; simpl combine; intros;
+ try break_match; try apply IHla; reflexivity.
+Qed.
+Hint Rewrite @combine_cps_correct: uncps.
+
+(* differs from fold_right_cps in that the functional argument `g` is also a CPS function *)
+Fixpoint fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) :=
+ match l with
+ | nil => f a0
+ | b :: tl => fold_right_cps2 g a0 tl (fun r => g b r f)
+ end.
+Lemma fold_right_cps2_correct {A B} g a0 l : forall {T} f,
+ (forall b a T h, @g b a T h = h (@g b a A id)) ->
+ @fold_right_cps2 A B g a0 l T f = f (List.fold_right (fun b a => @g b a A id) a0 l).
+Proof.
+ induction l; intros; [reflexivity|].
+ simpl fold_right_cps2. simpl fold_right.
+ rewrite H; erewrite IHl by eassumption.
+ rewrite H; reflexivity.
+Qed.
+Hint Rewrite @fold_right_cps2_correct using (intros; autorewrite with uncps; auto): uncps.
+
+Definition fold_right_no_starter {A} (f:A->A->A) ls : option A :=
+ match ls with
+ | nil => None
+ | cons x tl => Some (List.fold_right f x tl)
+ end.
+Lemma fold_right_min ls x :
+ x = List.fold_right Z.min x ls
+ \/ List.In (List.fold_right Z.min x ls) ls.
+Proof.
+ induction ls; intros; simpl in *; try tauto.
+ match goal with |- context [Z.min ?x ?y] =>
+ destruct (Z.min_spec x y) as [[? Hmin]|[? Hmin]]
+ end; rewrite Hmin; tauto.
+Qed.
+Lemma fold_right_no_starter_min ls : forall x,
+ fold_right_no_starter Z.min ls = Some x ->
+ List.In x ls.
+Proof.
+ cbv [fold_right_no_starter]; intros; destruct ls; try discriminate.
+ inversion H; subst; clear H.
+ destruct (fold_right_min ls z);
+ simpl List.In; tauto.
+Qed.
+Fixpoint fold_right_cps {A B} (g:B->A->A) (a0:A) (l:list B) {T} (f:A->T) :=
+ match l with
+ | nil => f a0
+ | cons a tl => fold_right_cps g a0 tl (fun r => f (g a r))
+ end.
+Lemma fold_right_cps_correct {A B} g a0 l: forall {T} f,
+ @fold_right_cps A B g a0 l T f = f (List.fold_right g a0 l).
+Proof. induction l; intros; simpl; rewrite ?IHl; auto. Qed.
+Hint Rewrite @fold_right_cps_correct : uncps.
+
+Definition fold_right_no_starter_cps {A} g ls {T} (f:option A->T) :=
+ match ls with
+ | nil => f None
+ | cons x tl => f (Some (List.fold_right g x tl))
+ end.
+Lemma fold_right_no_starter_cps_correct {A} g ls {T} f :
+ @fold_right_no_starter_cps A g ls T f = f (fold_right_no_starter g ls).
+Proof.
+ cbv [fold_right_no_starter_cps fold_right_no_starter]; break_match; reflexivity.
+Qed.
+Hint Rewrite @fold_right_no_starter_cps_correct : uncps.