diff options
author | Jason Gross <jgross@mit.edu> | 2018-03-31 21:05:10 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-04-04 15:39:34 -0400 |
commit | c900290d3297ade2cc2e73fe6b322abe52d1715a (patch) | |
tree | 371e4c692170ae5fbccc4e52449a08040512e0ce /src | |
parent | d97e060a3f8de0b83db89aa6c25eb4157045c275 (diff) |
Add Uncurry
This pass uncurries all applied lambdas. Care is taken to not do beta
reduction and to not introduce spurious `Abs` or `App` nodes.
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 6ee543a3a..202e21098 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -3228,6 +3228,12 @@ Module Compilers. | _ => None end. + Definition invert_or_expand_Pair {A B} (e : @expr var (type.prod A B)) : @expr var A * @expr var B + := match invert_Pair e with + | Some p => p + | None => (ident.fst @@ e, ident.snd @@ e) + end%core%expr. + (* if we want more code for the below, I would suggest [reify_base_type] and [reflect_base_type] *) Definition reify_primitive {t} (v : type.interp (type.type_primitive t)) : @expr var (type.type_primitive t) := AppIdent (ident.primitive v) TT. @@ -3368,6 +3374,189 @@ Module Compilers. := (Reify_as (type.reify_type_of v) (fun _ => v)) (only parsing). End GallinaReify. + Module Uncurry. + Module type. + Fixpoint uncurried_domain (t : type) : type + := match t with + | type.arrow s d + => match d with + | type.arrow _ _ + => s * uncurried_domain d + | _ => s + end + | _ => type.type_primitive type.unit + end%ctype. + + Definition uncurry (t : type) : type + := type.arrow (uncurried_domain t) (type.final_codomain t). + + Section with_var. + Context (var : type -> Type). + Fixpoint value (t : type) + := match t with + | type.arrow s d + => value s -> value d + | t => @expr var t + end. + End with_var. + End type. + + Fixpoint app_curried {t : type} + : type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t) + := match t return type.interp t -> type.interp (type.uncurried_domain t) -> type.interp (type.final_codomain t) with + | type.arrow s d + => match d + return (type.interp d -> type.interp (type.uncurried_domain d) -> type.interp (type.final_codomain d)) + -> type.interp (type.arrow s d) + -> type.interp (type.uncurried_domain (type.arrow s d)) + -> type.interp (type.final_codomain d) + with + | type.arrow _ _ as d + => fun app_curried_d + (f : type.interp s -> type.interp d) + (x : type.interp s * type.interp (type.uncurried_domain d)) + => app_curried_d (f (fst x)) (snd x) + | d + => fun _ + (f : type.interp s -> type.interp d) + (x : type.interp s) + => f x + end (@app_curried d) + | _ => fun f _ => f + end. + + Module expr. + Section reify_reflect. + Context {var : type -> Type}. + Fixpoint reify {t} + : type.value var t -> @expr var t + := match t with + | type.arrow s d + => fun f => Abs (fun v => @reify d (f (@reflect s (Var v)))) + | _ + => fun e => e + end%expr + with reflect {t} + : @expr var t -> type.value var t + := match t with + | type.arrow s d + => fun e (v : type.value _ s) => @reflect d (e @ (@reify s v)) + | _ + => fun e => e + end%expr. + End reify_reflect. + + Section with_var. + Context {var : type -> Type}. + + Definition reassociate_uncurried_domain_r_to_l {s s' d'} + : @expr var (type.uncurried_domain (s -> s' -> d')) + -> @expr var (type.uncurried_domain (s * s' -> d')) + := match d' + return (expr (type.uncurried_domain (s -> s' -> d')) + -> expr (type.uncurried_domain (s * s' -> d'))) + with + | type.arrow _ _ as d' + => fun (e : expr (s * (s' * type.uncurried_domain d'))) + => let '(e, e') := invert_or_expand_Pair e in + let '(e', e'') := invert_or_expand_Pair e' in + (e, e', e'')%expr + | _ + => fun e => e + end%core%expr. + + Definition reassociate_uncurried_domain_l_to_r {s s' d'} + : @expr var (type.uncurried_domain (s * s' -> d')) + -> @expr var (type.uncurried_domain (s -> s' -> d')) + := match d' + return (expr (type.uncurried_domain (s * s' -> d')) + -> expr (type.uncurried_domain (s -> s' -> d'))) + with + | type.arrow _ _ as d' + => fun (e : expr ((s * s') * type.uncurried_domain d')) + => let '(e, e'') := invert_or_expand_Pair e in + let '(e, e') := invert_or_expand_Pair e in + (e, (e', e''))%expr + | _ + => fun e => e + end%core%expr. + + Fixpoint uncurried_abs {s d} + : (type.value var s -> type.value var d) + -> @expr var (type.uncurried_domain (type.arrow s d)) + -> @expr var (type.final_codomain d) + := match d with + | type.arrow s' d' + => fun f x + => @uncurried_abs + (s * s')%ctype d' + (fun xy + => let '(x, y) := invert_or_expand_Pair xy in + f (reflect x) (reflect y)) + (reassociate_uncurried_domain_r_to_l x) + | _ + => fun f x => f (reflect x) + end%core%expr. + + Fixpoint uncurried_app_to_value {s d} + : (@expr var (type.uncurried_domain (type.arrow s d)) + -> @expr var (type.final_codomain d)) + -> type.value var s + -> type.value var d + := match d with + | type.arrow s' d' + => fun f x (y : type.value var s') + => @uncurried_app_to_value + (s * s')%ctype d' + (fun x' => f (reassociate_uncurried_domain_l_to_r x')) + (reify x, reify y) + | _ + => fun f x + => f (reify x) + end%expr. + + Definition uncurry_value {s d} (f : type.value var (s -> d)) + (x : type.value var s) + : type.value var d + := uncurried_app_to_value + (fun x' => Abs (fun v => uncurried_abs f (Var v)) @ x')%expr + x. + + (** N.B. We only uncurry things when we hit an application of + a lambda; everything else is untouched. *) + Fixpoint uncurry' {t} (e : @expr (type.value var) t) : type.value var t + := match e in expr.expr t return type.value var t with + | Var t v => v + | TT => TT + | AppIdent s d idc args + => reflect (AppIdent idc (reify (@uncurry' _ args))) + | App s d f x + => let f' := @uncurry' _ f in + let x' := @uncurry' _ x in + match invert_Abs f with + | Some _ => uncurry_value f' x' + | None => f' x' + end + | Pair A B a b + => Pair (reify (@uncurry' A a)) (reify (@uncurry' B b)) + | Abs s d f + => fun v : type.value var s => @uncurry' d (f v) + end. + + Definition uncurry {t} (e : @expr (type.value var) t) + : @expr var (type.uncurry t) + := Abs (fun v : var (type.uncurried_domain t) + => match t return type.value var t -> expr (type.uncurried_domain t) -> expr (type.final_codomain t) with + | type.arrow _ _ => uncurried_abs + | _ => fun e _ => e + end (uncurry' e) (Var v)). + End with_var. + + Definition Uncurry {t} (e : Expr t) : Expr (type.uncurry t) + := fun var => uncurry (e _). + End expr. + End Uncurry. + Module CPS. Import Uncurried. Module Import Output. @@ -6385,6 +6574,48 @@ Module test9. exact I. Qed. End test9. +Module test10. + Example test10 : True. + Proof. + let v := Reify (fun (f : Z -> Z -> Z) x y => f (x + y) (x * y))%Z in + pose v as E. + vm_compute in E. + pose (Uncurry.expr.Uncurry (PartialEvaluate true (canonicalize_list_recursion E))) as E'. + lazy in E'. + clear E. + lazymatch (eval cbv delta [E'] in E') with + | (fun var => + (λ v, + ident.fst @@ Var v @ + (ident.fst @@ (ident.snd @@ Var v) + ident.snd @@ (ident.snd @@ Var v)) @ + (ident.fst @@ (ident.snd @@ Var v) * ident.snd @@ (ident.snd @@ Var v)))%expr) + => idtac + end. + constructor. + Qed. +End test10. +Module test11. + Example test11 : True. + Proof. + let v := Reify (fun x y => (fun f a b => f a b) (fun a b => a + b) (x + y) (x * y))%Z in + pose v as E. + vm_compute in E. + pose (Uncurry.expr.Uncurry (canonicalize_list_recursion E)) as E'. + lazy in E'. + clear E. + lazymatch (eval cbv delta [E'] in E') with + | (fun var => + (λ v, + (λ v0, + ident.fst @@ Var v0 @ (ident.fst @@ (ident.snd @@ Var v0)) @ + (ident.snd @@ (ident.snd @@ Var v0))) @ + ((λ v0' v1, Var v0' + Var v1), + (ident.fst @@ Var v + ident.snd @@ Var v, ident.fst @@ Var v * ident.snd @@ Var v)))%expr) + => idtac + end. + constructor. + Qed. +End test11. Axiom admit_pf : False. Notation admit := (match admit_pf with end). Ltac cache_reify _ := |