summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/rpcify.sml204
1 files changed, 67 insertions, 137 deletions
diff --git a/src/rpcify.sml b/src/rpcify.sml
index a70d94fe..4ed90228 100644
--- a/src/rpcify.sml
+++ b/src/rpcify.sml
@@ -170,144 +170,74 @@ fun frob file =
DValRec vis =>
if List.exists (fn (_, _, _, e, _) => makesServerCall e) vis then
let
- val all = foldl (fn ((_, n, _, _, _), all) => IS.add (all, n)) IS.empty vis
-
- val usesRec = U.Exp.exists {kind = fn _ => false,
- con = fn _ => false,
- exp = fn ENamed n => IS.member (all, n)
- | _ => false}
-
- val noRec = not o usesRec
-
- fun tailOnly (e, _) =
- case e of
- EPrim _ => true
- | ERel _ => true
- | ENamed _ => true
- | ECon (_, _, _, SOME e) => noRec e
- | ECon _ => true
- | EFfi _ => true
- | EFfiApp (_, _, es) => List.all noRec es
- | EApp (e1, e2) => noRec e2 andalso tailOnly e1
- | EAbs (_, _, _, e) => noRec e
- | ECApp (e1, _) => tailOnly e1
- | ECAbs (_, _, e) => noRec e
-
- | EKAbs (_, e) => noRec e
- | EKApp (e1, _) => tailOnly e1
-
- | ERecord xes => List.all (noRec o #2) xes
- | EField (e1, _, _) => noRec e1
- | EConcat (e1, _, e2, _) => noRec e1 andalso noRec e2
- | ECut (e1, _, _) => noRec e1
- | ECutMulti (e1, _, _) => noRec e1
-
- | ECase (e1, pes, _) => noRec e1 andalso List.all (tailOnly o #2) pes
-
- | EWrite e1 => noRec e1
-
- | EClosure (_, es) => List.all noRec es
-
- | ELet (_, _, e1, e2) => noRec e1 andalso tailOnly e2
-
- | EServerCall (_, es, (EAbs (_, _, _, e), _), _, _) =>
- List.all noRec es andalso tailOnly e
- | EServerCall (_, es, e, _, _) => List.all noRec es andalso noRec e
-
- | ETailCall _ => raise Fail "Rpcify: ETailCall too early"
-
- fun tailOnlyF e =
- case #1 e of
- EAbs (_, _, _, e) => tailOnlyF e
- | ECAbs (_, _, e) => tailOnlyF e
- | EKAbs (_, e) => tailOnlyF e
- | _ => tailOnly e
-
- val nonTail = foldl (fn ((_, n, _, e, _), nonTail) =>
- if tailOnlyF e then
- nonTail
- else
- IS.add (nonTail, n)) IS.empty vis
+ val rpc = foldl (fn ((_, n, _, _, _), rpc) =>
+ IS.add (rpc, n)) (#rpc st) vis
+
+ val (cpsed, vis') =
+ foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) =>
+ let
+ fun getArgs (t, acc) =
+ case #1 t of
+ TFun (dom, ran) =>
+ getArgs (ran, dom :: acc)
+ | _ => (rev acc, t)
+ val (ts, ran) = getArgs (t, [])
+ val ran = case #1 ran of
+ CApp (_, ran) => ran
+ | _ => raise Fail "Rpcify: Tail function not transactional"
+ val len = length ts
+
+ val loc = #2 e
+ val args = ListUtil.mapi
+ (fn (i, _) =>
+ (ERel (len - i - 1), loc))
+ ts
+ val k = (EFfi ("Basis", "return"), loc)
+ val trans = (CFfi ("Basis", "transaction"), loc)
+ val k = (ECApp (k, trans), loc)
+ val k = (ECApp (k, ran), loc)
+ val k = (EApp (k, (EFfi ("Basis", "transaction_monad"),
+ loc)), loc)
+ val re = (ETailCall (n, args, k, ran, ran), loc)
+ val (re, _) = foldr (fn (dom, (re, ran)) =>
+ ((EAbs ("x", dom, ran, re),
+ loc),
+ (TFun (dom, ran), loc)))
+ (re, ran) ts
+
+ val be = multiLiftExpInExp (len + 1) e
+ val be = ListUtil.foldli
+ (fn (i, _, be) =>
+ (EApp (be, (ERel (len - i), loc)), loc))
+ be ts
+ val ne = (EFfi ("Basis", "bind"), loc)
+ val ne = (ECApp (ne, trans), loc)
+ val ne = (ECApp (ne, ran), loc)
+ val unit = (TRecord (CRecord ((KType, loc), []),
+ loc), loc)
+ val ne = (ECApp (ne, unit), loc)
+ val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"),
+ loc)), loc)
+ val ne = (EApp (ne, be), loc)
+ val ne = (EApp (ne, (ERel 0, loc)), loc)
+ val tunit = (CApp (trans, unit), loc)
+ val kt = (TFun (ran, tunit), loc)
+ val ne = (EAbs ("k", kt, tunit, ne), loc)
+ val (ne, res) = foldr (fn (dom, (ne, ran)) =>
+ ((EAbs ("x", dom, ran, ne), loc),
+ (TFun (dom, ran), loc)))
+ (ne, (TFun (kt, tunit), loc)) ts
+ in
+ (IM.insert (cpsed, n, #1 re),
+ (x, n, res, ne, s) :: vis')
+ end)
+ (#cpsed st, []) vis
in
- if IS.isEmpty nonTail then
- (d, {exported = #exported st,
- export_decls = #export_decls st,
- cpsed = #cpsed st,
- rpc = IS.union (#rpc st, all)})
- else
- let
- val rpc = foldl (fn ((_, n, _, _, _), rpc) =>
- IS.add (rpc, n)) (#rpc st) vis
-
- val (cpsed, vis') =
- foldl (fn (vi as (x, n, t, e, s), (cpsed, vis')) =>
- if IS.member (nonTail, n) then
- let
- fun getArgs (t, acc) =
- case #1 t of
- TFun (dom, ran) =>
- getArgs (ran, dom :: acc)
- | _ => (rev acc, t)
- val (ts, ran) = getArgs (t, [])
- val ran = case #1 ran of
- CApp (_, ran) => ran
- | _ => raise Fail "Rpcify: Tail function not transactional"
- val len = length ts
-
- val loc = #2 e
- val args = ListUtil.mapi
- (fn (i, _) =>
- (ERel (len - i - 1), loc))
- ts
- val k = (EFfi ("Basis", "return"), loc)
- val trans = (CFfi ("Basis", "transaction"), loc)
- val k = (ECApp (k, trans), loc)
- val k = (ECApp (k, ran), loc)
- val k = (EApp (k, (EFfi ("Basis", "transaction_monad"),
- loc)), loc)
- val re = (ETailCall (n, args, k, ran, ran), loc)
- val (re, _) = foldr (fn (dom, (re, ran)) =>
- ((EAbs ("x", dom, ran, re),
- loc),
- (TFun (dom, ran), loc)))
- (re, ran) ts
-
- val be = multiLiftExpInExp (len + 1) e
- val be = ListUtil.foldli
- (fn (i, _, be) =>
- (EApp (be, (ERel (len - i), loc)), loc))
- be ts
- val ne = (EFfi ("Basis", "bind"), loc)
- val ne = (ECApp (ne, trans), loc)
- val ne = (ECApp (ne, ran), loc)
- val unit = (TRecord (CRecord ((KType, loc), []),
- loc), loc)
- val ne = (ECApp (ne, unit), loc)
- val ne = (EApp (ne, (EFfi ("Basis", "transaction_monad"),
- loc)), loc)
- val ne = (EApp (ne, be), loc)
- val ne = (EApp (ne, (ERel 0, loc)), loc)
- val tunit = (CApp (trans, unit), loc)
- val kt = (TFun (ran, tunit), loc)
- val ne = (EAbs ("k", kt, tunit, ne), loc)
- val (ne, res) = foldr (fn (dom, (ne, ran)) =>
- ((EAbs ("x", dom, ran, ne), loc),
- (TFun (dom, ran), loc)))
- (ne, (TFun (kt, tunit), loc)) ts
- in
- (IM.insert (cpsed, n, #1 re),
- (x, n, res, ne, s) :: vis')
- end
- else
- (cpsed, vi :: vis'))
- (#cpsed st, []) vis
- in
- ((DValRec (rev vis'), ErrorMsg.dummySpan),
- {exported = #exported st,
- export_decls = #export_decls st,
- cpsed = cpsed,
- rpc = rpc})
- end
+ ((DValRec (rev vis'), ErrorMsg.dummySpan),
+ {exported = #exported st,
+ export_decls = #export_decls st,
+ cpsed = cpsed,
+ rpc = rpc})
end
else
(d, st)