summaryrefslogtreecommitdiff
path: root/src/unpoly.sml
diff options
context:
space:
mode:
Diffstat (limited to 'src/unpoly.sml')
-rw-r--r--src/unpoly.sml172
1 files changed, 106 insertions, 66 deletions
diff --git a/src/unpoly.sml b/src/unpoly.sml
index 17878508..56406636 100644
--- a/src/unpoly.sml
+++ b/src/unpoly.sml
@@ -72,8 +72,19 @@ fun unpolyNamed (xn, rep) =
end
| _ => e}
+structure M = BinaryMapFn(struct
+ type ord_key = con list
+ val compare = Order.joinL U.Con.compare
+ end)
+
+type func = {
+ kinds : kind list,
+ defs : (string * int * con * exp * string) list,
+ replacements : int M.map
+}
+
type state = {
- funcs : (kind list * (string * int * con * exp * string) list) IM.map,
+ funcs : func IM.map,
decls : decl list,
nextName : int
}
@@ -86,8 +97,6 @@ fun exp (e, st : state) =
case e of
ECApp _ =>
let
- (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*)
-
fun unravel (e, cargs) =
case e of
ECApp ((e, _), c) => unravel (e, c :: cargs)
@@ -102,72 +111,101 @@ fun exp (e, st : state) =
else
case IM.find (#funcs st, n) of
NONE => (e, st)
- | SOME (ks, vis) =>
- let
- val (vis, nextName) = ListUtil.foldlMap
- (fn ((x, n, t, e, s), nextName) =>
- ((x, nextName, n, t, e, s), nextName + 1))
- (#nextName st) vis
-
- fun specialize (x, n, n_old, t, e, s) =
- let
- fun trim (t, e, cargs) =
- case (t, e, cargs) of
- ((TCFun (_, _, t), _),
- (ECAbs (_, _, e), _),
- carg :: cargs) =>
- let
- val t = subConInCon (length cargs, carg) t
- val e = subConInExp (length cargs, carg) e
- in
- trim (t, e, cargs)
- end
- | (_, _, []) =>
- let
- val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
- unpolyNamed (n_old, ENamed n) e)
- e vis
- in
- SOME (t, e)
- end
- | _ => NONE
- in
- (*Print.prefaces "specialize"
- [("t", CorePrint.p_con CoreEnv.empty t),
- ("e", CorePrint.p_exp CoreEnv.empty e),
- ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
- Option.map (fn (t, e) => (x, n, n_old, t, e, s))
- (trim (t, e, cargs))
- end
-
- val vis = List.map specialize vis
- in
- if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
- (e, st)
- else
- let
- val vis = List.mapPartial (fn x => x) vis
- val vis = map (fn (x, n, n_old, t, e, s) =>
- (x ^ "_unpoly", n, n_old, t, e, s)) vis
- val vis' = map (fn (x, n, _, t, e, s) =>
- (x, n, t, e, s)) vis
-
- val ks' = List.drop (ks, length cargs)
- in
- case List.find (fn (_, _, n_old, _, _, _) => n_old = n) vis of
- NONE => raise Fail "Unpoly: Inconsistent 'val rec' record"
- | SOME (_, n, _, _, _, _) =>
- (ENamed n,
- {funcs = foldl (fn (vi, funcs) =>
- IM.insert (funcs, #2 vi, (ks', vis')))
- (#funcs st) vis',
+ | SOME {kinds = ks, defs = vis, replacements} =>
+ case M.find (replacements, cargs) of
+ SOME n => (ENamed n, st)
+ | NONE =>
+ let
+ val old_vis = vis
+ val (vis, (thisName, nextName)) =
+ ListUtil.foldlMap
+ (fn ((x, n', t, e, s), (thisName, nextName)) =>
+ ((x, nextName, n', t, e, s),
+ (if n' = n then nextName else thisName,
+ nextName + 1)))
+ (0, #nextName st) vis
+
+ fun specialize (x, n, n_old, t, e, s) =
+ let
+ fun trim (t, e, cargs) =
+ case (t, e, cargs) of
+ ((TCFun (_, _, t), _),
+ (ECAbs (_, _, e), _),
+ carg :: cargs) =>
+ let
+ val t = subConInCon (length cargs, carg) t
+ val e = subConInExp (length cargs, carg) e
+ in
+ trim (t, e, cargs)
+ end
+ | (_, _, []) =>
+ (*let
+ val e = foldl (fn ((_, n, n_old, _, _, _), e) =>
+ unpolyNamed (n_old, ENamed n) e)
+ e vis
+ in*)
+ SOME (t, e)
+ (*end*)
+ | _ => NONE
+ in
+ (*Print.prefaces "specialize"
+ [("t", CorePrint.p_con CoreEnv.empty t),
+ ("e", CorePrint.p_exp CoreEnv.empty e),
+ ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*)
+ Option.map (fn (t, e) => (x, n, n_old, t, e, s))
+ (trim (t, e, cargs))
+ end
+
+ val vis = List.map specialize vis
+ in
+ if List.exists (not o Option.isSome) vis orelse length cargs > length ks then
+ (e, st)
+ else
+ let
+ val vis = List.mapPartial (fn x => x) vis
+
+ val vis = map (fn (x, n, n_old, t, e, s) =>
+ (x ^ "_unpoly", n, n_old, t, e, s)) vis
+ val vis' = map (fn (x, n, _, t, e, s) =>
+ (x, n, t, e, s)) vis
+
+ val funcs = IM.insert (#funcs st, n,
+ {kinds = ks,
+ defs = old_vis,
+ replacements = M.insert (replacements,
+ cargs,
+ thisName)})
+
+ val ks' = List.drop (ks, length cargs)
+
+ val st = {funcs = foldl (fn (vi, funcs) =>
+ IM.insert (funcs, #2 vi,
+ {kinds = ks',
+ defs = vis',
+ replacements = M.empty}))
+ funcs vis',
+ decls = #decls st,
+ nextName = nextName}
+
+ val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) =>
+ let
+ val (e, st) = polyExp (e, st)
+ in
+ ((x, n, t, e, s), st)
+ end)
+ st vis'
+ in
+ (ENamed thisName,
+ {funcs = #funcs st,
decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st,
- nextName = nextName})
- end
- end
+ nextName = #nextName st})
+ end
+ end
end
| _ => (e, st)
+and polyExp (x, st) = U.Exp.foldMap {kind = kind, con = con, exp = exp} st x
+
fun decl (d, st : state) =
case d of
DValRec (vis as ((x, n, t, e, s) :: rest)) =>
@@ -232,7 +270,9 @@ fun decl (d, st : state) =
(d, st)
else
(d, {funcs = foldl (fn (vi, funcs) =>
- IM.insert (funcs, #2 vi, (cargs, vis)))
+ IM.insert (funcs, #2 vi, {kinds = cargs,
+ defs = vis,
+ replacements = M.empty}))
(#funcs st) vis,
decls = #decls st,
nextName = #nextName st})