aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-04-30 16:59:57 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-05-05 18:01:31 -0400
commitdd204a68a92a7394962cd8183b74ffb215b706f7 (patch)
tree5632488227ccdfcb29fce626a05ca40af04f4d6b /src/Experiments
parent850f21f5867b21c5beeb5a23fb54c2622514998d (diff)
WIP on lists as cons cells
Diffstat (limited to 'src/Experiments')
-rw-r--r--src/Experiments/PartialEvaluationWithLetIn.v268
1 files changed, 221 insertions, 47 deletions
diff --git a/src/Experiments/PartialEvaluationWithLetIn.v b/src/Experiments/PartialEvaluationWithLetIn.v
index 5a7b816f6..7fa0641ef 100644
--- a/src/Experiments/PartialEvaluationWithLetIn.v
+++ b/src/Experiments/PartialEvaluationWithLetIn.v
@@ -1,3 +1,4 @@
+Require Import Coq.Lists.List.
Require Import Crypto.Util.Option.
Require Import Crypto.Util.Notations.
@@ -41,7 +42,7 @@ Bind Scope etype_scope with type.type.
Infix "->" := type.arrow : etype_scope.
Module base.
Module type.
- Inductive type := nat | prod (A B : type).
+ Inductive type := nat | prod (A B : type) | list (A : type).
End type.
Notation type := type.type.
End base.
@@ -56,7 +57,8 @@ Module parametric.
(base_subst : base_type_with_var -> base_type)
(base_interp : base_type_with_var -> Type)
(base_subst_interp : base_type -> Type)
- (M : Type -> Type).
+ (M : Type -> Type)
+ (ret : forall T, T -> M T).
Fixpoint subst (t : type base_type_with_var) : type base_type
:= match t with
@@ -74,19 +76,38 @@ Module parametric.
| type.base t => base_interp t
end -> interp d
end.
+
+ Fixpoint interpM_final (t : type base_type_with_var) : Type
+ := match t with
+ | type.base t => M (base_interp t)
+ | type.arrow s d
+ => match s with
+ | type.arrow s' d' => type.interpM M base_subst_interp (subst s)
+ | type.base t => base_interp t
+ end -> interpM_final d
+ end.
+
+ Fixpoint interpM_final_of_interp {t} : interp t -> interpM_final t
+ := match t with
+ | type.base t => ret _
+ | type.arrow s d
+ => fun f x => @interpM_final_of_interp d (f x)
+ end.
End subst.
End type.
Local Notation btype := base.type.type.
Local Notation bnat := base.type.nat.
Local Notation bprod := base.type.prod.
+ Local Notation blist := base.type.list.
Module base.
Module type.
- Inductive type := nat | prod (A B : type) | var_with_subst (subst : btype).
+ Inductive type := nat | prod (A B : type) | list (A : type) | var_with_subst (subst : btype).
Fixpoint subst (t : type) : btype
:= match t with
| nat => bnat
| prod A B => bprod (subst A) (subst B)
+ | list A => blist (subst A)
| var_with_subst s => s
end.
@@ -97,6 +118,7 @@ Module parametric.
:= match t with
| nat => Datatypes.nat
| prod A B => interp A * interp B
+ | list A => Datatypes.list (interp A)
| var_with_subst s => base_interp s
end%type.
End interp.
@@ -107,8 +129,13 @@ Module parametric.
Definition subst (t : type base.type) : type btype
:= type.subst base.type.subst t.
- Definition half_interp (M : Type -> Type) (interp : btype -> Type) (t : type base.type) : Type
- := type.interp base.type.subst (base.type.interp interp) interp M t.
+ Definition half_interp (M : Type -> Type) (half_interp : base.type.type -> Type) (interp : btype -> Type) (t : type base.type) : Type
+ := type.interp base.type.subst half_interp interp M t.
+ Definition half_interp2 (M : Type -> Type) (half_interp : base.type.type -> Type) (interp : btype -> Type) (t : type base.type) : Type
+ := type.interpM_final base.type.subst half_interp interp M t.
+ Definition half_interp2_of_interp {M half_interpf interp t} ret
+ : half_interp M half_interpf interp t -> half_interp2 M half_interpf interp t
+ := type.interpM_final_of_interp _ _ _ _ ret.
End parametric.
Notation ptype := (type.type parametric.base.type).
Delimit Scope ptype_scope with ptype.
@@ -121,6 +148,7 @@ Fixpoint upperboundT (t : base.type) : Type
:= match t with
| base.type.nat => option nat
| base.type.prod A B => upperboundT A * upperboundT B
+ | base.type.list A => option (list (upperboundT A))
end.
Module expr.
@@ -168,6 +196,12 @@ Module ident.
| Pair {A B : base.type} : pident (#A -> #B -> #A * #B)%ptype
| Fst {A B} : pident (#A * #B -> #A)%ptype
| Snd {A B} : pident (#A * #B -> #B)%ptype
+ | Nil {A} : pident (parametric.base.type.list #A)%ptype
+ | Cons {A} : pident (#A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype
+ | List_map {A B} : pident ((#A -> #B) -> parametric.base.type.list #A -> parametric.base.type.list #B)%ptype
+ | List_app {A} : pident (parametric.base.type.list #A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype
+ | List_flat_map {A B} : pident ((#A -> parametric.base.type.list #B) -> parametric.base.type.list #A -> parametric.base.type.list #B)%ptype
+ | List_rect {A P} : pident (#P -> (#A -> parametric.base.type.list #A -> #P -> #P) -> parametric.base.type.list #A -> #P)%ptype
| Cast {T} (upper_bound : upperboundT T) : pident (#T -> #T)%ptype
.
@@ -189,6 +223,12 @@ Module ident.
Notation "# x" := (expr.Ident (wrap x)) (at level 9, x at level 10, format "# x") : expr_scope.
Notation "( x , y , .. , z )" := (expr.App (expr.App (#Pair) .. (expr.App (expr.App (#Pair) x%expr) y%expr) .. ) z%expr) : expr_scope.
Notation "x + y" := (#Plus @ x @ y)%expr : expr_scope.
+ (*Notation "x :: y" := (#Cons @ x @ y)%expr : expr_scope.*)
+ (* Unification fails if we don't fill in [wident pident] explicitly *)
+ Notation "x :: y" := (@expr.App base.type.type (wident pident) _ _ _ (#Cons @ x) y)%expr : expr_scope.
+ Notation "[ ]" := (#Nil)%expr : expr_scope.
+ Notation "[ x ]" := (#Cons @ x @ (#Nil))%expr : expr_scope.
+ Notation "[ x ; y ; .. ; z ]" := (@expr.App base.type.type (wident pident) _ _ _ (#Cons @ x) (#Cons @ y @ .. (#Cons @ z @ #Nil) ..))%expr : expr_scope.
End Notations.
End ident.
Import ident.Notations.
@@ -214,6 +254,13 @@ Module UnderLets.
| UnderLet A x f => UnderLet x (fun v => @splice _ _ (f v) e)
end.
+ Fixpoint splice_list {A B} (ls : list (@UnderLets A)) (e : list A -> @UnderLets B) : @UnderLets B
+ := match ls with
+ | nil => e nil
+ | cons x xs
+ => splice x (fun x => @splice_list A B xs (fun xs => e (cons x xs)))
+ end.
+
Fixpoint to_expr {t} (x : @UnderLets (expr t)) : expr t
:= match x with
| Base v => v
@@ -226,6 +273,9 @@ End UnderLets.
Delimit Scope under_lets_scope with under_lets.
Bind Scope under_lets_scope with UnderLets.UnderLets.
Notation "x <-- y ; f" := (UnderLets.splice y (fun x => f%under_lets)) : under_lets_scope.
+(** FIXME: MOVE ME *)
+Reserved Notation "A <--- X ; B" (at level 70, X at next level, right associativity, format "'[v' A <--- X ; '/' B ']'").
+Notation "x <--- y ; f" := (UnderLets.splice_list y (fun x => f%under_lets)) : under_lets_scope.
Module partial.
Import UnderLets.
@@ -254,7 +304,7 @@ Module partial.
Definition value_with_lets (t : type)
:= UnderLets (value t).
- Context (interp_ident : forall t, ident t -> value t).
+ Context (interp_ident : forall t, ident t -> value_with_lets t).
Definition abstract_domain (t : type)
:= type.interp abstract_domain' t.
@@ -312,7 +362,7 @@ Module partial.
Fixpoint interp {t} (e : @expr value_with_lets t) : value_with_lets t
:= match e in expr.expr t return value_with_lets t with
- | expr.Ident t idc => Base (interp_ident t idc)
+ | expr.Ident t idc => interp_ident t idc
| expr.Var t v => v
| expr.Abs s d f => Base (fun x => @interp d (f (Base x)))
| expr.App s d f x
@@ -370,15 +420,38 @@ Module partial.
barrier in both directions a decent amount *)
(ident_Literal : nat -> pident parametric.base.type.nat)
(ident_Pair : forall A B, pident (#A -> #B -> #A * #B)%ptype)
+ (ident_Nil : forall A, pident (parametric.base.type.list #A)%ptype)
+ (ident_Cons : forall A, pident (#A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype)
+ (ident_List_app : forall A, pident (parametric.base.type.list #A -> parametric.base.type.list #A -> parametric.base.type.list #A)%ptype)
(ident_Fst : forall A B, pident (#A * #B -> #A)%ptype)
- (ident_Snd : forall A B, pident (#A * #B -> #B)%ptype).
+ (ident_Snd : forall A B, pident (#A * #B -> #B)%ptype)
+ (hd_tl_list_state : forall A, abstract_domain' (base.type.list A) -> abstract_domain' A * abstract_domain' (base.type.list A)).
+ Local Notation expr_with_abs A
+ := (prod (abstract_domain' A) (@expr var A)).
+ Local Notation expr_or base_value A
+ := (sum (expr_with_abs A) (base_value A%etype)).
Fixpoint base_value (t : base.type)
:= match t return Type with
| base.type.nat as t
=> nat
| base.type.prod A B as t
- => (abstract_domain' A * @expr var A + base_value A) * (abstract_domain' B * @expr var B + base_value B)
+ => (expr_or base_value A) * (expr_or base_value B)
+ | base.type.list A as t
+ => list (expr_or base_value A) (* cons cells *)
+ end%type.
+ Local Notation value := (@value base.type ident var base_value abstract_domain').
+ Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain').
+ Fixpoint pbase_value (t : parametric.base.type)
+ := match t return Type with
+ | parametric.base.type.nat as t
+ => nat
+ | parametric.base.type.prod A B as t
+ => pbase_value A * pbase_value B
+ | parametric.base.type.list A as t
+ => list (pbase_value A)
+ | parametric.base.type.var_with_subst A as t
+ => value A
end%type.
Fixpoint abstraction_function {t} : base_value t -> abstract_domain' t
@@ -397,6 +470,19 @@ Module partial.
end in
abstract_interp_ident
_ (ident_Pair A B) sta stb
+ | base.type.list A
+ => fun cells
+ => let st_cells
+ := List.map
+ (fun a => match a with
+ | inl (st, _) => st
+ | inr a' => @abstraction_function A a'
+ end)
+ cells in
+ List.fold_right
+ (abstract_interp_ident _ (ident_Cons A))
+ (abstract_interp_ident _ (ident_Nil A))
+ st_cells
end.
Fixpoint base_reify {t} : base_value t -> @expr var t
@@ -422,11 +508,30 @@ Module partial.
| inr v => @base_reify _ v
end in
(#(ident_Pair A B) @ ea @ eb)%expr
+ | base.type.list A
+ => fun cells
+ => let cells'
+ := List.map
+ (fun a
+ => match a with
+ | inl (st, e) (* list chunk *)
+ => match annotate _ st with
+ | None => e
+ | Some cst => ###cst @ e
+ end%expr
+ | inr v
+ => @base_reify _ v
+ end)
+ cells in
+ List.fold_right
+ (fun x xs => (#(ident_Cons A) @ x @ xs)%expr)
+ (#(ident_Nil A))%expr
+ cells'
end.
- Local Notation value := (@value base.type ident var base_value abstract_domain').
-
- Context (half_interp : forall {t} (idc : pident t), parametric.half_interp UnderLets value t).
+ Context (half_interp : forall {t} (idc : pident t),
+ parametric.half_interp UnderLets pbase_value value t
+ + parametric.half_interp2 UnderLets pbase_value value t).
Fixpoint intersect_state_base_value {t} : abstract_domain' t -> base_value t -> base_value t
:= match t return abstract_domain' t -> base_value t -> base_value t with
@@ -444,6 +549,20 @@ Module partial.
| inr v => inr (@intersect_state_base_value _ stb v)
end in
(a', b')
+ | base.type.list _
+ => fun st cells
+ => let '(cells', st)
+ := List.fold_left
+ (fun '(rest_cells, st) cell
+ => let '(st0, st') := hd_tl_list_state _ st in
+ (match cell with
+ | inl (st0', e) => inl (intersect_state _ st0 st0', e)
+ | inr v => inr (@intersect_state_base_value _ st0 v)
+ end :: rest_cells,
+ st'))
+ cells
+ (nil, st) in
+ cells'
end.
@@ -459,21 +578,22 @@ Module partial.
end.
Local Notation reify := (@reify base.type ident var base_value abstract_domain' annotate bottom' (@abstraction_function) (@base_reify)).
- Print reflect.
Local Notation reflect := (@reflect base.type ident var base_value abstract_domain' annotate bottom' (@abstraction_function) (@base_reify)).
- Fixpoint pinterp_base {t : parametric.base.type} : parametric.half_interp UnderLets value (type.base t) -> value (parametric.subst (type.base t))
- := match t return parametric.half_interp UnderLets value (type.base t) -> value (parametric.subst (type.base t)) with
+ Fixpoint pinterp_base {t : parametric.base.type} : parametric.half_interp UnderLets pbase_value value (type.base t) -> value (parametric.subst (type.base t))
+ := match t return parametric.half_interp UnderLets pbase_value value (type.base t) -> value (parametric.subst (type.base t)) with
| parametric.base.type.nat
=> fun v => inr v
| parametric.base.type.prod A B
=> fun '(a, b) => inr (@pinterp_base A a, @pinterp_base B b)
+ | parametric.base.type.list A
+ => fun ls => inr (List.map (@pinterp_base A) ls)
| parametric.base.type.var_with_subst subst
=> fun v => v
end.
- Fixpoint puninterp_base {t : parametric.base.type} : value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets value (type.base t))
- := match t return value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets value (type.base t)) with
+ Fixpoint puninterp_base {t : parametric.base.type} : value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets pbase_value value (type.base t))
+ := match t return value (parametric.subst (type.base t)) -> option (parametric.half_interp UnderLets pbase_value value (type.base t)) with
| parametric.base.type.nat
=> fun v
=> match v with
@@ -489,41 +609,57 @@ Module partial.
b' <- @puninterp_base B b;
Some (a', b'))
end
+ | parametric.base.type.list A
+ => fun ls
+ => match ls with
+ | inl rest => None
+ | inr ls
+ => List.fold_right
+ (fun x xs
+ => (x' <- x; xs' <- xs; Some (x' :: xs'))%option)
+ (Some nil)
+ (List.map (@puninterp_base A) ls)
+ end
| parametric.base.type.var_with_subst subst
=> @Some _
end%option.
- Fixpoint pinterp {t} : UnderLets (value (parametric.subst t)) -> parametric.half_interp UnderLets value t -> value (parametric.subst t)
- := match t return UnderLets (value (parametric.subst t)) -> parametric.half_interp UnderLets value t -> value (parametric.subst t) with
+ Fixpoint pinterp {t} : UnderLets (value (parametric.subst t)) -> parametric.half_interp2 UnderLets pbase_value value t -> value_with_lets (parametric.subst t)
+ := match t return UnderLets (value (parametric.subst t)) -> parametric.half_interp2 UnderLets pbase_value value t -> value_with_lets (parametric.subst t) with
| type.base t
- => fun default partial => pinterp_base partial
+ => fun default partial => (partial' <-- partial;
+ Base (pinterp_base partial'))
| type.arrow (type.base s) d
- => fun fdefault fpartial (v : value (parametric.subst (type.base s)))
- => let default := (fdefault' <-- fdefault; fdefault' v) in
- match puninterp_base v return UnderLets (value (parametric.subst d)) with
- | Some v' => Base (@pinterp d default (fpartial v'))
- | None => default
- end
+ => fun fdefault fpartial
+ => Base
+ (fun (v : value (parametric.subst (type.base s)))
+ => let default := (fdefault' <-- fdefault; fdefault' v) in
+ match puninterp_base v return UnderLets (value (parametric.subst d)) with
+ | Some v' => @pinterp d default (fpartial v')
+ | None => default
+ end)
| type.arrow s d
- => fun fdefault fpartial (v : value (parametric.subst s))
+ => fun fdefault fpartial
=> Base
- (@pinterp
- d (fdefault' <-- fdefault; fdefault' v)
- (fpartial v))
+ (fun (v : value (parametric.subst s))
+ => @pinterp
+ d (fdefault' <-- fdefault; fdefault' v)
+ (fpartial v))
end%under_lets.
Local Notation bottom := (@bottom base.type abstract_domain' bottom').
- Definition interp {t} (idc : ident t) : value t
- := match idc in ident.wident _ t return value t with
+ Definition interp {t} (idc : ident t) : value_with_lets t
+ := match idc in ident.wident _ t return value_with_lets t with
| ident.wrap T idc' as idc
=> pinterp
(Base (reflect (###idc) (abstract_interp_ident _ idc')))%expr
- (half_interp _ idc')
+ match half_interp _ idc' with
+ | inl interp_idc => parametric.half_interp2_of_interp (fun T => @Base _ ident var T) interp_idc
+ | inr interp2_idc => interp2_idc
+ end
end.
- Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain').
-
Definition eval_with_bound {t} (e : @expr value_with_lets t)
(st : type.for_each_lhs_of_arrow abstract_domain t)
: expr t
@@ -549,22 +685,60 @@ Module partial.
(intersect_state : forall A, abstract_domain' A -> abstract_domain' A -> abstract_domain' A)
(update_literal_with_state : abstract_domain' base.type.nat -> nat -> nat)
(state_of_upperbound : forall T, upperboundT T -> abstract_domain' T)
- (bottom' : forall A, abstract_domain' A).
+ (bottom' : forall A, abstract_domain' A)
+ (hd_tl_list_state : forall A, abstract_domain' (base.type.list A) -> abstract_domain' A * abstract_domain' (base.type.list A)).
Local Notation base_value := (@wident.base_value var pident abstract_domain').
+ Local Notation pbase_value := (@wident.pbase_value var pident abstract_domain').
Local Notation value := (@value base.type ident var base_value abstract_domain').
- Local Notation intersect_state_value := (@wident.intersect_state_value var pident abstract_domain' abstract_interp_ident intersect_state update_literal_with_state (@ident.Fst) (@ident.Snd)).
-
- Definition half_interp {t} (idc : pident t) : parametric.half_interp UnderLets value t
- := match idc in ident.pident t return parametric.half_interp UnderLets value t with
- | ident.Literal v => v
- | ident.Plus => Nat.add
- | ident.Pair A B => @pair _ _
- | ident.Fst A B => @fst _ _
- | ident.Snd A B => @snd _ _
+ Local Notation intersect_state_value := (@wident.intersect_state_value var pident abstract_domain' abstract_interp_ident intersect_state update_literal_with_state (@ident.Fst) (@ident.Snd) (@hd_tl_list_state)).
+
+ Definition half_interp {t} (idc : pident t)
+ : parametric.half_interp UnderLets pbase_value value t
+ + parametric.half_interp2 UnderLets pbase_value value t.
+ refine match idc in ident.pident t return parametric.half_interp UnderLets pbase_value value t + parametric.half_interp2 UnderLets pbase_value value t with
+ | ident.Literal v => inl v
+ | ident.Plus => inl Nat.add
+ | ident.Pair A B => inl (@pair _ _)
+ | ident.Fst A B => inl (@fst _ _)
+ | ident.Snd A B => inl (@snd _ _)
+ | ident.Nil _ => inl (@nil _)
+ | ident.Cons _ => inl (@cons _)
+ | ident.List_app _ => inl (@List.app _)
+ | ident.List_map _ _
+ => inr (fun f ls => fls <--- List.map f ls; Base fls)
+ | ident.List_flat_map A B
+ => inr (fun f ls
+ => list_rect
+ _
+ (Base nil)
+ (fun x _ flat_map_xs
+ => (fx <-- f x;
+ flat_map_xs' <-- flat_map_xs;
+ _))
+ ls
+ (*
+ (_ nil)
+ (fun x _ flat_map_xs => _(*f x ++ flat_map_xs*))
+ ls*))
+ | ident.List_rect A P
+ => inr
+ (fun N_case C_case ls
+ => list_rect
+ _
+ (Base N_case)
+ (fun x xs rest
+ => (rest' <-- rest;
+ C_case <-- C_case x;
+ C_case <-- C_case (inr xs);
+ C_case rest'))
+ ls)
| ident.Cast T upper_bound as idc
- => intersect_state_value (t:=T) (state_of_upperbound _ upper_bound)
- end.
+ => inl (intersect_state_value (t:=T) (state_of_upperbound _ upper_bound))
+ end%under_lets.
+ cbn in *.
+ Print flat_map.
+
Local Notation value_with_lets := (@value_with_lets base.type ident var base_value abstract_domain').