diff options
author | jadephilipoom <jade.philipoom@gmail.com> | 2017-02-22 18:44:33 -0500 |
---|---|---|
committer | Andres Erbsen <andreser@mit.edu> | 2017-02-22 18:44:33 -0500 |
commit | ce10def144ca9a21c3b1ca4a262b1c94336513e5 (patch) | |
tree | 02d40658aee71f6170032ee360a0fb03fa23974f /src/Util | |
parent | 57a0a97fdbeee2954128d0917d534a7ed8c433cb (diff) |
Merge new base system (#112)
* Added sketch of new low-level base system code
* Implemented and proved addition
* Implemented carrying, which requires defining over Z rather than arbitrary ring
* Proved carry and proved ring-ness of base system ops
* Implemented split operation
* Started implementing modular reduction
* NewBaseSystem: prettify some proofs
* andres base
* improve andresbase
* new base
* first draft of goldilocks karatsuba
* Factored out goldilocks karatsuba
* Implement and prove karatsuba
* goldilocks cleanup
remodularize
* merge karatsuba and goldilocs karatsuba parameter blocks
* carry impl and proofs (not yet synthesis-ready)
* newbasesystem: use rewrite databases
* carry index range fix (TODO: allow for carry-then-reduce)
* simpler carry implementation
* Added compact operation for saturated base systems (this handles carries after multiplying or adding)
* debugging reduction for compact_rows
* rewrote compact
* Converted saturated section to CPS
* some progress on cps conversion for non-saturated stuff
* Converted associational non-saturated code to CPS, temporarily commented out examples
* pushed cps conversion through Positional
* moved list/tuple stuff to top of file
* proved lingering lemma
* worked on generic-style goal for simplified operations
* finished proving the generic-form example goal, revising a couple earlier lemmas
* revised previous lemmas
* finished revising previous lemmas
* removed commented-out code
* fixed non-terminating string in comment
* fix for 8.5
* removed old file
* better automation part 1
* better automation part 2 (goodbye proofs)
* better automation part 3/3
* some work on freeze
* remove saturated code and clean up exported-operations code
* Move helper lemmas for list/tuple CPS stuff to new CPSUtil file
* qualified imports
* fix runtime notations and module-level Let as per comments
* moved push_id to CPSUtil and cancel_pair lemmas to Prod
* fixed typo
* correctly generalized and moved lift_tuple2 (now called lift2_sig) and converted chained_carries into a fold
* moved karatsuba section to new file
* rename lemmas and definitions (now cps definitions are consistently <name>_cps and non-cps equivalents have no suffix)
* updated timing on mulT
* renamed push_eval to push_basesystem_eval
Diffstat (limited to 'src/Util')
-rw-r--r-- | src/Util/CPSUtil.v | 244 | ||||
-rw-r--r-- | src/Util/Prod.v | 4 | ||||
-rw-r--r-- | src/Util/Sigma.v | 9 | ||||
-rw-r--r-- | src/Util/ZUtil.v | 1 |
4 files changed, 257 insertions, 1 deletions
diff --git a/src/Util/CPSUtil.v b/src/Util/CPSUtil.v new file mode 100644 index 000000000..5d2a80399 --- /dev/null +++ b/src/Util/CPSUtil.v @@ -0,0 +1,244 @@ +Require Import Coq.Lists.List. Import ListNotations. +Require Import Coq.ZArith.ZArith Coq.omega.Omega. +Require Import Crypto.Util.ListUtil Crypto.Util.Tactics. +Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple. +Local Open Scope Z_scope. + +Lemma push_id {A} (a:A) : id a = a. reflexivity. Qed. +Create HintDb push_id discriminated. Hint Rewrite @push_id : push_id. + +Lemma update_nth_id {T} i (xs:list T) : ListUtil.update_nth i id xs = xs. +Proof. + revert xs; induction i; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma map_fst_combine {A B} (xs:list A) (ys:list B) : List.map fst (List.combine xs ys) = List.firstn (length ys) xs. +Proof. + revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma map_snd_combine {A B} (xs:list A) (ys:list B) : List.map snd (List.combine xs ys) = List.firstn (length xs) ys. +Proof. + revert xs; induction ys; destruct xs; simpl; solve [ trivial | congruence ]. +Qed. + +Lemma nth_default_seq_inbouns d s n i (H:(i < n)%nat) : + List.nth_default d (List.seq s n) i = (s+i)%nat. +Proof. + progress cbv [List.nth_default]. + rewrite ListUtil.nth_error_seq. + break_innermost_match; solve [ trivial | omega ]. +Qed. + +Lemma mod_add_mul_full a b c k m : m <> 0 -> c mod m = k mod m -> + (a + b * c) mod m = (a + b * k) mod m. +Proof. + intros; rewrite Z.add_mod, Z.mul_mod by auto. + match goal with H : _ mod _ = _ mod _ |- _ => rewrite H end. + rewrite <-Z.mul_mod, <-Z.add_mod by auto; reflexivity. +Qed. + +(* TODO +Lemma to_nat_neg : forall x, x < 0 -> Z.to_nat x = 0%nat. +Proof. destruct x; try reflexivity; intros. pose proof (Pos2Z.is_pos p). omega. Qed. + *) + +Fixpoint map_cps {A B} (g : A->B) ls + {T} (f:list B->T):= + match ls with + | nil => f nil + | a :: t => map_cps g t (fun r => f (g a :: r)) + end. +Lemma map_cps_correct {A B} g ls: forall {T} f, + @map_cps A B g ls T f = f (map g ls). +Proof. induction ls; simpl; intros; rewrite ?IHls; reflexivity. Qed. +Create HintDb uncps discriminated. Hint Rewrite @map_cps_correct : uncps. + +Fixpoint flat_map_cps {A B} (g:A->forall {T}, (list B->T)->T) (ls : list A) {T} (f:list B->T) := + match ls with + | nil => f nil + | (x::tl)%list => g x (fun r => flat_map_cps g tl (fun rr => f (r ++ rr))%list) + end. +Lemma flat_map_cps_correct {A B} (g:A->forall {T}, (list B->T)->T) ls : + forall {T} (f:list B->T), + (forall x T h, @g x T h = h (g x id)) -> + @flat_map_cps A B g ls T f = f (List.flat_map (fun x => g x id) ls). +Proof. + induction ls; intros; [reflexivity|]. + simpl flat_map_cps. simpl flat_map. + rewrite H; erewrite IHls by eassumption. + reflexivity. +Qed. +Hint Rewrite @flat_map_cps_correct using (intros; autorewrite with uncps; auto): uncps. + +Fixpoint from_list_default'_cps {A} (d y:A) n xs: + forall {T}, (Tuple.tuple' A n -> T) -> T:= + match n as n0 return (forall {T}, (Tuple.tuple' A n0 ->T) ->T) with + | O => fun T f => f y + | S n' => fun T f => + match xs with + | nil => from_list_default'_cps d d n' nil (fun r => f (r, y)) + | x :: xs' => from_list_default'_cps d x n' xs' (fun r => f (r, y)) + end + end. +Lemma from_list_default'_cps_correct {A} n : forall d y l {T} f, + @from_list_default'_cps A d y n l T f = f (Tuple.from_list_default' d y n l). +Proof. + induction n; intros; simpl; [reflexivity|]. + break_match; subst; apply IHn. +Qed. +Definition from_list_default_cps {A} (d:A) n (xs:list A) : + forall {T}, (Tuple.tuple A n -> T) -> T:= + match n as n0 return (forall {T}, (Tuple.tuple A n0 ->T) ->T) with + | O => fun T f => f tt + | S n' => fun T f => + match xs with + | nil => from_list_default'_cps d d n' nil f + | x :: xs' => from_list_default'_cps d x n' xs' f + end + end. +Lemma from_list_default_cps_correct {A} n : forall d l {T} f, + @from_list_default_cps A d n l T f = f (Tuple.from_list_default d n l). +Proof. + destruct n; intros; simpl; [reflexivity|]. + break_match; auto using from_list_default'_cps_correct. +Qed. +Hint Rewrite @from_list_default_cps_correct : uncps. +Fixpoint to_list'_cps {A} n + {T} (f:list A -> T) : Tuple.tuple' A n -> T := + match n as n0 return (Tuple.tuple' A n0 -> T) with + | O => fun x => f [x] + | S n' => fun (xs: Tuple.tuple' A (S n')) => + let (xs', x) := xs in + to_list'_cps n' (fun r => f (x::r)) xs' + end. +Lemma to_list'_cps_correct {A} n: forall t {T} f, + @to_list'_cps A n T f t = f (Tuple.to_list' n t). +Proof. + induction n; simpl; intros; [reflexivity|]. + destruct_head prod. apply IHn. +Qed. +Definition to_list_cps' {A} n {T} (f:list A->T) + : Tuple.tuple A n -> T := + match n as n0 return (Tuple.tuple A n0 ->T) with + | O => fun _ => f nil + | S n' => to_list'_cps n' f + end. +Definition to_list_cps {A} n t {T} f := + @to_list_cps' A n T f t. +Lemma to_list_cps_correct {A} n t {T} f : + @to_list_cps A n t T f = f (Tuple.to_list n t). +Proof. cbv [to_list_cps to_list_cps' Tuple.to_list]; break_match; auto using to_list'_cps_correct. Qed. +Hint Rewrite @to_list_cps_correct : uncps. + +Definition on_tuple_cps {A B} (d:B) (g:list A ->forall {T},(list B->T)->T) {n m} + (xs : Tuple.tuple A n) {T} (f:tuple B m ->T) := + to_list_cps n xs (fun r => g r (fun rr => from_list_default_cps d m rr f)). +Lemma on_tuple_cps_correct {A B} d (g:list A -> forall {T}, (list B->T)->T) + {n m} xs {T} f + (Hg : forall x {T} h, @g x T h = h (g x id)) : forall H, + @on_tuple_cps A B d g n m xs T f = f (@Tuple.on_tuple A B (fun x => g x id) n m H xs). +Proof. + cbv [on_tuple_cps Tuple.on_tuple]; intros. + rewrite to_list_cps_correct, Hg, from_list_default_cps_correct. + rewrite (Tuple.from_list_default_eq _ _ _ (H _ (Tuple.length_to_list _))). + reflexivity. +Qed. Hint Rewrite @on_tuple_cps_correct using (intros; autorewrite with uncps; auto): uncps. + +Fixpoint update_nth_cps {A} n (g:A->A) xs {T} (f:list A->T) := + match n with + | O => + match xs with + | [] => f [] + | x' :: xs' => f (g x' :: xs') + end + | S n' => + match xs with + | [] => f [] + | x' :: xs' => update_nth_cps n' g xs' (fun r => f (x' :: r)) + end + end. +Lemma update_nth_cps_correct {A} n g: forall xs T f, + @update_nth_cps A n g xs T f = f (update_nth n g xs). +Proof. induction n; intros; simpl; break_match; try apply IHn; reflexivity. Qed. +Hint Rewrite @update_nth_cps_correct : uncps. + +Fixpoint combine_cps {A B} (la :list A) (lb : list B) + {T} (f:list (A*B)->T) := + match la with + | nil => f nil + | a :: tla => + match lb with + | nil => f nil + | b :: tlb => combine_cps tla tlb (fun lab => f ((a,b)::lab)) + end + end. +Lemma combine_cps_correct {A B} la: forall lb {T} f, + @combine_cps A B la lb T f = f (combine la lb). +Proof. + induction la; simpl combine_cps; simpl combine; intros; + try break_match; try apply IHla; reflexivity. +Qed. +Hint Rewrite @combine_cps_correct: uncps. + +(* differs from fold_right_cps in that the functional argument `g` is also a CPS function *) +Fixpoint fold_right_cps2 {A B} (g : B -> A -> forall {T}, (A->T)->T) (a0 : A) (l : list B) {T} (f : A -> T) := + match l with + | nil => f a0 + | b :: tl => fold_right_cps2 g a0 tl (fun r => g b r f) + end. +Lemma fold_right_cps2_correct {A B} g a0 l : forall {T} f, + (forall b a T h, @g b a T h = h (@g b a A id)) -> + @fold_right_cps2 A B g a0 l T f = f (List.fold_right (fun b a => @g b a A id) a0 l). +Proof. + induction l; intros; [reflexivity|]. + simpl fold_right_cps2. simpl fold_right. + rewrite H; erewrite IHl by eassumption. + rewrite H; reflexivity. +Qed. +Hint Rewrite @fold_right_cps2_correct using (intros; autorewrite with uncps; auto): uncps. + +Definition fold_right_no_starter {A} (f:A->A->A) ls : option A := + match ls with + | nil => None + | cons x tl => Some (List.fold_right f x tl) + end. +Lemma fold_right_min ls x : + x = List.fold_right Z.min x ls + \/ List.In (List.fold_right Z.min x ls) ls. +Proof. + induction ls; intros; simpl in *; try tauto. + match goal with |- context [Z.min ?x ?y] => + destruct (Z.min_spec x y) as [[? Hmin]|[? Hmin]] + end; rewrite Hmin; tauto. +Qed. +Lemma fold_right_no_starter_min ls : forall x, + fold_right_no_starter Z.min ls = Some x -> + List.In x ls. +Proof. + cbv [fold_right_no_starter]; intros; destruct ls; try discriminate. + inversion H; subst; clear H. + destruct (fold_right_min ls z); + simpl List.In; tauto. +Qed. +Fixpoint fold_right_cps {A B} (g:B->A->A) (a0:A) (l:list B) {T} (f:A->T) := + match l with + | nil => f a0 + | cons a tl => fold_right_cps g a0 tl (fun r => f (g a r)) + end. +Lemma fold_right_cps_correct {A B} g a0 l: forall {T} f, + @fold_right_cps A B g a0 l T f = f (List.fold_right g a0 l). +Proof. induction l; intros; simpl; rewrite ?IHl; auto. Qed. +Hint Rewrite @fold_right_cps_correct : uncps. + +Definition fold_right_no_starter_cps {A} g ls {T} (f:option A->T) := + match ls with + | nil => f None + | cons x tl => f (Some (List.fold_right g x tl)) + end. +Lemma fold_right_no_starter_cps_correct {A} g ls {T} f : + @fold_right_no_starter_cps A g ls T f = f (fold_right_no_starter g ls). +Proof. + cbv [fold_right_no_starter_cps fold_right_no_starter]; break_match; reflexivity. +Qed. +Hint Rewrite @fold_right_no_starter_cps_correct : uncps. diff --git a/src/Util/Prod.v b/src/Util/Prod.v index bcd9404a6..6e6c7d3c4 100644 --- a/src/Util/Prod.v +++ b/src/Util/Prod.v @@ -16,6 +16,10 @@ Local Arguments f_equal {_ _} _ {_ _} _. Scheme Equality for prod. +Definition fst_pair {A B} (a:A) (b:B) : fst (a,b) = a := eq_refl. +Definition snd_pair {A B} (a:A) (b:B) : snd (a,b) = b := eq_refl. +Create HintDb cancel_pair discriminated. Hint Rewrite @fst_pair @snd_pair : cancel_pair. + (** ** Equality for [prod] *) Section prod. (** *** Projecting an equality of a pair to equality of the first components *) diff --git a/src/Util/Sigma.v b/src/Util/Sigma.v index 57c82df68..7a1d0cacb 100644 --- a/src/Util/Sigma.v +++ b/src/Util/Sigma.v @@ -16,6 +16,15 @@ Local Arguments f_equal {_ _} _ {_ _} _. (** ** Equality for [sigT] *) Section sigT. + (* Lift foralls out of sigT proofs and leave a sig goal *) + Definition lift2_sig {R S T} f (g:R->S) + (X : forall a b, {prod | g prod = f a b}) : + { op : T -> T -> R & forall a b, g (op a b) = f a b }. + Proof. + exists (fun a b => proj1_sig (X a b)). + exact (fun a b => proj2_sig (X a b)). + Defined. + (** *** Projecting an equality of a pair to equality of the first components *) Definition pr1_path {A} {P : A -> Type} {u v : sigT P} (p : u = v) : projT1 u = projT1 v diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index ee280ca06..4c6e2441d 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -766,7 +766,6 @@ Module Z. apply Z.mod_mul, Z.pow_nonzero; omega. } Qed. - Lemma odd_mod : forall a b, (b <> 0)%Z -> Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a. Proof. |