summaryrefslogtreecommitdiff
path: root/src/elab_util.sml
diff options
context:
space:
mode:
Diffstat (limited to 'src/elab_util.sml')
-rw-r--r--src/elab_util.sml111
1 files changed, 65 insertions, 46 deletions
diff --git a/src/elab_util.sml b/src/elab_util.sml
index 51a203f2..036aa867 100644
--- a/src/elab_util.sml
+++ b/src/elab_util.sml
@@ -568,15 +568,17 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
S.map2 (con ctx c,
fn c' =>
(SgiCon (x, n, k', c'), loc)))
- | SgiDatatype (x, n, xs, xncs) =>
- S.map2 (ListUtil.mapfold (fn (x, n, c) =>
- case c of
- NONE => S.return2 (x, n, c)
- | SOME c =>
- S.map2 (con ctx c,
- fn c' => (x, n, SOME c'))) xncs,
- fn xncs' =>
- (SgiDatatype (x, n, xs, xncs'), loc))
+ | SgiDatatype dts =>
+ S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) =>
+ S.map2 (ListUtil.mapfold (fn (x, n, c) =>
+ case c of
+ NONE => S.return2 (x, n, c)
+ | SOME c =>
+ S.map2 (con ctx c,
+ fn c' => (x, n, SOME c'))) xncs,
+ fn xncs' => (x, n, xs, xncs'))) dts,
+ fn dts' =>
+ (SgiDatatype dts', loc))
| SgiDatatypeImp (x, n, m1, ms, s, xs, xncs) =>
S.map2 (ListUtil.mapfold (fn (x, n, c) =>
case c of
@@ -627,8 +629,15 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
bind (ctx, NamedC (x, n, k, NONE))
| SgiCon (x, n, k, c) =>
bind (ctx, NamedC (x, n, k, SOME c))
- | SgiDatatype (x, n, _, xncs) =>
- bind (ctx, NamedC (x, n, (KType, loc), NONE))
+ | SgiDatatype dts =>
+ foldl (fn ((x, n, ks, _), ctx) =>
+ let
+ val k' = (KType, loc)
+ val k = foldl (fn (_, k) => (KArrow (k', k), loc))
+ k' ks
+ in
+ bind (ctx, NamedC (x, n, k, NONE))
+ end) ctx dts
| SgiDatatypeImp (x, n, m1, ms, s, _, _) =>
bind (ctx, NamedC (x, n, (KType, loc),
SOME (CModProj (m1, ms, s), loc)))
@@ -753,29 +762,34 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
(case #1 d of
DCon (x, n, k, c) =>
bind (ctx, NamedC (x, n, k, SOME c))
- | DDatatype (x, n, xs, xncs) =>
+ | DDatatype dts =>
let
- val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE))
+ fun doOne ((x, n, xs, xncs), ctx) =
+ let
+ val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE))
+ in
+ foldl (fn ((x, _, co), ctx) =>
+ let
+ val t =
+ case co of
+ NONE => CNamed n
+ | SOME t => TFun (t, (CNamed n, loc))
+
+ val k = (KType, loc)
+ val t = (t, loc)
+ val t = foldr (fn (x, t) =>
+ (TCFun (Explicit,
+ x,
+ k,
+ t), loc))
+ t xs
+ in
+ bind (ctx, NamedE (x, t))
+ end)
+ ctx xncs
+ end
in
- foldl (fn ((x, _, co), ctx) =>
- let
- val t =
- case co of
- NONE => CNamed n
- | SOME t => TFun (t, (CNamed n, loc))
-
- val k = (KType, loc)
- val t = (t, loc)
- val t = foldr (fn (x, t) =>
- (TCFun (Explicit,
- x,
- k,
- t), loc))
- t xs
- in
- bind (ctx, NamedE (x, t))
- end)
- ctx xncs
+ foldl doOne ctx dts
end
| DDatatypeImp (x, n, m, ms, x', _, _) =>
bind (ctx, NamedC (x, n, (KType, loc),
@@ -851,15 +865,18 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
S.map2 (mfc ctx c,
fn c' =>
(DCon (x, n, k', c'), loc)))
- | DDatatype (x, n, xs, xncs) =>
- S.map2 (ListUtil.mapfold (fn (x, n, c) =>
- case c of
- NONE => S.return2 (x, n, c)
- | SOME c =>
- S.map2 (mfc ctx c,
- fn c' => (x, n, SOME c'))) xncs,
- fn xncs' =>
- (DDatatype (x, n, xs, xncs'), loc))
+ | DDatatype dts =>
+ S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) =>
+ S.map2 (ListUtil.mapfold (fn (x, n, c) =>
+ case c of
+ NONE => S.return2 (x, n, c)
+ | SOME c =>
+ S.map2 (mfc ctx c,
+ fn c' => (x, n, SOME c'))) xncs,
+ fn xncs' =>
+ (x, n, xs, xncs'))) dts,
+ fn dts' =>
+ (DDatatype dts', loc))
| DDatatypeImp (x, n, m1, ms, s, xs, xncs) =>
S.map2 (ListUtil.mapfold (fn (x, n, c) =>
case c of
@@ -1059,9 +1076,10 @@ fun maxName ds = foldl (fn (d, count) => Int.max (maxNameDecl d, count)) 0 ds
and maxNameDecl (d, _) =
case d of
DCon (_, n, _, _) => n
- | DDatatype (_, n, _, ns) =>
+ | DDatatype dts =>
+ foldl (fn ((_, n, _, ns), max) =>
foldl (fn ((_, n', _), m) => Int.max (n', m))
- n ns
+ (Int.max (n, max)) ns) 0 dts
| DDatatypeImp (_, n1, n2, _, _, _, ns) =>
foldl (fn ((_, n', _), m) => Int.max (n', m))
(Int.max (n1, n2)) ns
@@ -1101,9 +1119,10 @@ and maxNameSgi (sgi, _) =
case sgi of
SgiConAbs (_, n, _) => n
| SgiCon (_, n, _, _) => n
- | SgiDatatype (_, n, _, ns) =>
- foldl (fn ((_, n', _), m) => Int.max (n', m))
- n ns
+ | SgiDatatype dts =>
+ foldl (fn ((_, n, _, ns), max) =>
+ foldl (fn ((_, n', _), m) => Int.max (n', m))
+ (Int.max (n, max)) ns) 0 dts
| SgiDatatypeImp (_, n1, n2, _, _, _, ns) =>
foldl (fn ((_, n', _), m) => Int.max (n', m))
(Int.max (n1, n2)) ns