summaryrefslogtreecommitdiff
path: root/src/elab_util.sml
diff options
context:
space:
mode:
Diffstat (limited to 'src/elab_util.sml')
-rw-r--r--src/elab_util.sml154
1 files changed, 100 insertions, 54 deletions
diff --git a/src/elab_util.sml b/src/elab_util.sml
index f052a06d..be1c9459 100644
--- a/src/elab_util.sml
+++ b/src/elab_util.sml
@@ -43,44 +43,60 @@ structure S = Search
structure Kind = struct
-fun mapfold f =
+fun mapfoldB {kind, bind} =
let
- fun mfk k acc =
- S.bindP (mfk' k acc, f)
+ fun mfk ctx k acc =
+ S.bindP (mfk' ctx k acc, kind ctx)
- and mfk' (kAll as (k, loc)) =
+ and mfk' ctx (kAll as (k, loc)) =
case k of
KType => S.return2 kAll
| KArrow (k1, k2) =>
- S.bind2 (mfk k1,
+ S.bind2 (mfk ctx k1,
fn k1' =>
- S.map2 (mfk k2,
+ S.map2 (mfk ctx k2,
fn k2' =>
(KArrow (k1', k2'), loc)))
| KName => S.return2 kAll
| KRecord k =>
- S.map2 (mfk k,
+ S.map2 (mfk ctx k,
fn k' =>
(KRecord k', loc))
| KUnit => S.return2 kAll
| KTuple ks =>
- S.map2 (ListUtil.mapfold mfk ks,
+ S.map2 (ListUtil.mapfold (mfk ctx) ks,
fn ks' =>
(KTuple ks', loc))
| KError => S.return2 kAll
- | KUnif (_, _, ref (SOME k)) => mfk' k
+ | KUnif (_, _, ref (SOME k)) => mfk' ctx k
| KUnif _ => S.return2 kAll
+
+ | KRel _ => S.return2 kAll
+ | KFun (x, k) =>
+ S.map2 (mfk (bind (ctx, x)) k,
+ fn k' =>
+ (KFun (x, k'), loc))
in
mfk
end
+fun mapfold fk =
+ mapfoldB {kind = fn () => fk,
+ bind = fn ((), _) => ()} ()
+
+fun mapB {kind, bind} ctx k =
+ case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
+ bind = bind} ctx k () of
+ S.Continue (k, ()) => k
+ | S.Return _ => raise Fail "ElabUtil.Kind.mapB: Impossible"
+
fun exists f k =
case mapfold (fn k => fn () =>
if f k then
@@ -95,12 +111,13 @@ end
structure Con = struct
datatype binder =
- Rel of string * Elab.kind
- | Named of string * int * Elab.kind
+ RelK of string
+ | RelC of string * Elab.kind
+ | NamedC of string * int * Elab.kind
fun mapfoldB {kind = fk, con = fc, bind} =
let
- val mfk = Kind.mapfold fk
+ val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, s) => bind (ctx, RelK s)}
fun mfc ctx c acc =
S.bindP (mfc' ctx c acc, fc ctx)
@@ -114,9 +131,9 @@ fun mapfoldB {kind = fk, con = fc, bind} =
fn c2' =>
(TFun (c1', c2'), loc)))
| TCFun (e, x, k, c) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
- S.map2 (mfc (bind (ctx, Rel (x, k))) c,
+ S.map2 (mfc (bind (ctx, RelC (x, k))) c,
fn c' =>
(TCFun (e, x, k', c'), loc)))
| CDisjoint (ai, c1, c2, c3) =>
@@ -142,16 +159,16 @@ fun mapfoldB {kind = fk, con = fc, bind} =
fn c2' =>
(CApp (c1', c2'), loc)))
| CAbs (x, k, c) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
- S.map2 (mfc (bind (ctx, Rel (x, k))) c,
+ S.map2 (mfc (bind (ctx, RelC (x, k))) c,
fn c' =>
(CAbs (x, k', c'), loc)))
| CName _ => S.return2 cAll
| CRecord (k, xcs) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
S.map2 (ListUtil.mapfold (fn (x, c) =>
S.bind2 (mfc ctx x,
@@ -169,9 +186,9 @@ fun mapfoldB {kind = fk, con = fc, bind} =
fn c2' =>
(CConcat (c1', c2'), loc)))
| CMap (k1, k2) =>
- S.bind2 (mfk k1,
+ S.bind2 (mfk ctx k1,
fn k1' =>
- S.map2 (mfk k2,
+ S.map2 (mfk ctx k2,
fn k2' =>
(CMap (k1', k2'), loc)))
@@ -190,17 +207,32 @@ fun mapfoldB {kind = fk, con = fc, bind} =
| CError => S.return2 cAll
| CUnif (_, _, _, ref (SOME c)) => mfc' ctx c
| CUnif _ => S.return2 cAll
+
+ | CKAbs (x, c) =>
+ S.map2 (mfc (bind (ctx, RelK x)) c,
+ fn c' =>
+ (CKAbs (x, c'), loc))
+ | CKApp (c, k) =>
+ S.bind2 (mfc ctx c,
+ fn c' =>
+ S.map2 (mfk ctx k,
+ fn k' =>
+ (CKApp (c', k'), loc)))
+ | TKFun (x, c) =>
+ S.map2 (mfc (bind (ctx, RelK x)) c,
+ fn c' =>
+ (TKFun (x, c'), loc))
in
mfc
end
fun mapfold {kind = fk, con = fc} =
- mapfoldB {kind = fk,
+ mapfoldB {kind = fn () => fk,
con = fn () => fc,
bind = fn ((), _) => ()} ()
fun mapB {kind, con, bind} ctx c =
- case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()),
+ case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()),
bind = bind} ctx c () of
S.Continue (c, ()) => c
@@ -227,7 +259,7 @@ fun exists {kind, con} k =
| S.Continue _ => false
fun foldB {kind, con, bind} ctx st c =
- case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)),
+ case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)),
con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)),
bind = bind} ctx c st of
S.Continue (_, st) => st
@@ -238,20 +270,22 @@ end
structure Exp = struct
datatype binder =
- RelC of string * Elab.kind
+ RelK of string
+ | RelC of string * Elab.kind
| NamedC of string * int * Elab.kind
| RelE of string * Elab.con
| NamedE of string * Elab.con
fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
let
- val mfk = Kind.mapfold fk
+ val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)}
fun bind' (ctx, b) =
let
val b' = case b of
- Con.Rel x => RelC x
- | Con.Named x => NamedC x
+ Con.RelK x => RelK x
+ | Con.RelC x => RelC x
+ | Con.NamedC x => NamedC x
in
bind (ctx, b')
end
@@ -288,7 +322,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
fn c' =>
(ECApp (e', c'), loc)))
| ECAbs (expl, x, k, e) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
S.map2 (mfe (bind (ctx, RelC (x, k))) e,
fn e' =>
@@ -347,11 +381,6 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
fn rest' =>
(ECutMulti (e', c', {rest = rest'}), loc))))
- | EFold k =>
- S.map2 (mfk k,
- fn k' =>
- (EFold k', loc))
-
| ECase (e, pes, {disc, result}) =>
S.bind2 (mfe ctx e,
fn e' =>
@@ -406,6 +435,17 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
(ELet (des', e'), loc)))
end
+ | EKAbs (x, e) =>
+ S.map2 (mfe (bind (ctx, RelK x)) e,
+ fn e' =>
+ (EKAbs (x, e'), loc))
+ | EKApp (e, k) =>
+ S.bind2 (mfe ctx e,
+ fn e' =>
+ S.map2 (mfk ctx k,
+ fn k' =>
+ (EKApp (e', k'), loc)))
+
and mfed ctx (dAll as (d, loc)) =
case d of
EDVal vi =>
@@ -432,7 +472,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
end
fun mapfold {kind = fk, con = fc, exp = fe} =
- mapfoldB {kind = fk,
+ mapfoldB {kind = fn () => fk,
con = fn () => fc,
exp = fn () => fe,
bind = fn ((), _) => ()} ()
@@ -457,7 +497,7 @@ fun exists {kind, con, exp} k =
| S.Continue _ => false
fun mapB {kind, con, exp, bind} ctx e =
- case mapfoldB {kind = fn k => fn () => S.Continue (kind k, ()),
+ case mapfoldB {kind = fn ctx => fn k => fn () => S.Continue (kind ctx k, ()),
con = fn ctx => fn c => fn () => S.Continue (con ctx c, ()),
exp = fn ctx => fn e => fn () => S.Continue (exp ctx e, ()),
bind = bind} ctx e () of
@@ -465,7 +505,7 @@ fun mapB {kind, con, exp, bind} ctx e =
| S.Return _ => raise Fail "ElabUtil.Exp.mapB: Impossible"
fun foldB {kind, con, exp, bind} ctx st e =
- case mapfoldB {kind = fn k => fn st => S.Continue (k, kind (k, st)),
+ case mapfoldB {kind = fn ctx => fn k => fn st => S.Continue (k, kind (ctx, k, st)),
con = fn ctx => fn c => fn st => S.Continue (c, con (ctx, c, st)),
exp = fn ctx => fn e => fn st => S.Continue (e, exp (ctx, e, st)),
bind = bind} ctx e st of
@@ -477,7 +517,8 @@ end
structure Sgn = struct
datatype binder =
- RelC of string * Elab.kind
+ RelK of string
+ | RelC of string * Elab.kind
| NamedC of string * int * Elab.kind
| Str of string * Elab.sgn
| Sgn of string * Elab.sgn
@@ -487,14 +528,15 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
fun bind' (ctx, b) =
let
val b' = case b of
- Con.Rel x => RelC x
- | Con.Named x => NamedC x
+ Con.RelK x => RelK x
+ | Con.RelC x => RelC x
+ | Con.NamedC x => NamedC x
in
bind (ctx, b')
end
val con = Con.mapfoldB {kind = kind, con = con, bind = bind'}
- val kind = Kind.mapfold kind
+ val kind = Kind.mapfoldB {kind = kind, bind = fn (ctx, x) => bind (ctx, RelK x)}
fun sgi ctx si acc =
S.bindP (sgi' ctx si acc, sgn_item ctx)
@@ -502,11 +544,11 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
and sgi' ctx (siAll as (si, loc)) =
case si of
SgiConAbs (x, n, k) =>
- S.map2 (kind k,
+ S.map2 (kind ctx k,
fn k' =>
(SgiConAbs (x, n, k'), loc))
| SgiCon (x, n, k, c) =>
- S.bind2 (kind k,
+ S.bind2 (kind ctx k,
fn k' =>
S.map2 (con ctx c,
fn c' =>
@@ -548,11 +590,11 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
fn c2' =>
(SgiConstraint (c1', c2'), loc)))
| SgiClassAbs (x, n, k) =>
- S.map2 (kind k,
+ S.map2 (kind ctx k,
fn k' =>
(SgiClassAbs (x, n, k'), loc))
| SgiClass (x, n, k, c) =>
- S.bind2 (kind k,
+ S.bind2 (kind ctx k,
fn k' =>
S.map2 (con ctx c,
fn c' =>
@@ -608,7 +650,7 @@ fun mapfoldB {kind, con, sgn_item, sgn, bind} =
end
fun mapfold {kind, con, sgn_item, sgn} =
- mapfoldB {kind = kind,
+ mapfoldB {kind = fn () => kind,
con = fn () => con,
sgn_item = fn () => sgn_item,
sgn = fn () => sgn,
@@ -627,7 +669,8 @@ end
structure Decl = struct
datatype binder =
- RelC of string * Elab.kind
+ RelK of string
+ | RelC of string * Elab.kind
| NamedC of string * int * Elab.kind
| RelE of string * Elab.con
| NamedE of string * Elab.con
@@ -636,13 +679,14 @@ datatype binder =
fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = fst, decl = fd, bind} =
let
- val mfk = Kind.mapfold fk
+ val mfk = Kind.mapfoldB {kind = fk, bind = fn (ctx, x) => bind (ctx, RelK x)}
fun bind' (ctx, b) =
let
val b' = case b of
- Con.Rel x => RelC x
- | Con.Named x => NamedC x
+ Con.RelK x => RelK x
+ | Con.RelC x => RelC x
+ | Con.NamedC x => NamedC x
in
bind (ctx, b')
end
@@ -651,7 +695,8 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
fun bind' (ctx, b) =
let
val b' = case b of
- Exp.RelC x => RelC x
+ Exp.RelK x => RelK x
+ | Exp.RelC x => RelC x
| Exp.NamedC x => NamedC x
| Exp.RelE x => RelE x
| Exp.NamedE x => NamedE x
@@ -663,7 +708,8 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
fun bind' (ctx, b) =
let
val b' = case b of
- Sgn.RelC x => RelC x
+ Sgn.RelK x => RelK x
+ | Sgn.RelC x => RelC x
| Sgn.NamedC x => NamedC x
| Sgn.Sgn x => Sgn x
| Sgn.Str x => Str x
@@ -760,7 +806,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
and mfd' ctx (dAll as (d, loc)) =
case d of
DCon (x, n, k, c) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
S.map2 (mfc ctx c,
fn c' =>
@@ -825,7 +871,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
| DSequence _ => S.return2 dAll
| DClass (x, n, k, c) =>
- S.bind2 (mfk k,
+ S.bind2 (mfk ctx k,
fn k' =>
S.map2 (mfc ctx c,
fn c' =>
@@ -849,7 +895,7 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, sgn_item = fsgi, sgn = fsg, str = f
end
fun mapfold {kind, con, exp, sgn_item, sgn, str, decl} =
- mapfoldB {kind = kind,
+ mapfoldB {kind = fn () => kind,
con = fn () => con,
exp = fn () => exp,
sgn_item = fn () => sgn_item,
@@ -938,7 +984,7 @@ fun search {kind, con, exp, sgn_item, sgn, str, decl} k =
| S.Continue _ => NONE
fun foldMapB {kind, con, exp, sgn_item, sgn, str, decl, bind} ctx st d =
- case mapfoldB {kind = fn x => fn st => S.Continue (kind (x, st)),
+ case mapfoldB {kind = fn ctx => fn x => fn st => S.Continue (kind (ctx, x, st)),
con = fn ctx => fn x => fn st => S.Continue (con (ctx, x, st)),
exp = fn ctx => fn x => fn st => S.Continue (exp (ctx, x, st)),
sgn_item = fn ctx => fn x => fn st => S.Continue (sgn_item (ctx, x, st)),