From b9b67597324deb6e6dfc8ef33c60c110abc2af7b Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Fri, 8 Aug 2008 17:55:51 -0400 Subject: Specialization of single-parameter datatypes --- src/core_util.sml | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 2 deletions(-) (limited to 'src/core_util.sml') diff --git a/src/core_util.sml b/src/core_util.sml index b7a16dc2..3fc57739 100644 --- a/src/core_util.sml +++ b/src/core_util.sml @@ -39,6 +39,28 @@ structure S = Search structure Kind = struct +open Order + +fun compare ((k1, _), (k2, _)) = + case (k1, k2) of + (KType, KType) => EQUAL + | (KType, _) => LESS + | (_, KType) => GREATER + + | (KArrow (d1, r1), KArrow (d2, r2)) => join (compare (d1, d2), fn () => compare (r1, r2)) + | (KArrow _, _) => LESS + | (_, KArrow _) => GREATER + + | (KName, KName) => EQUAL + | (KName, _) => LESS + | (_, KName) => GREATER + + | (KRecord k1, KRecord k2) => compare (k1, k2) + | (KRecord _, _) => LESS + | (_, KRecord _) => GREATER + + | (KUnit, KUnit) => EQUAL + fun mapfold f = let fun mfk k acc = @@ -85,6 +107,76 @@ end structure Con = struct +open Order + +fun compare ((c1, _), (c2, _)) = + case (c1, c2) of + (TFun (d1, r1), TFun (d2, r2)) => join (compare (d1, d2), fn () => compare (r1, r2)) + | (TFun _, _) => LESS + | (_, TFun _) => GREATER + + | (TCFun (x1, k1, r1), TCFun (x2, k2, r2)) => + join (String.compare (x1, x2), + fn () => join (Kind.compare (k1, k2), + fn () => compare (r1, r2))) + | (TCFun _, _) => LESS + | (_, TCFun _) => GREATER + + | (TRecord c1, TRecord c2) => compare (c1, c2) + | (TRecord _, _) => LESS + | (_, TRecord _) => GREATER + + | (CRel n1, CRel n2) => Int.compare (n1, n2) + | (CRel _, _) => LESS + | (_, CRel _) => GREATER + + | (CNamed n1, CNamed n2) => Int.compare (n1, n2) + | (CNamed _, _) => LESS + | (_, CNamed _) => GREATER + + | (CFfi (m1, s1), CFfi (m2, s2)) => join (String.compare (m1, m2), + fn () => String.compare (s1, s2)) + | (CFfi _, _) => LESS + | (_, CFfi _) => GREATER + + | (CApp (f1, x1), CApp (f2, x2)) => join (compare (f1, f2), + fn () => compare (x1, x2)) + | (CApp _, _) => LESS + | (_, CApp _) => GREATER + + | (CAbs (x1, k1, b1), CAbs (x2, k2, b2)) => + join (String.compare (x1, x2), + fn () => join (Kind.compare (k1, k2), + fn () => compare (b1, b2))) + | (CAbs _, _) => LESS + | (_, CAbs _) => GREATER + + | (CName s1, CName s2) => String.compare (s1, s2) + | (CName _, _) => LESS + | (_, CName _) => GREATER + + | (CRecord (k1, xvs1), CRecord (k2, xvs2)) => + join (Kind.compare (k1, k2), + fn () => joinL (fn ((x1, v1), (x2, v2)) => + join (compare (x1, x2), + fn () => compare (v1, v2))) (xvs1, xvs2)) + | (CRecord _, _) => LESS + | (_, CRecord _) => GREATER + + | (CConcat (f1, s1), CConcat (f2, s2)) => + join (compare (f1, f2), + fn () => compare (s1, s2)) + | (CConcat _, _) => LESS + | (_, CConcat _) => GREATER + + | (CFold (d1, r1), CFold (d2, r2)) => + join (Kind.compare (d1, r2), + fn () => Kind.compare (r1, r2)) + | (CFold _, _) => LESS + | (_, CFold _) => GREATER + + | (CUnit, CUnit) => EQUAL + datatype binder = Rel of string * kind | Named of string * int * kind * con option @@ -201,6 +293,12 @@ fun exists {kind, con} k = S.Return _ => true | S.Continue _ => false +fun foldMap {kind, con} s c = + case mapfold {kind = fn k => fn s => S.Continue (kind (k, s)), + con = fn c => fn s => S.Continue (con (c, s))} c s of + S.Continue v => v + | S.Return _ => raise Fail "CoreUtil.Con.foldMap: Impossible" + end structure Exp = struct @@ -317,8 +415,22 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = S.bind2 (mfe ctx e, fn e' => S.bind2 (ListUtil.mapfold (fn (p, e) => - S.map2 (mfe ctx e, - fn e' => (p, e'))) pes, + let + fun pb ((p, _), ctx) = + case p of + PWild => ctx + | PVar (x, t) => bind (ctx, RelE (x, t)) + | PPrim _ => ctx + | PCon (_, _, _, NONE) => ctx + | PCon (_, _, _, SOME p) => pb (p, ctx) + | PRecord xps => foldl (fn ((_, p, _), ctx) => + pb (p, ctx)) ctx xps + in + S.bind2 (mfp ctx p, + fn p' => + S.map2 (mfe (pb (p', ctx)) e, + fn e' => (p', e'))) + end) pes, fn pes' => S.bind2 (mfc ctx disc, fn disc' => @@ -335,6 +447,45 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = S.map2 (ListUtil.mapfold (mfe ctx) es, fn es' => (EClosure (n, es'), loc)) + + and mfp ctx (pAll as (p, loc)) = + case p of + PWild => S.return2 pAll + | PVar (x, t) => + S.map2 (mfc ctx t, + fn t' => + (PVar (x, t'), loc)) + | PPrim _ => S.return2 pAll + | PCon (dk, pc, args, po) => + S.bind2 (mfpc ctx pc, + fn pc' => + S.bind2 (ListUtil.mapfold (mfc ctx) args, + fn args' => + S.map2 ((case po of + NONE => S.return2 NONE + | SOME p => S.map2 (mfp ctx p, SOME)), + fn po' => + (PCon (dk, pc', args', po'), loc)))) + | PRecord xps => + S.map2 (ListUtil.mapfold (fn (x, p, c) => + S.bind2 (mfp ctx p, + fn p' => + S.map2 (mfc ctx c, + fn c' => + (x, p', c')))) xps, + fn xps' => + (PRecord xps', loc)) + + and mfpc ctx pc = + case pc of + PConVar _ => S.return2 pc + | PConFfi {mod = m, datatyp, params, con, arg, kind} => + S.map2 ((case arg of + NONE => S.return2 NONE + | SOME c => S.map2 (mfc ctx c, SOME)), + fn arg' => + PConFfi {mod = m, datatyp = datatyp, params = params, + con = con, arg = arg', kind = kind}) in mfe end -- cgit v1.2.3