diff options
Diffstat (limited to 'src/unpoly.sml')
-rw-r--r-- | src/unpoly.sml | 172 |
1 files changed, 106 insertions, 66 deletions
diff --git a/src/unpoly.sml b/src/unpoly.sml index 17878508..56406636 100644 --- a/src/unpoly.sml +++ b/src/unpoly.sml @@ -72,8 +72,19 @@ fun unpolyNamed (xn, rep) = end | _ => e} +structure M = BinaryMapFn(struct + type ord_key = con list + val compare = Order.joinL U.Con.compare + end) + +type func = { + kinds : kind list, + defs : (string * int * con * exp * string) list, + replacements : int M.map +} + type state = { - funcs : (kind list * (string * int * con * exp * string) list) IM.map, + funcs : func IM.map, decls : decl list, nextName : int } @@ -86,8 +97,6 @@ fun exp (e, st : state) = case e of ECApp _ => let - (*val () = Print.prefaces "exp" [("e", CorePrint.p_exp CoreEnv.empty (e, ErrorMsg.dummySpan))]*) - fun unravel (e, cargs) = case e of ECApp ((e, _), c) => unravel (e, c :: cargs) @@ -102,72 +111,101 @@ fun exp (e, st : state) = else case IM.find (#funcs st, n) of NONE => (e, st) - | SOME (ks, vis) => - let - val (vis, nextName) = ListUtil.foldlMap - (fn ((x, n, t, e, s), nextName) => - ((x, nextName, n, t, e, s), nextName + 1)) - (#nextName st) vis - - fun specialize (x, n, n_old, t, e, s) = - let - fun trim (t, e, cargs) = - case (t, e, cargs) of - ((TCFun (_, _, t), _), - (ECAbs (_, _, e), _), - carg :: cargs) => - let - val t = subConInCon (length cargs, carg) t - val e = subConInExp (length cargs, carg) e - in - trim (t, e, cargs) - end - | (_, _, []) => - let - val e = foldl (fn ((_, n, n_old, _, _, _), e) => - unpolyNamed (n_old, ENamed n) e) - e vis - in - SOME (t, e) - end - | _ => NONE - in - (*Print.prefaces "specialize" - [("t", CorePrint.p_con CoreEnv.empty t), - ("e", CorePrint.p_exp CoreEnv.empty e), - ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*) - Option.map (fn (t, e) => (x, n, n_old, t, e, s)) - (trim (t, e, cargs)) - end - - val vis = List.map specialize vis - in - if List.exists (not o Option.isSome) vis orelse length cargs > length ks then - (e, st) - else - let - val vis = List.mapPartial (fn x => x) vis - val vis = map (fn (x, n, n_old, t, e, s) => - (x ^ "_unpoly", n, n_old, t, e, s)) vis - val vis' = map (fn (x, n, _, t, e, s) => - (x, n, t, e, s)) vis - - val ks' = List.drop (ks, length cargs) - in - case List.find (fn (_, _, n_old, _, _, _) => n_old = n) vis of - NONE => raise Fail "Unpoly: Inconsistent 'val rec' record" - | SOME (_, n, _, _, _, _) => - (ENamed n, - {funcs = foldl (fn (vi, funcs) => - IM.insert (funcs, #2 vi, (ks', vis'))) - (#funcs st) vis', + | SOME {kinds = ks, defs = vis, replacements} => + case M.find (replacements, cargs) of + SOME n => (ENamed n, st) + | NONE => + let + val old_vis = vis + val (vis, (thisName, nextName)) = + ListUtil.foldlMap + (fn ((x, n', t, e, s), (thisName, nextName)) => + ((x, nextName, n', t, e, s), + (if n' = n then nextName else thisName, + nextName + 1))) + (0, #nextName st) vis + + fun specialize (x, n, n_old, t, e, s) = + let + fun trim (t, e, cargs) = + case (t, e, cargs) of + ((TCFun (_, _, t), _), + (ECAbs (_, _, e), _), + carg :: cargs) => + let + val t = subConInCon (length cargs, carg) t + val e = subConInExp (length cargs, carg) e + in + trim (t, e, cargs) + end + | (_, _, []) => + (*let + val e = foldl (fn ((_, n, n_old, _, _, _), e) => + unpolyNamed (n_old, ENamed n) e) + e vis + in*) + SOME (t, e) + (*end*) + | _ => NONE + in + (*Print.prefaces "specialize" + [("t", CorePrint.p_con CoreEnv.empty t), + ("e", CorePrint.p_exp CoreEnv.empty e), + ("|cargs|", Print.PD.string (Int.toString (length cargs)))];*) + Option.map (fn (t, e) => (x, n, n_old, t, e, s)) + (trim (t, e, cargs)) + end + + val vis = List.map specialize vis + in + if List.exists (not o Option.isSome) vis orelse length cargs > length ks then + (e, st) + else + let + val vis = List.mapPartial (fn x => x) vis + + val vis = map (fn (x, n, n_old, t, e, s) => + (x ^ "_unpoly", n, n_old, t, e, s)) vis + val vis' = map (fn (x, n, _, t, e, s) => + (x, n, t, e, s)) vis + + val funcs = IM.insert (#funcs st, n, + {kinds = ks, + defs = old_vis, + replacements = M.insert (replacements, + cargs, + thisName)}) + + val ks' = List.drop (ks, length cargs) + + val st = {funcs = foldl (fn (vi, funcs) => + IM.insert (funcs, #2 vi, + {kinds = ks', + defs = vis', + replacements = M.empty})) + funcs vis', + decls = #decls st, + nextName = nextName} + + val (vis', st) = ListUtil.foldlMap (fn ((x, n, t, e, s), st) => + let + val (e, st) = polyExp (e, st) + in + ((x, n, t, e, s), st) + end) + st vis' + in + (ENamed thisName, + {funcs = #funcs st, decls = (DValRec vis', ErrorMsg.dummySpan) :: #decls st, - nextName = nextName}) - end - end + nextName = #nextName st}) + end + end end | _ => (e, st) +and polyExp (x, st) = U.Exp.foldMap {kind = kind, con = con, exp = exp} st x + fun decl (d, st : state) = case d of DValRec (vis as ((x, n, t, e, s) :: rest)) => @@ -232,7 +270,9 @@ fun decl (d, st : state) = (d, st) else (d, {funcs = foldl (fn (vi, funcs) => - IM.insert (funcs, #2 vi, (cargs, vis))) + IM.insert (funcs, #2 vi, {kinds = cargs, + defs = vis, + replacements = M.empty})) (#funcs st) vis, decls = #decls st, nextName = #nextName st}) |