aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-03-31 21:05:10 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-04-04 15:39:34 -0400
commitc900290d3297ade2cc2e73fe6b322abe52d1715a (patch)
tree371e4c692170ae5fbccc4e52449a08040512e0ce /src
parentd97e060a3f8de0b83db89aa6c25eb4157045c275 (diff)
Add Uncurry
This pass uncurries all applied lambdas. Care is taken to not do beta reduction and to not introduce spurious `Abs` or `App` nodes.
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v231
1 files changed, 231 insertions, 0 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 6ee543a3a..202e21098 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -3228,6 +3228,12 @@ Module Compilers.
| _ => None
end.
+ Definition invert_or_expand_Pair {A B} (e : @expr var (type.prod A B)) : @expr var A * @expr var B
+ := match invert_Pair e with
+ | Some p => p
+ | None => (ident.fst @@ e, ident.snd @@ e)
+ end%core%expr.
+
(* if we want more code for the below, I would suggest [reify_base_type] and [reflect_base_type] *)
Definition reify_primitive {t} (v : type.interp (type.type_primitive t)) : @expr var (type.type_primitive t)
:= AppIdent (ident.primitive v) TT.
@@ -3368,6 +3374,189 @@ Module Compilers.
:= (Reify_as (type.reify_type_of v) (fun _ => v)) (only parsing).
End GallinaReify.
+ Module Uncurry.
+ Module type.
+ Fixpoint uncurried_domain (t : type) : type
+ := match t with
+ | type.arrow s d
+ => match d with
+ | type.arrow _ _
+ => s * uncurried_domain d
+ | _ => s
+ end
+ | _ => type.type_primitive type.unit
+ end%ctype.
+
+ Definition uncurry (t : type) : type
+ := type.arrow (uncurried_domain t) (type.final_codomain t).
+
+ Section with_var.
+ Context (var : type -> Type).
+ Fixpoint value (t : type)
+ := match t with
+ | type.arrow s d
+ => value s -> value d
+ | t => @expr var t
+ end.
+ End with_var.
+ End type.
+
+ Fixpoint app_curried {t : type}
+ : type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t)
+ := match t return type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t) with
+ | type.arrow s d
+ => match d
+ return (type.interp d -> type.interp (type.uncurried_domain d) -> type.interp (type.final_codomain d))
+ -> type.interp (type.arrow s d)
+ -> type.interp (type.uncurried_domain (type.arrow s d))
+ -> type.interp (type.final_codomain d)
+ with
+ | type.arrow _ _ as d
+ => fun app_curried_d
+ (f : type.interp s -> type.interp d)
+ (x : type.interp s * type.interp (type.uncurried_domain d))
+ => app_curried_d (f (fst x)) (snd x)
+ | d
+ => fun _
+ (f : type.interp s -> type.interp d)
+ (x : type.interp s)
+ => f x
+ end (@app_curried d)
+ | _ => fun f _ => f
+ end.
+
+ Module expr.
+ Section reify_reflect.
+ Context {var : type -> Type}.
+ Fixpoint reify {t}
+ : type.value var t -> @expr var t
+ := match t with
+ | type.arrow s d
+ => fun f => Abs (fun v => @reify d (f (@reflect s (Var v))))
+ | _
+ => fun e => e
+ end%expr
+ with reflect {t}
+ : @expr var t -> type.value var t
+ := match t with
+ | type.arrow s d
+ => fun e (v : type.value _ s) => @reflect d (e @ (@reify s v))
+ | _
+ => fun e => e
+ end%expr.
+ End reify_reflect.
+
+ Section with_var.
+ Context {var : type -> Type}.
+
+ Definition reassociate_uncurried_domain_r_to_l {s s' d'}
+ : @expr var (type.uncurried_domain (s -> s' -> d'))
+ -> @expr var (type.uncurried_domain (s * s' -> d'))
+ := match d'
+ return (expr (type.uncurried_domain (s -> s' -> d'))
+ -> expr (type.uncurried_domain (s * s' -> d')))
+ with
+ | type.arrow _ _ as d'
+ => fun (e : expr (s * (s' * type.uncurried_domain d')))
+ => let '(e, e') := invert_or_expand_Pair e in
+ let '(e', e'') := invert_or_expand_Pair e' in
+ (e, e', e'')%expr
+ | _
+ => fun e => e
+ end%core%expr.
+
+ Definition reassociate_uncurried_domain_l_to_r {s s' d'}
+ : @expr var (type.uncurried_domain (s * s' -> d'))
+ -> @expr var (type.uncurried_domain (s -> s' -> d'))
+ := match d'
+ return (expr (type.uncurried_domain (s * s' -> d'))
+ -> expr (type.uncurried_domain (s -> s' -> d')))
+ with
+ | type.arrow _ _ as d'
+ => fun (e : expr ((s * s') * type.uncurried_domain d'))
+ => let '(e, e'') := invert_or_expand_Pair e in
+ let '(e, e') := invert_or_expand_Pair e in
+ (e, (e', e''))%expr
+ | _
+ => fun e => e
+ end%core%expr.
+
+ Fixpoint uncurried_abs {s d}
+ : (type.value var s -> type.value var d)
+ -> @expr var (type.uncurried_domain (type.arrow s d))
+ -> @expr var (type.final_codomain d)
+ := match d with
+ | type.arrow s' d'
+ => fun f x
+ => @uncurried_abs
+ (s * s')%ctype d'
+ (fun xy
+ => let '(x, y) := invert_or_expand_Pair xy in
+ f (reflect x) (reflect y))
+ (reassociate_uncurried_domain_r_to_l x)
+ | _
+ => fun f x => f (reflect x)
+ end%core%expr.
+
+ Fixpoint uncurried_app_to_value {s d}
+ : (@expr var (type.uncurried_domain (type.arrow s d))
+ -> @expr var (type.final_codomain d))
+ -> type.value var s
+ -> type.value var d
+ := match d with
+ | type.arrow s' d'
+ => fun f x (y : type.value var s')
+ => @uncurried_app_to_value
+ (s * s')%ctype d'
+ (fun x' => f (reassociate_uncurried_domain_l_to_r x'))
+ (reify x, reify y)
+ | _
+ => fun f x
+ => f (reify x)
+ end%expr.
+
+ Definition uncurry_value {s d} (f : type.value var (s -> d))
+ (x : type.value var s)
+ : type.value var d
+ := uncurried_app_to_value
+ (fun x' => Abs (fun v => uncurried_abs f (Var v)) @ x')%expr
+ x.
+
+ (** N.B. We only uncurry things when we hit an application of
+ a lambda; everything else is untouched. *)
+ Fixpoint uncurry' {t} (e : @expr (type.value var) t) : type.value var t
+ := match e in expr.expr t return type.value var t with
+ | Var t v => v
+ | TT => TT
+ | AppIdent s d idc args
+ => reflect (AppIdent idc (reify (@uncurry' _ args)))
+ | App s d f x
+ => let f' := @uncurry' _ f in
+ let x' := @uncurry' _ x in
+ match invert_Abs f with
+ | Some _ => uncurry_value f' x'
+ | None => f' x'
+ end
+ | Pair A B a b
+ => Pair (reify (@uncurry' A a)) (reify (@uncurry' B b))
+ | Abs s d f
+ => fun v : type.value var s => @uncurry' d (f v)
+ end.
+
+ Definition uncurry {t} (e : @expr (type.value var) t)
+ : @expr var (type.uncurry t)
+ := Abs (fun v : var (type.uncurried_domain t)
+ => match t return type.value var t -> expr (type.uncurried_domain t) -> expr (type.final_codomain t) with
+ | type.arrow _ _ => uncurried_abs
+ | _ => fun e _ => e
+ end (uncurry' e) (Var v)).
+ End with_var.
+
+ Definition Uncurry {t} (e : Expr t) : Expr (type.uncurry t)
+ := fun var => uncurry (e _).
+ End expr.
+ End Uncurry.
+
Module CPS.
Import Uncurried.
Module Import Output.
@@ -6385,6 +6574,48 @@ Module test9.
exact I.
Qed.
End test9.
+Module test10.
+ Example test10 : True.
+ Proof.
+ let v := Reify (fun (f : Z -> Z -> Z) x y => f (x + y) (x * y))%Z in
+ pose v as E.
+ vm_compute in E.
+ pose (Uncurry.expr.Uncurry (PartialEvaluate true (canonicalize_list_recursion E))) as E'.
+ lazy in E'.
+ clear E.
+ lazymatch (eval cbv delta [E'] in E') with
+ | (fun var =>
+ (λ v,
+ ident.fst @@ Var v @
+ (ident.fst @@ (ident.snd @@ Var v) + ident.snd @@ (ident.snd @@ Var v)) @
+ (ident.fst @@ (ident.snd @@ Var v) * ident.snd @@ (ident.snd @@ Var v)))%expr)
+ => idtac
+ end.
+ constructor.
+ Qed.
+End test10.
+Module test11.
+ Example test11 : True.
+ Proof.
+ let v := Reify (fun x y => (fun f a b => f a b) (fun a b => a + b) (x + y) (x * y))%Z in
+ pose v as E.
+ vm_compute in E.
+ pose (Uncurry.expr.Uncurry (canonicalize_list_recursion E)) as E'.
+ lazy in E'.
+ clear E.
+ lazymatch (eval cbv delta [E'] in E') with
+ | (fun var =>
+ (λ v,
+ (λ v0,
+ ident.fst @@ Var v0 @ (ident.fst @@ (ident.snd @@ Var v0)) @
+ (ident.snd @@ (ident.snd @@ Var v0))) @
+ ((λ v0' v1, Var v0' + Var v1),
+ (ident.fst @@ Var v + ident.snd @@ Var v, ident.fst @@ Var v * ident.snd @@ Var v)))%expr)
+ => idtac
+ end.
+ constructor.
+ Qed.
+End test11.
Axiom admit_pf : False.
Notation admit := (match admit_pf with end).
Ltac cache_reify _ :=