aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-12-15 17:49:23 -0500
committerGravatar Jason Gross <jasongross9@gmail.com>2018-01-29 18:04:58 -0500
commit22945da8d42e693f53f2d43d82cb60c6962ebd0b (patch)
tree77023fef82adf29718a5a01bb858e32baf1169d5 /src
parentb54a55a122653e62f222628e1e83e5de5a7abef9 (diff)
Add a CPS conversion pass
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v192
1 files changed, 188 insertions, 4 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index e6faf462f..9c4ceb851 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -82,12 +82,12 @@ Module Associational.
Proof. cbv [reduce]; push.
rewrite <-reduction_rule, eval_split; trivial. Qed.
Hint Rewrite eval_reduce : push_eval.
-
+
Section Carries.
Context {modulo div : Z -> Z -> Z}.
Context {div_mod : forall a b:Z, b <> 0 ->
a = b * (div a b) + modulo a b}.
-
+
Definition carryterm (w fw:Z) (t:Z * Z) :=
if (Z.eqb (fst t) w)
then dlet t2 := snd t in
@@ -249,7 +249,7 @@ Module Positional. Section Positional.
(* N.B. It is important to reverse [idxs] here, because fold_right is
written such that the first terms in the list are actually used
last in the computation. For example, running:
-
+
`Eval cbv - [Z.add] in (fun a b c d => fold_right Z.add d [a;b;c]).`
will produce [fun a b c d => (a + (b + (c + d)))].*)
@@ -336,7 +336,7 @@ Module Compilers.
Bind Scope expr_scope with expr.
Delimit Scope expr_scope with expr.
- Notation "f x" := (App f x) (only printing) : expr_scope.
+ Infix "@" := App : expr_scope.
Notation "'λ' x .. y , t" := (Abs (fun x => .. (Abs (fun y => t%expr)) ..)) : expr_scope.
End Notations.
@@ -1142,6 +1142,190 @@ Module Compilers.
Import expr.
Import expr.default.
+ Module CPS.
+ Module type.
+ Import Compilers.type.
+ Section translate.
+ Context (R : type).
+ Fixpoint translate (t : type) : type
+ := match t with
+ | A * B => translate A * translate B
+ | s -> d => translate s -> (translate d -> R) -> R
+ | list A => list (translate A)
+ | type_opaque _ as t
+ | unit as t
+ | nat as t
+ | bool as t
+ => t
+ end%ctype.
+ End translate.
+ End type.
+
+ Module ident.
+ Section with_var.
+ Context {var : type -> Type}.
+ Let Ident' := @Ident ident var.
+ Local Coercion Ident' : ident >-> expr.
+
+ Definition translate {t} {R}
+ (idc : ident t)
+ (k : @expr var (type.translate R t) -> @expr var R)
+ : @expr var R
+ := match idc in ident.ident t return (expr (type.translate R t) -> expr R) -> expr R with
+ | ident.opaque _ _ as idc
+ | ident.tt as idc
+ | ident.O as idc
+ | ident.true as idc
+ | ident.false as idc
+ => fun k => k idc
+ | ident.nil t
+ => fun k => k (@ident.nil (type.translate R t))
+ | ident.Let_In tx tC as idc
+ => fun k
+ => k (λ (x : var (type.translate R tx))
+ (xk :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (((type.translate _ tx -> ((type.translate _ tC -> R) -> R)) -> ((type.translate _ tC -> R) -> R)) -> R)) ,
+ Var xk @ (λ f fk,
+ ident.Let_In
+ @ Var x
+ @ (λ x, Var f @ Var x @ Var fk)))
+ | ident.S as idc
+ | ident.pred as idc
+ | ident.Z_opp as idc
+ | ident.Z_of_nat as idc
+ => fun k
+ => k (λ x k, Var k @ (idc @ Var x))
+ | ident.cons t as idc
+ => fun k
+ => k (λ (x : var (type.translate R t))
+ (xk :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((type.list (type.translate _ t) -> ((type.list (type.translate _ t) -> R) -> R)) -> R)),
+ Var xk @ (λ xs k,
+ Var k @ (ident.cons @ Var x @ Var xs)))
+ | ident.pair A B
+ => fun k
+ => k (λ (x : var (type.translate R A))
+ (xk :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((type.translate _ B -> ((type.translate _ A * type.translate _ B -> R) -> R)) -> R)),
+ Var xk @ (λ y k,
+ Var k @ (Var x, Var y)))
+ | ident.fst A B
+ => fun k
+ => k (λ (x : var (type.translate R A * type.translate R B))
+ (k : var (type.translate _ A -> R)),
+ Var k @ (ident.fst @ Var x))
+ | ident.snd A B
+ => fun k
+ => k (λ (x : var (type.translate R A * type.translate R B))
+ (k : var (type.translate _ B -> R)),
+ Var k @ (ident.snd @ Var x))
+ | ident.bool_rect T
+ => fun k
+ => k (λ (true_case : var (type.translate R T))
+ (k0 :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var ((type.translate R T -> ((type.bool -> (type.translate R T -> R) -> R) -> R) -> R) -> R)),
+ Var k0 @ (λ false_case k1,
+ Var k1 @ (λ b k,
+ ident.bool_rect
+ @ (Var k @ Var true_case)
+ @ (Var k @ Var false_case)
+ @ Var b)))
+ | ident.nat_rect P
+ => fun k
+ => k (λ (O_case : var (type.translate R P))
+ (k0 :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (((type.nat -> ((type.translate R P -> (type.translate R P -> R) -> R) -> R) -> R) -> ((type.nat -> (type.translate R P -> R) -> R) -> R) -> R) -> R)),
+ Var k0 @ (λ S_case k1,
+ Var k1 @ (λ n k,
+ (@ident.nat_rect ((type.translate R P -> R) -> R))
+ @ (λ k, Var k @ Var O_case)
+ @ (λ n' rec k,
+ (Var rec)
+ @ (λ rec,
+ (Var S_case)
+ @ (Var n')
+ @ (λ K, Var K @ Var rec @ Var k)))
+ @ (Var n)
+ @ (Var k))))
+ | ident.list_rect A P
+ => fun k
+ => k (λ (nil_case : var (type.translate R P))
+ (k0 :
+ (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (((type.translate R A -> ((type.list (type.translate R A) -> ((type.translate R P -> (type.translate R P -> R) -> R) -> R) -> R) -> R) -> R) -> ((type.list (type.translate R A) -> (type.translate R P -> R) -> R) -> R) -> R) -> R)),
+ (Var k0)
+ @ (λ cons_case k1,
+ (Var k1)
+ @ (λ ls k,
+ (@ident.list_rect _ ((type.translate R P -> R) -> R))
+ @ (λ k, Var k @ Var nil_case)
+ @ (λ x xs rec k,
+ (Var rec)
+ @ (λ rec,
+ (Var cons_case)
+ @ (Var x)
+ @ (λ K,
+ (Var K)
+ @ (Var xs)
+ @ (λ K, Var K @ Var rec @ Var k))))
+ @ (Var ls)
+ @ (Var k))))
+ | ident.Z_runtime_mul as idc
+ | ident.Z_runtime_add as idc
+ | ident.Z_add as idc
+ | ident.Z_mul as idc
+ | ident.Z_pow as idc
+ | ident.Z_div as idc
+ | ident.Z_modulo as idc
+ | ident.Z_eqb as idc
+ => fun k
+ => k (λ x xk,
+ (Var xk)
+ @ (λ y yk,
+ (Var yk @ (idc @ Var x @ Var y))))
+ end%expr k.
+ End with_var.
+ End ident.
+
+ Module expr.
+ Section with_var.
+ Context {var : type -> Type}
+ {R : type}.
+ Notation var' R := (fun t => var (type.translate R t)).
+
+ Fixpoint translate {t}
+ (e : @expr (var' R) t)
+ (k : @expr var (type.translate R t) -> @expr var R)
+ {struct e}
+ : @expr var R
+ := match e in expr.expr t return (expr (type.translate R t) -> expr R) -> expr R with
+ | Var t v
+ => fun k => k (Var v)
+ | Ident t idc => ident.translate idc
+ | App s d f x
+ => fun k
+ => @translate
+ _ f
+ (fun fv
+ => @translate
+ _ x
+ (fun xv
+ => App (App fv xv) (Abs (fun v => k (Var v)))))
+ | Abs s d f
+ => fun k
+ => k (Abs (fun (x : var (type.translate R s))
+ => Abs (fun (k : var (type.translate _ _ -> _))
+ => @translate
+ _ (f x)
+ (fun v => App (Var k) v))))
+ end k.
+ End with_var.
+
+ Definition Translate {R t} (e : Expr t) (k : forall var, @expr var (type.translate R t) -> @expr var R)
+ : Expr R
+ := fun var => @translate var R t (e _) (k _).
+ End expr.
+ End CPS.
+
Section option_partition.
Context {A : Type} (f : A -> option Datatypes.bool).
Fixpoint option_partition (l : list A) : option (list A * list A)