diff options
author | Jason Gross <jgross@mit.edu> | 2018-02-12 21:42:29 -0500 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-02-18 19:25:23 -0500 |
commit | ae4f8c9a53717c3733a29abfb8dfe716c38d21a8 (patch) | |
tree | e04fc47af4630d56c73973ad728822fa1f50d233 /src | |
parent | 7b84df346c6a56b1b51cce73078411767122caa4 (diff) |
WIP on more general continuations
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 545 |
1 files changed, 458 insertions, 87 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index c033b4c88..f9db6517c 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -278,6 +278,44 @@ Module Positional. Section Positional. End Carries. + + Section carry_mulmod. + Context (m:Z) (s:Z) + (c:list (Z*Z)) + (n : nat) + (len_c : nat) + (idxs : list nat) + (len_idxs : nat) + (fg : list Z * list Z). + + Derive carry_mulmod + SuchThat (forall (f := fst fg) (g := snd fg) + (m_nz:m <> 0) (s_nz:s <> 0) (Hm:m = s - Associational.eval c) + (Hf : length f = n) + (Hg : length g = n) + (Hn_nz : n <> 0%nat) + (Hc : length c = len_c) + (Hidxs : length idxs = len_idxs) + (Hw_div_nz : forall i : nat, weight (S i) / weight i <> 0), + (eval n carry_mulmod) mod (s - Associational.eval c) + = (eval n f * eval n g) mod (s - Associational.eval c)) + As carry_mulmod_correct. + Proof. + intros. + erewrite <-eval_mulmod with (s:=s) (c:=c) + by (subst; try assumption; try reflexivity). + etransitivity; + [ | rewrite <- @eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) (modulo:=fun x y => Z.modulo x y) (div:=fun x y => Z.div x y) + by (subst; try assumption; auto using Z.div_mod); reflexivity ]. + eapply f_equal2; [|trivial]. eapply f_equal. + erewrite <- (expand_list_correct _ (-1)%Z f), + <- (expand_list_correct _ (-1)%Z g), + <- (expand_list_correct _ 0%nat idxs), + <- (expand_list_correct _ (-1,-1)%Z c) + by eassumption. + subst carry_mulmod; reflexivity. + Qed. + End carry_mulmod. End Positional. End Positional. Module Compilers. @@ -299,6 +337,24 @@ Module Compilers. | bool => Datatypes.bool end%type. + Fixpoint final_codomain (t : type) : type + := match t with + | type_primitive _ as t + | prod _ _ as t + | list _ as t + => t + | arrow s d => final_codomain d + end. + + Fixpoint under_arrows (t : type) (f : type -> type) : type + := match t with + | type_primitive _ as t + | prod _ _ as t + | list _ as t + => f t + | arrow s d => arrow s (under_arrows d f) + end. + Ltac reify_primitive ty := lazymatch eval cbv beta in ty with | Datatypes.unit => unit @@ -1620,6 +1676,59 @@ Module Compilers. => (untranslate R t -> R)%ctype | type.list A => Compilers.type.list (untranslate R A) end%cpstype. + Fixpoint try_untranslate (t : type) + : option Compilers.type.type + := match t with + | type.type_primitive t => @Some Compilers.type.type t + | type.list A => option_map Compilers.type.list (try_untranslate A) + | A * B + => match try_untranslate A, try_untranslate B with + | Some A, Some B => Some (A * B)%ctype + | _, _ => None + end + | (s * (d --->) --->) + => match try_untranslate s, try_untranslate d with + | Some s, Some d => Some (s -> d)%ctype + | _, _ => None + end + | (_ --->) => None + end%cpstype. + + Fixpoint try_transport_untranslate (P : type -> Type) (t : type) + : P t -> option { t : _ & P (translate t) } + := match t with + | type.type_primitive t + => fun v => Some (existT _ (t : Compilers.type.type) v) + | type.list A + => fun v => option_map + (fun '(existT A v) => existT (fun t => P (translate t)) + (Compilers.type.list A) + v) + (@try_transport_untranslate (fun t => P (type.list t)) A v) + | A * B + => fun v : P (A * B) + => match @try_transport_untranslate (fun a => P (a * _)) A v with + | Some (existT A v) + => match @try_transport_untranslate (fun b => P (_ * b)) B v with + | Some (existT B v) + => Some (existT _ (A * B)%ctype v) + | None => None + end + | None => None + end + | (s * (d --->) --->) + => fun v + => match @try_transport_untranslate (fun s => P ((s * _) --->)) s v with + | Some (existT s v) + => match @try_transport_untranslate (fun d => P (_ * (d --->) --->)) d v with + | Some (existT d v) + => Some (existT _ (s -> d)%ctype v) + | None => None + end + | None => None + end + | (_ --->) => fun _ => None + end%cpstype. End translate. End type. @@ -2103,23 +2212,332 @@ Module Compilers. : @Compilers.Uncurried.expr.default.Expr R := expr.CallWithContinuation (@ident.untranslate _) (@ident.fst) (@ident.snd) (@ident.bool_rect) e k. + Module type_descr. + Import Compilers.type. + Inductive flat_type := + | type_primitive (_ : primitive) + | prod (_ : flat_type) (_ : flat_type) + | list (_ : flat_type). + Inductive argtype := + | flat_arg (_ : flat_type) + | arrow_arg (s : flat_type) (d : argtype) + | prod_arg (_ _ : argtype). + Inductive type := + | flat (_ : flat_type) + | arrow (s : argtype) (d : type). + + Module Export Coercions. + Coercion type_primitive : primitive >-> flat_type. + Coercion flat_arg : flat_type >-> argtype. + Coercion flat : flat_type >-> type. + End Coercions. + + Fixpoint flat_to_type (t : flat_type) : Compilers.type.type + := match t with + | type_primitive x => x + | prod A B => Compilers.type.prod (flat_to_type A) (flat_to_type B) + | list A => Compilers.type.list (flat_to_type A) + end. + + Fixpoint arg_to_type (t : argtype) : Compilers.type.type + := match t with + | flat_arg t => flat_to_type t + | arrow_arg s d => Compilers.type.arrow (flat_to_type s) (arg_to_type d) + | prod_arg A B => Compilers.type.prod (arg_to_type A) (arg_to_type B) + end. + + Fixpoint to_type (t : type) : Compilers.type.type + := match t with + | flat t => flat_to_type t + | arrow s d => Compilers.type.arrow (arg_to_type s) (to_type d) + end. + + Fixpoint flat_of_type (t : Compilers.type.type) : option flat_type + := match t with + | Compilers.type.type_primitive x => @Some flat_type x + | Compilers.type.prod A B + => match flat_of_type A, flat_of_type B with + | Some A, Some B => @Some flat_type (prod A B) + | _, _ => None + end + | type.arrow s d => None + | Compilers.type.list A + => option_map list (flat_of_type A) + end. + + Fixpoint arg_of_type (t : Compilers.type.type) : option argtype + := match t with + | Compilers.type.prod A B as t + => match flat_of_type t, arg_of_type A, arg_of_type B with + | Some t, _, _ + => @Some argtype t + | None, Some A, Some B + => @Some argtype (prod_arg A B) + | _, _, _ => None + end + | type.arrow s d + => match flat_of_type s, arg_of_type d with + | Some s, Some d => Some (arrow_arg s d) + | _, _ => None + end + | Compilers.type.type_primitive _ as t + | Compilers.type.list _ as t + => option_map flat_arg (flat_of_type t) + end. + + Fixpoint of_type (t : Compilers.type.type) : option type + := match t with + | type.arrow s d + => match arg_of_type s, of_type d with + | Some s, Some d => Some (arrow s d) + | _, _ => None + end + | Compilers.type.prod _ _ as t + | Compilers.type.type_primitive _ as t + | Compilers.type.list _ as t + => option_map flat (flat_of_type t) + end. + + Fixpoint try_transport_flat_of_type P (t : Compilers.type.type) + : P t -> option { t' : _ & P (flat_to_type t') } + := match t with + | Compilers.type.type_primitive x + => fun v => Some (existT _ (x : flat_type) v) + | Compilers.type.prod A B + => fun v + => match try_transport_flat_of_type (fun a => P (a * _)%ctype) A v with + | Some (existT A v) + => match try_transport_flat_of_type (fun b => P (_ * b)%ctype) B v with + | Some (existT B v) + => Some (existT _ (prod A B) v) + | None => None + end + | None => None + end + | type.arrow s d => fun _ => None + | Compilers.type.list A + => fun v + => option_map + (fun '(existT A v) => existT (fun t => P (flat_to_type t)) (list A) v) + (try_transport_flat_of_type (fun a => P (Compilers.type.list a)) A v) + end. + + Fixpoint try_transport_arg_of_type P (t : Compilers.type.type) + : P t -> option { t' : _ & P (arg_to_type t') } + := match t with + | Compilers.type.prod A B as t + => fun v + => match try_transport_flat_of_type P t v with + | Some (existT t v) => Some (existT (fun t' => P (arg_to_type t')) t v) + | None + => match try_transport_arg_of_type (fun a => P (a * _)%ctype) A v with + | Some (existT A v) + => match try_transport_arg_of_type (fun b => P (_ * b)%ctype) B v with + | Some (existT B v) + => Some (existT _ (prod_arg A B) v) + | None => None + end + | None => None + end + end + | type.arrow s d + => fun v + => match try_transport_flat_of_type (fun s => P (s -> _)%ctype) s v with + | Some (existT s v) + => match try_transport_flat_of_type (fun d => P (_ -> d)%ctype) d v with + | Some (existT d v) + => Some (existT (fun t' => P (arg_to_type t')) (arrow_arg s d) v) + | None => None + end + | None => None + end + | Compilers.type.type_primitive _ as t + | Compilers.type.list _ as t + => fun v + => option_map + (fun '(existT t v) => existT (fun t => P (arg_to_type t)) (flat_arg t) v) + (try_transport_flat_of_type P t v) + end. + + Fixpoint try_transport_of_type P (t : Compilers.type.type) + : P t -> option { t' : _ & P (to_type t') } + := match t with + | type.arrow s d + => fun v + => match try_transport_arg_of_type (fun s => P (s -> _)%ctype) s v with + | Some (existT s v) + => match try_transport_of_type (fun d => P (_ -> d)%ctype) d v with + | Some (existT d v) + => Some (existT (fun t' => P (to_type t')) (arrow s d) v) + | None => None + end + | None => None + end + | Compilers.type.prod _ _ as t + | Compilers.type.type_primitive _ as t + | Compilers.type.list _ as t + => fun v + => option_map + (fun '(existT t v) => existT (fun t => P (to_type t)) (flat t) v) + (try_transport_flat_of_type P t v) + end. + End type_descr. + Import type_descr.Coercions. + + Fixpoint untranslate_translate_flat + (P : Compilers.type.type -> Type) + {R} + {t : type_descr.flat_type} + (e : P (type_descr.to_type t)) + {struct t} + : P (type.untranslate R (type.translate (type_descr.to_type t))) + := match t return P (type_descr.to_type t) + -> P (type.untranslate R (type.translate (type_descr.to_type t))) + with + | type_descr.type_primitive x => id + | type_descr.prod A B + => fun ab : P (type_descr.flat_to_type A * type_descr.flat_to_type B)%ctype + => @untranslate_translate_flat + (fun a => P (a * _)%ctype) + R A + (@untranslate_translate_flat + (fun b => P (_ * b)%ctype) + R B + ab) + | type_descr.list A + => @untranslate_translate_flat + (fun t => P (Compilers.type.list t)) + R A + end e. + + Fixpoint untranslate_translate_flat' + (P : Compilers.type.type -> Type) + {R} + {t : type_descr.flat_type} + (e : P (type.untranslate R (type.translate (type_descr.to_type t)))) + {struct t} + : P (type_descr.to_type t) + := match t return P (type.untranslate R (type.translate (type_descr.to_type t))) + -> P (type_descr.to_type t) + with + | type_descr.type_primitive x => id + | type_descr.prod A B + => fun ab : + (* ignore this line *) P (type.untranslate R (type.translate (type_descr.flat_to_type A)) * type.untranslate R (type.translate (type_descr.flat_to_type B)))%ctype + => @untranslate_translate_flat' + (fun a => P (a * _)%ctype) + R A + (@untranslate_translate_flat' + (fun b => P (_ * b)%ctype) + R B + ab) + | type_descr.list A + => @untranslate_translate_flat' + (fun t => P (Compilers.type.list t)) + R A + end e. + + Definition transport_final_codomain_flat P {t} + : P (type_descr.flat_to_type t) + -> P (type.final_codomain (type_descr.flat_to_type t)) + := match t with + | type_descr.type_primitive x => id + | type_descr.prod x x0 => id + | type_descr.list x => id + end. + + Definition transport_final_codomain_flat' P {t} + : P (type.final_codomain (type_descr.flat_to_type t)) + -> P (type_descr.flat_to_type t) + := match t with + | type_descr.type_primitive x => id + | type_descr.prod x x0 => id + | type_descr.list x => id + end. + + Fixpoint untranslate_translate_arg + {var} + {R} + {t : type_descr.argtype} + (e : @Compilers.Uncurried.expr.default.expr var (type_descr.arg_to_type t)) + {struct t} + : @Compilers.Uncurried.expr.default.expr var (type.untranslate R (type.translate (type_descr.arg_to_type t))) + := match t return Compilers.Uncurried.expr.default.expr (type_descr.arg_to_type t) + -> Compilers.Uncurried.expr.default.expr (type.untranslate R (type.translate (type_descr.arg_to_type t))) + with + | type_descr.flat_arg t + => untranslate_translate_flat _ + | type_descr.arrow_arg s d + => fun e' + => Abs (fun v : + (* ignore this line *) var (type.untranslate R (type.translate (type_descr.flat_to_type s)) * (type.untranslate R (type.translate (type_descr.arg_to_type d)) -> R))%ctype + => (ident.snd @@ Var v) + @ (@untranslate_translate_arg + var R d + (e' @ (untranslate_translate_flat' _ (ident.fst @@ Var v)))))%expr + | type_descr.prod_arg A B + => fun e' : expr.default.expr (type_descr.arg_to_type A * type_descr.arg_to_type B) + => ((Abs (fun a => Abs (fun b => (Var a, Var b)))) + @ (@untranslate_translate_arg var R A (ident.fst @@ e')) + @ (@untranslate_translate_arg var R B (ident.snd @@ e')))%expr + end e. + + Local Notation "x <-- e1 ; e2" := (expr.splice e1 (fun x => e2%cpsexpr)) : cpsexpr_scope. + + Fixpoint call_fun_with_id_continuation' + {var} + {t : type_descr.type} + (R := type.final_codomain (type_descr.to_type t)) + (e : @expr (fun t0 => + @Uncurried.expr.expr default.ident.ident var (type.untranslate R t0)) + (type.translate (type_descr.to_type t))) + {struct t} + : @Compilers.Uncurried.expr.default.expr var (type_descr.to_type t) + := match t + return (@expr (fun t0 => + @Uncurried.expr.expr default.ident.ident var (type.untranslate (type.final_codomain (type_descr.to_type t)) t0)) + (type.translate (type_descr.to_type t))) + -> @Compilers.Uncurried.expr.default.expr var (type_descr.to_type t) + with + | type_descr.flat t + => fun e' + => transport_final_codomain_flat' + _ + (@call_with_continuation + var _ _ e' + (fun e'' => transport_final_codomain_flat _ (untranslate_translate_flat' _ e''))) + | type_descr.arrow s d + => fun e' : + (* ignore this line *) expr (type.translate (type_descr.arg_to_type s) * (type.translate (type_descr.to_type d) --->) --->) + => Abs (s:=type_descr.arg_to_type s) (d:=type_descr.to_type d) + (fun v + => @call_fun_with_id_continuation' + var d + (f <-- e'; + k <- (λ r, expr.Halt r); + p <- (untranslate_translate_arg (Var v), k); + f @ p)%cpsexpr) + end e. Definition CallFunWithIdContinuation' - {R} - {s d} (e : Expr (type.translate (s -> d))) - (k : forall var, var (type.untranslate R (type.translate d)) -> var R) - : @Compilers.Uncurried.expr.default.Expr (type.untranslate R (type.translate s) -> R) - := fun var - => Abs (fun x => @call_with_continuation - var R _ (e _) - (fun e : expr.default.expr (type.untranslate _ (type.translate s) * (type.untranslate _ (type.translate d) -> _) -> _) - => e @ (Var x, λ v , Var (k _ v)))%expr). - Notation CallFunWithIdContinuation e - := (@CallFunWithIdContinuation' - ((fun s d (e' : Expr (type.translate (s -> d))) => d) _ _ e) - _ _ - e - (fun _ => id)) - (only parsing). + {t : type_descr.type} + (e : Expr (type.translate (type_descr.to_type t))) + : @Compilers.Uncurried.expr.default.Expr (type_descr.to_type t) + := fun var => @call_fun_with_id_continuation' _ t (e _). + + Definition CallFunWithIdContinuation + {t} + (e : Expr t) + := match type.try_transport_untranslate Expr _ e as o return match o with None => _ | _ => _ end with + | Some v + => match type_descr.try_transport_of_type (fun t => Expr (type.translate t)) _ + (projT2 v) + as o return match o with None => _ | _ => _ end + with + | Some v => CallFunWithIdContinuation' (projT2 v) + | None => I + end + | None => I + end. End default. Include default. End CPS. @@ -3348,39 +3766,6 @@ Local Coercion QArith_base.inject_Z : Z >-> Q. - reassociation - indexed + bounds analysis + of phoas *) - -(* TODO: is this the right way to do things? *) -Definition expand_list_helper {A} (default : A) (ls : list A) (n : nat) (idx : nat) : list A - := nat_rect - (fun _ => nat -> list A) - (fun _ => nil) - (fun n' rec_call idx - => cons (List.nth_default default ls idx) (rec_call (S idx))) - n - idx. -Definition expand_list {A} (default : A) (ls : list A) (n : nat) : list A - := expand_list_helper default ls n 0. -Require Import Coq.micromega.Lia. -(* TODO: MOVE ME *) -Lemma expand_list_helper_correct {A} (default : A) (ls : list A) (n idx : nat) (H : (idx + n <= length ls)%nat) - : expand_list_helper default ls n idx - = List.firstn n (List.skipn idx ls). -Proof. - cbv [expand_list_helper]; revert idx H. - induction n as [|n IHn]; cbn; intros. - { reflexivity. } - { rewrite IHn by omega. - erewrite (@skipn_nth_default _ idx ls) by omega. - reflexivity. } -Qed. - -Lemma expand_list_correct (n : nat) {A} (default : A) (ls : list A) (H : List.length ls = n) - : expand_list default ls n = ls. -Proof. - subst; cbv [expand_list]; rewrite expand_list_helper_correct by reflexivity. - rewrite skipn_0, firstn_all; reflexivity. -Qed. - Delimit Scope RT_expr_scope with RT_expr. Import Uncurried. Import expr. @@ -3463,11 +3848,9 @@ Module test3. (z * z)%RT) in pose v as E. vm_compute in E. - pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'. + pose (PartialReduce (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. vm_compute in E'. - pose (PartialReduce E') as E''. - lazy in E''. - lazymatch (eval cbv delta [E''] in E'') with + lazymatch (eval cbv delta [E'] in E') with | (fun var : type -> Type => (λ x : var (type.type_primitive type.Z), expr_let x0 := Var x * Var x in @@ -3480,7 +3863,7 @@ Module test3. Import BoundsAnalysis.ident. Import BoundsAnalysis.Notations. pose (projT2 (Option.invert_Some (BoundsAnalysis.OfPHOAS.AnalyzeBounds - (fun x => Some x) E'' r[0~>10]%zrange))) as E'''. + (fun x => Some x) E' r[0~>10]%zrange))) as E'''. lazy in E'''. lazymatch (eval cbv delta [E'''] in E''') with | (expr_let 2 := mul r[0 ~> 10]%btype r[0 ~> 10]%btype r[0 ~> 100]%btype @@ (x_ 1, x_ 1) in @@ -3579,7 +3962,6 @@ End test6. Axiom admit : forall {T}, T. -(** TODO: split this into [carry_mul_gen] which does not use PHOAS stuff, and version that synthesizes a reified thing *) Derive carry_mul_gen SuchThat (forall (w : nat -> Z) (fg : list Z * list Z) @@ -3600,40 +3982,30 @@ Derive carry_mul_gen (Hsc_nz : s - Associational.eval c <> 0) (Hs_nz : s <> 0) (Hn_nz : n <> 0%nat), - let fg' := carry_mul_gen w fg n s c len_c idxs len_idxs in - (eval w n fg') mod (s - Associational.eval c) - = (eval w n f * eval w n g) mod (s - Associational.eval c)) + (* N.B. type must be a tuple for CPS.CallFunWithIdContinuation to work *) + Interp (t:=((type.nat*type.Z*type.list (type.Z * type.Z)*type.nat*type.list type.nat*type.nat*(type.nat->type.Z))*(type.list type.Z * type.list type.Z)->type.list type.Z)%ctype) + carry_mul_gen ((n, s, c, len_c, idxs, len_idxs, w), fg) + = carry_mulmod w s c n len_c idxs len_idxs fg) As carry_mul_gen_correct. Proof. - intros; subst carry_mul_gen. - erewrite <-eval_mulmod with (s:=s) (c:=c) - by (try assumption; try reflexivity). - (* eval w n (fg' w fg n s c len_c) mod (s - Associational.eval c) = - eval w n (mulmod w s c n f g) mod (s - Associational.eval c) *) - etransitivity; - [ | rewrite <- eval_chained_carries with (s:=s) (c:=c) (idxs:=idxs) (modulo:=fun x y => Z.modulo x y) (div:=fun x y => Z.div x y) - by (try assumption; auto using Z.div_mod); reflexivity ]. - eapply f_equal2; [|trivial]. eapply f_equal. - erewrite <- (expand_list_correct _ (-1)%Z f), - <- (expand_list_correct _ (-1)%Z g), - <- (expand_list_correct _ 0%nat idxs), - <- (expand_list_correct _ (-1,-1)%Z c) - by eassumption. - pose (idxs, len_idxs, n, s, c, len_c, w, fg) as args. - subst f g. - change fg with (snd args). - change w with (snd (fst args)). - change len_c with (snd (fst (fst args))). - change c with (snd (fst (fst (fst args)))). - change s with (snd (fst (fst (fst (fst args))))). - change n with (snd (fst (fst (fst (fst (fst args)))))). - change len_idxs with (snd (fst (fst (fst (fst (fst (fst args))))))). - change idxs with (fst (fst (fst (fst (fst (fst (fst args))))))). - remember args as args' eqn:Hargs. + intros; subst carry_mul_gen; cbv [carry_mulmod]. + clear. + repeat match goal with + | [ |- context[(?x, ?y)] ] + => is_var x; is_var y; + let args := fresh "args" in + let args' := fresh "args'" in + let Hargs := fresh "Hargs" in + set (args := (x, y)); + change x with (fst args); + change y with (snd args); + remember args as args' eqn:Hargs; subst args; + try subst x; try subst y + end. etransitivity. Focus 2. - { subst fg'. - repeat match goal with H : _ |- _ => clear H end; revert args'. + { repeat match goal with H : _ |- _ => clear H end. + repeat match goal with H : _ |- _ => revert H end. lazymatch goal with | [ |- forall args, ?ev = @?RHS args ] => refine (fun args => f_equal (fun F => F args) (_ : _ = RHS)) @@ -3648,8 +4020,7 @@ Proof. let E' := (eval vm_compute in E') in (* 0.131 for vm, about 0.6 for lazy, slower for native and cbv *) pose E' as E''. transitivity (Interp E'' (fst (fst args'), (fun '((i, k) : nat * (Z -> list Z)) => k (w i)), snd args')); [ clear E | exact admit ]. - subst args' args; cbn [fst snd]. - subst fg'. + subst args'; cbn [fst snd]. reflexivity. Qed. |