diff options
author | Jason Gross <jgross@mit.edu> | 2018-04-30 16:59:57 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-05-05 18:01:31 -0400 |
commit | dd204a68a92a7394962cd8183b74ffb215b706f7 (patch) | |
tree | 5632488227ccdfcb29fce626a05ca40af04f4d6b /src/Experiments | |
parent | 850f21f5867b21c5beeb5a23fb54c2622514998d (diff) |
WIP on lists as cons cells
Diffstat (limited to 'src/Experiments')
-rw-r--r-- | src/Experiments/PartialEvaluationWithLetIn.v | 268 |
1 files changed, 221 insertions, 47 deletions
diff --git a/src/Experiments/PartialEvaluationWithLetIn.v b/src/Experiments/PartialEvaluationWithLetIn.v index 5a7b816f6..7fa0641ef 100644 --- a/src/Experiments/PartialEvaluationWithLetIn.v +++ b/src/Experiments/PartialEvaluationWithLetIn.v @@ -1,3 +1,4 @@ +Require Import Coq.Lists.List. Require Import Crypto.Util.Option. Require Import Crypto.Util.Notations. @@ -41,7 +42,7 @@ Bind Scope etype_scope with type.type. Infix "->" := type.arrow : etype_scope. Module base. Module type. - Inductive type := nat | prod (A B : type). + Inductive type := nat | prod (A B : type) | list (A : type). End type. Notation type := type.type. End base. @@ -56,7 +57,8 @@ Module parametric. (base_subst : base_type_with_var -> base_type) (base_interp : base_type_with_var -> Type) (base_subst_interp : base_type -> Type) - (M : Type -> Type). + (M : Type -> Type) + (ret : forall T, T -> M T). Fixpoint subst (t : type base_type_with_var) : type base_type := match t with @@ -74,19 +76,38 @@ Module parametric. | type.base t => base_interp t end -> interp d end. + + Fixpoint interpM_final (t : type base_type_with_var) : Type + := match t with + | type.base t => M (base_interp t) + | type.arrow s d + => match s with + | type.arrow s' d' => type.interpM M base_subst_interp (subst s) + | type.base t => base_interp t + end -> interpM_final d + end. + + Fixpoint interpM_final_of_interp {t} : interp t -> interpM_final t + := match t with + | type.base t => ret _ + | type.arrow s d + => fun f x => @interpM_final_of_interp d (f x) + end. End subst. End type. Local Notation btype := base.type.type. Local Notation bnat := base.type.nat. Local Notation bprod := base.type.prod. + Local Notation blist := base.type.list. Module base. Module type. - Inductive type := nat | prod (A B : type) | var_with_subst (subst : btype). + Inductive type := nat | prod (A B : type) | list (A : type) | var_with_subst (subst : btype). Fixpoint subst (t : type) : btype := match t with | nat => bnat | prod A B => bprod (subst A) (subst B) + | list A => blist (subst A) | var_with_subst s => s end. @@ -97,6 +118,7 @@ Module parametric. := match t with | nat => Datatypes.nat | prod A B => interp A * interp B + | list A => Datatypes.list (interp A) | var_with_subst s => base_interp s end%type. End interp. @@ -107,8 +129,13 @@ Module parametric. Definition subst (t : type base.type) : type btype := type.subst base.type.subst t. - Definition half_interp (M : Type -> Type) (interp : btype -> Type) (t : type base.type) : Type - := type.interp base.type.subst (base.type.interp interp) interp M t. + Definition half_interp (M : Type -> Type) (half_interp : base.type.type -> Type) (interp : btype -> Type) (t : type base.type) : Type + := type.interp base.type.subst half_interp interp M t. + Definition half_interp2 (M : Type -> Type) (half_interp : base.type.type -> Type) (interp : btype -> Type) (t : type base.type) : Type + := type.interpM_final base.type.subst half_interp interp M t. + Definition half_interp2_of_interp {M half_interpf interp t} ret + : half_interp M half_interpf interp t -> half_interp2 M half_interpf interp t + := type.interpM_final_of_interp _ _ _ _ ret. End parametric. Notation ptype := (type.type parametric.base.type). Delimit Scope ptype_scope with ptype. @@ -121,6 +148,7 @@ Fixpoint upperboundT (t : base.type) : Type := match t with | base.type.nat => option nat | base.type.prod A B => upperboundT A * upperboundT B + | base.type.list A => option (list (upperboundT A)) end. Module expr. @@ -168,6 +196,12 @@ Module ident. | Pair {A B : base.type} : pident (#A -> #B -> #A * #B)%ptype | Fst {A B} : pident (#A * #B -> #A)%ptype | Snd {A B} : pident (#A * #B -> #B)%ptype + | Nil {A} : pident (parametric.base.type.list #A)%ptype + | Cons {A} : pident (#A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype + | List_map {A B} : pident ((#A -> #B) -> parametric.base.type.list #A -> parametric.base.type.list #B)%ptype + | List_app {A} : pident (parametric.base.type.list #A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype + | List_flat_map {A B} : pident ((#A -> parametric.base.type.list #B) -> parametric.base.type.list #A -> parametric.base.type.list #B)%ptype + | List_rect {A P} : pident (#P -> (#A -> parametric.base.type.list #A -> #P -> #P) -> parametric.base.type.list #A -> #P)%ptype | Cast {T} (upper_bound : upperboundT T) : pident (#T -> #T)%ptype . @@ -189,6 +223,12 @@ Module ident. Notation "# x" := (expr.Ident (wrap x)) (at level 9, x at level 10, format "# x") : expr_scope. Notation "( x , y , .. , z )" := (expr.App (expr.App (#Pair) .. (expr.App (expr.App (#Pair) x%expr) y%expr) .. ) z%expr) : expr_scope. Notation "x + y" := (#Plus @ x @ y)%expr : expr_scope. + (*Notation "x :: y" := (#Cons @ x @ y)%expr : expr_scope.*) + (* Unification fails if we don't fill in [wident pident] explicitly *) + Notation "x :: y" := (@expr.App base.type.type (wident pident) _ _ _ (#Cons @ x) y)%expr : expr_scope. + Notation "[ ]" := (#Nil)%expr : expr_scope. + Notation "[ x ]" := (#Cons @ x @ (#Nil))%expr : expr_scope. + Notation "[ x ; y ; .. ; z ]" := (@expr.App base.type.type (wident pident) _ _ _ (#Cons @ x) (#Cons @ y @ .. (#Cons @ z @ #Nil) ..))%expr : expr_scope. End Notations. End ident. Import ident.Notations. @@ -214,6 +254,13 @@ Module UnderLets. | UnderLet A x f => UnderLet x (fun v => @splice _ _ (f v) e) end. + Fixpoint splice_list {A B} (ls : list (@UnderLets A)) (e : list A -> @UnderLets B) : @UnderLets B + := match ls with + | nil => e nil + | cons x xs + => splice x (fun x => @splice_list A B xs (fun xs => e (cons x xs))) + end. + Fixpoint to_expr {t} (x : @UnderLets (expr t)) : expr t := match x with | Base v => v @@ -226,6 +273,9 @@ End UnderLets. Delimit Scope under_lets_scope with under_lets. Bind Scope under_lets_scope with UnderLets.UnderLets. Notation "x <-- y ; f" := (UnderLets.splice y (fun x => f%under_lets)) : under_lets_scope. +(** FIXME: MOVE ME *) +Reserved Notation "A <--- X ; B" (at level 70, X at next level, right associativity, format "'[v' A <--- X ; '/' B ']'"). +Notation "x <--- y ; f" := (UnderLets.splice_list y (fun x => f%under_lets)) : under_lets_scope. Module partial. Import UnderLets. @@ -254,7 +304,7 @@ Module partial. Definition value_with_lets (t : type) := UnderLets (value t). - Context (interp_ident : forall t, ident t -> value t). + Context (interp_ident : forall t, ident t -> value_with_lets t). Definition abstract_domain (t : type) := type.interp abstract_domain' t. @@ -312,7 +362,7 @@ Module partial. Fixpoint interp {t} (e : @expr value_with_lets t) : value_with_lets t := match e in expr.expr t return value_with_lets t with - | expr.Ident t idc => Base (interp_ident t idc) + | expr.Ident t idc => interp_ident t idc | expr.Var t v => v | expr.Abs s d f => Base (fun x => @interp d (f (Base x))) | expr.App s d f x @@ -370,15 +420,38 @@ Module partial. barrier in both directions a decent amount *) (ident_Literal : nat -> pident parametric.base.type.nat) (ident_Pair : forall A B, pident (#A -> #B -> #A * #B)%ptype) + (ident_Nil : forall A, pident (parametric.base.type.list #A)%ptype) + (ident_Cons : forall A, pident (#A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype) + (ident_List_app : forall A, pident (parametric.base.type.list #A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype) (ident_Fst : forall A B, pident (#A * #B -> #A)%ptype) - (ident_Snd : forall A B, pident (#A * #B -> #B)%ptype). + (ident_Snd : forall A B, pident (#A * #B -> #B)%ptype) + (hd_tl_list_state : forall A, abstract_domain' (base.type.list A) -> abstract_domain' A * abstract_domain' (base.type.list A)). + Local Notation expr_with_abs A + := (prod (abstract_domain' A) (@expr var A)). + Local Notation expr_or base_value A + := (sum (expr_with_abs A) (base_value A%etype)). Fixpoint base_value (t : base.type) := match t return Type with | base.type.nat as t => nat | base.type.prod A B as t - => (abstract_domain' A * @expr var A + base_value A) * (abstract_domain' B * @expr var B + base_value B) + => (expr_or base_value A) * (expr_or base_value B) + | base.type.list A as t + => list (expr_or base_value A) (* cons cells *) + end%type. + Local Notation value := (@value base.type ident var base_value abstract_domain'). + Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain'). + Fixpoint pbase_value (t : parametric.base.type) + := match t return Type with + | parametric.base.type.nat as t + => nat + | parametric.base.type.prod A B as t + => pbase_value A * pbase_value B + | parametric.base.type.list A as t + => list (pbase_value A) + | parametric.base.type.var_with_subst A as t + => value A end%type. Fixpoint abstraction_function {t} : base_value t -> abstract_domain' t @@ -397,6 +470,19 @@ Module partial. end in abstract_interp_ident _ (ident_Pair A B) sta stb + | base.type.list A + => fun cells + => let st_cells + := List.map + (fun a => match a with + | inl (st, _) => st + | inr a' => @abstraction_function A a' + end) + cells in + List.fold_right + (abstract_interp_ident _ (ident_Cons A)) + (abstract_interp_ident _ (ident_Nil A)) + st_cells end. Fixpoint base_reify {t} : base_value t -> @expr var t @@ -422,11 +508,30 @@ Module partial. | inr v => @base_reify _ v end in (#(ident_Pair A B) @ ea @ eb)%expr + | base.type.list A + => fun cells + => let cells' + := List.map + (fun a + => match a with + | inl (st, e) (* list chunk *) + => match annotate _ st with + | None => e + | Some cst => ###cst @ e + end%expr + | inr v + => @base_reify _ v + end) + cells in + List.fold_right + (fun x xs => (#(ident_Cons A) @ x @ xs)%expr) + (#(ident_Nil A))%expr + cells' end. - Local Notation value := (@value base.type ident var base_value abstract_domain'). - - Context (half_interp : forall {t} (idc : pident t), parametric.half_interp UnderLets value t). + Context (half_interp : forall {t} (idc : pident t), + parametric.half_interp UnderLets pbase_value value t + + parametric.half_interp2 UnderLets pbase_value value t). Fixpoint intersect_state_base_value {t} : abstract_domain' t -> base_value t -> base_value t := match t return abstract_domain' t -> base_value t -> base_value t with @@ -444,6 +549,20 @@ Module partial. | inr v => inr (@intersect_state_base_value _ stb v) end in (a', b') + | base.type.list _ + => fun st cells + => let '(cells', st) + := List.fold_left + (fun '(rest_cells, st) cell + => let '(st0, st') := hd_tl_list_state _ st in + (match cell with + | inl (st0', e) => inl (intersect_state _ st0 st0', e) + | inr v => inr (@intersect_state_base_value _ st0 v) + end :: rest_cells, + st')) + cells + (nil, st) in + cells' end. @@ -459,21 +578,22 @@ Module partial. end. Local Notation reify := (@reify base.type ident var base_value abstract_domain' annotate bottom' (@abstraction_function) (@base_reify)). - Print reflect. Local Notation reflect := (@reflect base.type ident var base_value abstract_domain' annotate bottom' (@abstraction_function) (@base_reify)). - Fixpoint pinterp_base {t : parametric.base.type} : parametric.half_interp UnderLets value (type.base t) -> value (parametric.subst (type.base t)) - := match t return parametric.half_interp UnderLets value (type.base t) -> value (parametric.subst (type.base t)) with + Fixpoint pinterp_base {t : parametric.base.type} : parametric.half_interp UnderLets pbase_value value (type.base t) -> value (parametric.subst (type.base t)) + := match t return parametric.half_interp UnderLets pbase_value value (type.base t) -> value (parametric.subst (type.base t)) with | parametric.base.type.nat => fun v => inr v | parametric.base.type.prod A B => fun '(a, b) => inr (@pinterp_base A a, @pinterp_base B b) + | parametric.base.type.list A + => fun ls => inr (List.map (@pinterp_base A) ls) | parametric.base.type.var_with_subst subst => fun v => v end. - Fixpoint puninterp_base {t : parametric.base.type} : value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets value (type.base t)) - := match t return value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets value (type.base t)) with + Fixpoint puninterp_base {t : parametric.base.type} : value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets pbase_value value (type.base t)) + := match t return value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets pbase_value value (type.base t)) with | parametric.base.type.nat => fun v => match v with @@ -489,41 +609,57 @@ Module partial. b' <- @puninterp_base B b; Some (a', b')) end + | parametric.base.type.list A + => fun ls + => match ls with + | inl rest => None + | inr ls + => List.fold_right + (fun x xs + => (x' <- x; xs' <- xs; Some (x' :: xs'))%option) + (Some nil) + (List.map (@puninterp_base A) ls) + end | parametric.base.type.var_with_subst subst => @Some _ end%option. - Fixpoint pinterp {t} : UnderLets (value (parametric.subst t)) -> parametric.half_interp UnderLets value t -> value (parametric.subst t) - := match t return UnderLets (value (parametric.subst t)) -> parametric.half_interp UnderLets value t -> value (parametric.subst t) with + Fixpoint pinterp {t} : UnderLets (value (parametric.subst t)) -> parametric.half_interp2 UnderLets pbase_value value t -> value_with_lets (parametric.subst t) + := match t return UnderLets (value (parametric.subst t)) -> parametric.half_interp2 UnderLets pbase_value value t -> value_with_lets (parametric.subst t) with | type.base t - => fun default partial => pinterp_base partial + => fun default partial => (partial' <-- partial; + Base (pinterp_base partial')) | type.arrow (type.base s) d - => fun fdefault fpartial (v : value (parametric.subst (type.base s))) - => let default := (fdefault' <-- fdefault; fdefault' v) in - match puninterp_base v return UnderLets (value (parametric.subst d)) with - | Some v' => Base (@pinterp d default (fpartial v')) - | None => default - end + => fun fdefault fpartial + => Base + (fun (v : value (parametric.subst (type.base s))) + => let default := (fdefault' <-- fdefault; fdefault' v) in + match puninterp_base v return UnderLets (value (parametric.subst d)) with + | Some v' => @pinterp d default (fpartial v') + | None => default + end) | type.arrow s d - => fun fdefault fpartial (v : value (parametric.subst s)) + => fun fdefault fpartial => Base - (@pinterp - d (fdefault' <-- fdefault; fdefault' v) - (fpartial v)) + (fun (v : value (parametric.subst s)) + => @pinterp + d (fdefault' <-- fdefault; fdefault' v) + (fpartial v)) end%under_lets. Local Notation bottom := (@bottom base.type abstract_domain' bottom'). - Definition interp {t} (idc : ident t) : value t - := match idc in ident.wident _ t return value t with + Definition interp {t} (idc : ident t) : value_with_lets t + := match idc in ident.wident _ t return value_with_lets t with | ident.wrap T idc' as idc => pinterp (Base (reflect (###idc) (abstract_interp_ident _ idc')))%expr - (half_interp _ idc') + match half_interp _ idc' with + | inl interp_idc => parametric.half_interp2_of_interp (fun T => @Base _ ident var T) interp_idc + | inr interp2_idc => interp2_idc + end end. - Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain'). - Definition eval_with_bound {t} (e : @expr value_with_lets t) (st : type.for_each_lhs_of_arrow abstract_domain t) : expr t @@ -549,22 +685,60 @@ Module partial. (intersect_state : forall A, abstract_domain' A -> abstract_domain' A -> abstract_domain' A) (update_literal_with_state : abstract_domain' base.type.nat -> nat -> nat) (state_of_upperbound : forall T, upperboundT T -> abstract_domain' T) - (bottom' : forall A, abstract_domain' A). + (bottom' : forall A, abstract_domain' A) + (hd_tl_list_state : forall A, abstract_domain' (base.type.list A) -> abstract_domain' A * abstract_domain' (base.type.list A)). Local Notation base_value := (@wident.base_value var pident abstract_domain'). + Local Notation pbase_value := (@wident.pbase_value var pident abstract_domain'). Local Notation value := (@value base.type ident var base_value abstract_domain'). - Local Notation intersect_state_value := (@wident.intersect_state_value var pident abstract_domain' abstract_interp_ident intersect_state update_literal_with_state (@ident.Fst) (@ident.Snd)). - - Definition half_interp {t} (idc : pident t) : parametric.half_interp UnderLets value t - := match idc in ident.pident t return parametric.half_interp UnderLets value t with - | ident.Literal v => v - | ident.Plus => Nat.add - | ident.Pair A B => @pair _ _ - | ident.Fst A B => @fst _ _ - | ident.Snd A B => @snd _ _ + Local Notation intersect_state_value := (@wident.intersect_state_value var pident abstract_domain' abstract_interp_ident intersect_state update_literal_with_state (@ident.Fst) (@ident.Snd) (@hd_tl_list_state)). + + Definition half_interp {t} (idc : pident t) + : parametric.half_interp UnderLets pbase_value value t + + parametric.half_interp2 UnderLets pbase_value value t. + refine match idc in ident.pident t return parametric.half_interp UnderLets pbase_value value t + parametric.half_interp2 UnderLets pbase_value value t with + | ident.Literal v => inl v + | ident.Plus => inl Nat.add + | ident.Pair A B => inl (@pair _ _) + | ident.Fst A B => inl (@fst _ _) + | ident.Snd A B => inl (@snd _ _) + | ident.Nil _ => inl (@nil _) + | ident.Cons _ => inl (@cons _) + | ident.List_app _ => inl (@List.app _) + | ident.List_map _ _ + => inr (fun f ls => fls <--- List.map f ls; Base fls) + | ident.List_flat_map A B + => inr (fun f ls + => list_rect + _ + (Base nil) + (fun x _ flat_map_xs + => (fx <-- f x; + flat_map_xs' <-- flat_map_xs; + _)) + ls + (* + (_ nil) + (fun x _ flat_map_xs => _(*f x ++ flat_map_xs*)) + ls*)) + | ident.List_rect A P + => inr + (fun N_case C_case ls + => list_rect + _ + (Base N_case) + (fun x xs rest + => (rest' <-- rest; + C_case <-- C_case x; + C_case <-- C_case (inr xs); + C_case rest')) + ls) | ident.Cast T upper_bound as idc - => intersect_state_value (t:=T) (state_of_upperbound _ upper_bound) - end. + => inl (intersect_state_value (t:=T) (state_of_upperbound _ upper_bound)) + end%under_lets. + cbn in *. + Print flat_map. + Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain'). |