diff options
Diffstat (limited to 'src/elab_util.sml')
-rw-r--r-- | src/elab_util.sml | 111 |
1 files changed, 65 insertions, 46 deletions
diff --git a/src/elab_util.sml b/src/elab_util.sml index 51a203f2..036aa867 100644 --- a/src/elab_util.sml +++ b/src/elab_util.sml @@ -568,15 +568,17 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = S.map2 (con ctx c, fn c' => (SgiCon (x, n, k', c'), loc))) - | SgiDatatype (x, n, xs, xncs) => - S.map2 (ListUtil.mapfold (fn (x, n, c) => - case c of - NONE => S.return2 (x, n, c) - | SOME c => - S.map2 (con ctx c, - fn c' => (x, n, SOME c'))) xncs, - fn xncs' => - (SgiDatatype (x, n, xs, xncs'), loc)) + | SgiDatatype dts => + S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) => + S.map2 (ListUtil.mapfold (fn (x, n, c) => + case c of + NONE => S.return2 (x, n, c) + | SOME c => + S.map2 (con ctx c, + fn c' => (x, n, SOME c'))) xncs, + fn xncs' => (x, n, xs, xncs'))) dts, + fn dts' => + (SgiDatatype dts', loc)) | SgiDatatypeImp (x, n, m1, ms, s, xs, xncs) => S.map2 (ListUtil.mapfold (fn (x, n, c) => case c of @@ -627,8 +629,15 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = bind (ctx, NamedC (x, n, k, NONE)) | SgiCon (x, n, k, c) => bind (ctx, NamedC (x, n, k, SOME c)) - | SgiDatatype (x, n, _, xncs) => - bind (ctx, NamedC (x, n, (KType, loc), NONE)) + | SgiDatatype dts => + foldl (fn ((x, n, ks, _), ctx) => + let + val k' = (KType, loc) + val k = foldl (fn (_, k) => (KArrow (k', k), loc)) + k' ks + in + bind (ctx, NamedC (x, n, k, NONE)) + end) ctx dts | SgiDatatypeImp (x, n, m1, ms, s, _, _) => bind (ctx, NamedC (x, n, (KType, loc), SOME (CModProj (m1, ms, s), loc))) @@ -753,29 +762,34 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f (case #1 d of DCon (x, n, k, c) => bind (ctx, NamedC (x, n, k, SOME c)) - | DDatatype (x, n, xs, xncs) => + | DDatatype dts => let - val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE)) + fun doOne ((x, n, xs, xncs), ctx) = + let + val ctx = bind (ctx, NamedC (x, n, (KType, loc), NONE)) + in + foldl (fn ((x, _, co), ctx) => + let + val t = + case co of + NONE => CNamed n + | SOME t => TFun (t, (CNamed n, loc)) + + val k = (KType, loc) + val t = (t, loc) + val t = foldr (fn (x, t) => + (TCFun (Explicit, + x, + k, + t), loc)) + t xs + in + bind (ctx, NamedE (x, t)) + end) + ctx xncs + end in - foldl (fn ((x, _, co), ctx) => - let - val t = - case co of - NONE => CNamed n - | SOME t => TFun (t, (CNamed n, loc)) - - val k = (KType, loc) - val t = (t, loc) - val t = foldr (fn (x, t) => - (TCFun (Explicit, - x, - k, - t), loc)) - t xs - in - bind (ctx, NamedE (x, t)) - end) - ctx xncs + foldl doOne ctx dts end | DDatatypeImp (x, n, m, ms, x', _, _) => bind (ctx, NamedC (x, n, (KType, loc), @@ -851,15 +865,18 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f S.map2 (mfc ctx c, fn c' => (DCon (x, n, k', c'), loc))) - | DDatatype (x, n, xs, xncs) => - S.map2 (ListUtil.mapfold (fn (x, n, c) => - case c of - NONE => S.return2 (x, n, c) - | SOME c => - S.map2 (mfc ctx c, - fn c' => (x, n, SOME c'))) xncs, - fn xncs' => - (DDatatype (x, n, xs, xncs'), loc)) + | DDatatype dts => + S.map2 (ListUtil.mapfold (fn (x, n, xs, xncs) => + S.map2 (ListUtil.mapfold (fn (x, n, c) => + case c of + NONE => S.return2 (x, n, c) + | SOME c => + S.map2 (mfc ctx c, + fn c' => (x, n, SOME c'))) xncs, + fn xncs' => + (x, n, xs, xncs'))) dts, + fn dts' => + (DDatatype dts', loc)) | DDatatypeImp (x, n, m1, ms, s, xs, xncs) => S.map2 (ListUtil.mapfold (fn (x, n, c) => case c of @@ -1059,9 +1076,10 @@ fun maxName ds = foldl (fn (d, count) => Int.max (maxNameDecl d, count)) 0 ds and maxNameDecl (d, _) = case d of DCon (_, n, _, _) => n - | DDatatype (_, n, _, ns) => + | DDatatype dts => + foldl (fn ((_, n, _, ns), max) => foldl (fn ((_, n', _), m) => Int.max (n', m)) - n ns + (Int.max (n, max)) ns) 0 dts | DDatatypeImp (_, n1, n2, _, _, _, ns) => foldl (fn ((_, n', _), m) => Int.max (n', m)) (Int.max (n1, n2)) ns @@ -1101,9 +1119,10 @@ and maxNameSgi (sgi, _) = case sgi of SgiConAbs (_, n, _) => n | SgiCon (_, n, _, _) => n - | SgiDatatype (_, n, _, ns) => - foldl (fn ((_, n', _), m) => Int.max (n', m)) - n ns + | SgiDatatype dts => + foldl (fn ((_, n, _, ns), max) => + foldl (fn ((_, n', _), m) => Int.max (n', m)) + (Int.max (n, max)) ns) 0 dts | SgiDatatypeImp (_, n1, n2, _, _, _, ns) => foldl (fn ((_, n', _), m) => Int.max (n', m)) (Int.max (n1, n2)) ns |