diff options
Diffstat (limited to 'src/elab_util.sml')
-rw-r--r-- | src/elab_util.sml | 154 |
1 files changed, 100 insertions, 54 deletions
diff --git a/src/elab_util.sml b/src/elab_util.sml index f052a06d..be1c9459 100644 --- a/src/elab_util.sml +++ b/src/elab_util.sml @@ -43,44 +43,60 @@ structure S = Search structure Kind = struct -fun mapfold f = +fun mapfoldB {kind, bind} = let - fun mfk k acc = - S.bindP (mfk' k acc, f) + fun mfk ctx k acc = + S.bindP (mfk' ctx k acc, kind ctx) - and mfk' (kAll as (k, loc)) = + and mfk' ctx (kAll as (k, loc)) = case k of KType => S.return2 kAll | KArrow (k1, k2) => - S.bind2 (mfk k1, + S.bind2 (mfk ctx k1, fn k1' => - S.map2 (mfk k2, + S.map2 (mfk ctx k2, fn k2' => (KArrow (k1', k2'), loc))) | KName => S.return2 kAll | KRecord k => - S.map2 (mfk k, + S.map2 (mfk ctx k, fn k' => (KRecord k', loc)) | KUnit => S.return2 kAll | KTuple ks => - S.map2 (ListUtil.mapfold mfk ks, + S.map2 (ListUtil.mapfold (mfk ctx) ks, fn ks' => (KTuple ks', loc)) | KError => S.return2 kAll - | KUnif (_, _, ref (SOME k)) => mfk' k + | KUnif (_, _, ref (SOME k)) => mfk' ctx k | KUnif _ => S.return2 kAll + + | KRel _ => S.return2 kAll + | KFun (x, k) => + S.map2 (mfk (bind (ctx, x)) k, + fn k' => + (KFun (x, k'), loc)) in mfk end +fun mapfold fk = + mapfoldB {kind = fn () => fk, + bind = fn ((), _) => ()} () + +fun mapB {kind, bind} ctx k = + case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()), + bind = bind} ctx k () of + S.Continue (k, ()) => k + | S.Return _ => raise Fail "ElabUtil.Kind.mapB: Impossible" + fun exists f k = case mapfold (fn k => fn () => if f k then @@ -95,12 +111,13 @@ end structure Con = struct datatype binder = - Rel of string * Elab.kind - | Named of string * int * Elab.kind + RelK of string + | RelC of string * Elab.kind + | NamedC of string * int * Elab.kind fun mapfoldB {kind = fk, con = fc, bind} = let - val mfk = Kind.mapfold fk + val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, s) => bind (ctx, RelK s)} fun mfc ctx c acc = S.bindP (mfc' ctx c acc, fc ctx) @@ -114,9 +131,9 @@ fun mapfoldB {kind = fk, con = fc, bind} = fn c2' => (TFun (c1', c2'), loc))) | TCFun (e, x, k, c) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => - S.map2 (mfc (bind (ctx, Rel (x, k))) c, + S.map2 (mfc (bind (ctx, RelC (x, k))) c, fn c' => (TCFun (e, x, k', c'), loc))) | CDisjoint (ai, c1, c2, c3) => @@ -142,16 +159,16 @@ fun mapfoldB {kind = fk, con = fc, bind} = fn c2' => (CApp (c1', c2'), loc))) | CAbs (x, k, c) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => - S.map2 (mfc (bind (ctx, Rel (x, k))) c, + S.map2 (mfc (bind (ctx, RelC (x, k))) c, fn c' => (CAbs (x, k', c'), loc))) | CName _ => S.return2 cAll | CRecord (k, xcs) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => S.map2 (ListUtil.mapfold (fn (x, c) => S.bind2 (mfc ctx x, @@ -169,9 +186,9 @@ fun mapfoldB {kind = fk, con = fc, bind} = fn c2' => (CConcat (c1', c2'), loc))) | CMap (k1, k2) => - S.bind2 (mfk k1, + S.bind2 (mfk ctx k1, fn k1' => - S.map2 (mfk k2, + S.map2 (mfk ctx k2, fn k2' => (CMap (k1', k2'), loc))) @@ -190,17 +207,32 @@ fun mapfoldB {kind = fk, con = fc, bind} = | CError => S.return2 cAll | CUnif (_, _, _, ref (SOME c)) => mfc' ctx c | CUnif _ => S.return2 cAll + + | CKAbs (x, c) => + S.map2 (mfc (bind (ctx, RelK x)) c, + fn c' => + (CKAbs (x, c'), loc)) + | CKApp (c, k) => + S.bind2 (mfc ctx c, + fn c' => + S.map2 (mfk ctx k, + fn k' => + (CKApp (c', k'), loc))) + | TKFun (x, c) => + S.map2 (mfc (bind (ctx, RelK x)) c, + fn c' => + (TKFun (x, c'), loc)) in mfc end fun mapfold {kind = fk, con = fc} = - mapfoldB {kind = fk, + mapfoldB {kind = fn () => fk, con = fn () => fc, bind = fn ((), _) => ()} () fun mapB {kind, con, bind} ctx c = - case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()), + case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()), con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()), bind = bind} ctx c () of S.Continue (c, ()) => c @@ -227,7 +259,7 @@ fun exists {kind, con} k = | S.Continue _ => false fun foldB {kind, con, bind} ctx st c = - case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)), + case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)), con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)), bind = bind} ctx c st of S.Continue (_, st) => st @@ -238,20 +270,22 @@ end structure Exp = struct datatype binder = - RelC of string * Elab.kind + RelK of string + | RelC of string * Elab.kind | NamedC of string * int * Elab.kind | RelE of string * Elab.con | NamedE of string * Elab.con fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = let - val mfk = Kind.mapfold fk + val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)} fun bind' (ctx, b) = let val b' = case b of - Con.Rel x => RelC x - | Con.Named x => NamedC x + Con.RelK x => RelK x + | Con.RelC x => RelC x + | Con.NamedC x => NamedC x in bind (ctx, b') end @@ -288,7 +322,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = fn c' => (ECApp (e', c'), loc))) | ECAbs (expl, x, k, e) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => S.map2 (mfe (bind (ctx, RelC (x, k))) e, fn e' => @@ -347,11 +381,6 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = fn rest' => (ECutMulti (e', c', {rest = rest'}), loc)))) - | EFold k => - S.map2 (mfk k, - fn k' => - (EFold k', loc)) - | ECase (e, pes, {disc, result}) => S.bind2 (mfe ctx e, fn e' => @@ -406,6 +435,17 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = (ELet (des', e'), loc))) end + | EKAbs (x, e) => + S.map2 (mfe (bind (ctx, RelK x)) e, + fn e' => + (EKAbs (x, e'), loc)) + | EKApp (e, k) => + S.bind2 (mfe ctx e, + fn e' => + S.map2 (mfk ctx k, + fn k' => + (EKApp (e', k'), loc))) + and mfed ctx (dAll as (d, loc)) = case d of EDVal vi => @@ -432,7 +472,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = end fun mapfold {kind = fk, con = fc, exp = fe} = - mapfoldB {kind = fk, + mapfoldB {kind = fn () => fk, con = fn () => fc, exp = fn () => fe, bind = fn ((), _) => ()} () @@ -457,7 +497,7 @@ fun exists {kind, con, exp} k = | S.Continue _ => false fun mapB {kind, con, exp, bind} ctx e = - case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()), + case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()), con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()), exp = fn ctx => fn e => fn () => S.Continue (exp ctx e, ()), bind = bind} ctx e () of @@ -465,7 +505,7 @@ fun mapB {kind, con, exp, bind} ctx e = | S.Return _ => raise Fail "ElabUtil.Exp.mapB: Impossible" fun foldB {kind, con, exp, bind} ctx st e = - case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)), + case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)), con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)), exp = fn ctx => fn e => fn st => S.Continue (e, exp (ctx, e, st)), bind = bind} ctx e st of @@ -477,7 +517,8 @@ end structure Sgn = struct datatype binder = - RelC of string * Elab.kind + RelK of string + | RelC of string * Elab.kind | NamedC of string * int * Elab.kind | Str of string * Elab.sgn | Sgn of string * Elab.sgn @@ -487,14 +528,15 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = fun bind' (ctx, b) = let val b' = case b of - Con.Rel x => RelC x - | Con.Named x => NamedC x + Con.RelK x => RelK x + | Con.RelC x => RelC x + | Con.NamedC x => NamedC x in bind (ctx, b') end val con = Con.mapfoldB {kind = kind, con = con, bind = bind'} - val kind = Kind.mapfold kind + val kind = Kind.mapfoldB {kind = kind, bind = fn (ctx, x) => bind (ctx, RelK x)} fun sgi ctx si acc = S.bindP (sgi' ctx si acc, sgn_item ctx) @@ -502,11 +544,11 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = and sgi' ctx (siAll as (si, loc)) = case si of SgiConAbs (x, n, k) => - S.map2 (kind k, + S.map2 (kind ctx k, fn k' => (SgiConAbs (x, n, k'), loc)) | SgiCon (x, n, k, c) => - S.bind2 (kind k, + S.bind2 (kind ctx k, fn k' => S.map2 (con ctx c, fn c' => @@ -548,11 +590,11 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = fn c2' => (SgiConstraint (c1', c2'), loc))) | SgiClassAbs (x, n, k) => - S.map2 (kind k, + S.map2 (kind ctx k, fn k' => (SgiClassAbs (x, n, k'), loc)) | SgiClass (x, n, k, c) => - S.bind2 (kind k, + S.bind2 (kind ctx k, fn k' => S.map2 (con ctx c, fn c' => @@ -608,7 +650,7 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} = end fun mapfold {kind, con, sgn_item, sgn} = - mapfoldB {kind = kind, + mapfoldB {kind = fn () => kind, con = fn () => con, sgn_item = fn () => sgn_item, sgn = fn () => sgn, @@ -627,7 +669,8 @@ end structure Decl = struct datatype binder = - RelC of string * Elab.kind + RelK of string + | RelC of string * Elab.kind | NamedC of string * int * Elab.kind | RelE of string * Elab.con | NamedE of string * Elab.con @@ -636,13 +679,14 @@ datatype binder = fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = fst, decl = fd, bind} = let - val mfk = Kind.mapfold fk + val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)} fun bind' (ctx, b) = let val b' = case b of - Con.Rel x => RelC x - | Con.Named x => NamedC x + Con.RelK x => RelK x + | Con.RelC x => RelC x + | Con.NamedC x => NamedC x in bind (ctx, b') end @@ -651,7 +695,8 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f fun bind' (ctx, b) = let val b' = case b of - Exp.RelC x => RelC x + Exp.RelK x => RelK x + | Exp.RelC x => RelC x | Exp.NamedC x => NamedC x | Exp.RelE x => RelE x | Exp.NamedE x => NamedE x @@ -663,7 +708,8 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f fun bind' (ctx, b) = let val b' = case b of - Sgn.RelC x => RelC x + Sgn.RelK x => RelK x + | Sgn.RelC x => RelC x | Sgn.NamedC x => NamedC x | Sgn.Sgn x => Sgn x | Sgn.Str x => Str x @@ -760,7 +806,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f and mfd' ctx (dAll as (d, loc)) = case d of DCon (x, n, k, c) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => S.map2 (mfc ctx c, fn c' => @@ -825,7 +871,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f | DSequence _ => S.return2 dAll | DClass (x, n, k, c) => - S.bind2 (mfk k, + S.bind2 (mfk ctx k, fn k' => S.map2 (mfc ctx c, fn c' => @@ -849,7 +895,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f end fun mapfold {kind, con, exp, sgn_item, sgn, str, decl} = - mapfoldB {kind = kind, + mapfoldB {kind = fn () => kind, con = fn () => con, exp = fn () => exp, sgn_item = fn () => sgn_item, @@ -938,7 +984,7 @@ fun search {kind, con, exp, sgn_item, sgn, str, decl} k = | S.Continue _ => NONE fun foldMapB {kind, con, exp, sgn_item, sgn, str, decl, bind} ctx st d = - case mapfoldB {kind = fn x => fn st => S.Continue (kind (x, st)), + case mapfoldB {kind = fn ctx => fn x => fn st => S.Continue (kind (ctx, x, st)), con = fn ctx => fn x => fn st => S.Continue (con (ctx, x, st)), exp = fn ctx => fn x => fn st => S.Continue (exp (ctx, x, st)), sgn_item = fn ctx => fn x => fn st => S.Continue (sgn_item (ctx, x, st)), |