From d668886a45158cf3a292fdef3fa81498efd77652 Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Thu, 31 Jul 2008 13:08:57 -0400 Subject: Elaborating record patterns --- src/elaborate.sml | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 119 insertions(+), 6 deletions(-) (limited to 'src/elaborate.sml') diff --git a/src/elaborate.sml b/src/elaborate.sml index 10c5d214..0a71aa8c 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -38,10 +38,14 @@ open Print open ElabPrint structure IM = IntBinaryMap -structure SS = BinarySetFn(struct - type ord_key = string - val compare = String.compare - end) + +structure SK = struct +type ord_key = string +val compare = String.compare +end + +structure SS = BinarySetFn(SK) +structure SM = BinaryMapFn(SK) fun elabExplicitness e = case e of @@ -816,6 +820,7 @@ datatype exp_error = | PatHasArg of ErrorMsg.span | PatHasNoArg of ErrorMsg.span | Inexhaustive of ErrorMsg.span + | DuplicatePatField of ErrorMsg.span * string fun expError env err = case err of @@ -856,6 +861,8 @@ fun expError env err = ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument" | Inexhaustive loc => ErrorMsg.errorAt loc "Inexhaustive 'case'" + | DuplicatePatField (loc, s) => + ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern") fun checkCon (env, denv) e c1 c2 = unifyCons (env, denv) c1 c2 @@ -1021,13 +1028,45 @@ fun elabPat (pAll as (p, loc), (env, denv, bound)) = | SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn) end) - | L.PRecord _ => raise Fail "Elaborate PRecord" + | L.PRecord (xps, flex) => + let + val (xpts, (env, bound, _)) = + ListUtil.foldlMap (fn ((x, p), (env, bound, fbound)) => + let + val ((p', t), (env, bound)) = elabPat (p, (env, denv, bound)) + in + if SS.member (fbound, x) then + expError env (DuplicatePatField (loc, x)) + else + (); + ((x, p', t), (env, bound, SS.add (fbound, x))) + end) + (env, bound, SS.empty) xps + + val k = (L'.KType, loc) + val c = (L'.CRecord (k, map (fn (x, _, t) => ((L'.CName x, loc), t)) xpts), loc) + val (flex, c) = + if flex then + let + val flex = cunif (loc, (L'.KRecord k, loc)) + in + (SOME flex, (L'.CConcat (c, flex), loc)) + end + else + (NONE, c) + in + (((L'.PRecord (map (fn (x, p', _) => (x, p')) xpts, flex), loc), + (L'.TRecord c, loc)), + (env, bound)) + end + end datatype coverage = Wild | None | Datatype of coverage IM.map + | Record of coverage SM.map list fun exhaustive (env, denv, t, ps) = let @@ -1050,7 +1089,8 @@ fun exhaustive (env, denv, t, ps) = | L'.PPrim _ => None | L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild)) | L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p)) - + | L'.PRecord (xps, _) => Record [foldl (fn ((x, p), fmap) => + SM.insert (fmap, x, coverage p)) SM.empty xps] fun merge (c1, c2) = case (c1, c2) of (None, _) => c2 @@ -1061,12 +1101,84 @@ fun exhaustive (env, denv, t, ps) = | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2)) + | (Record fm1, Record fm2) => Record (fm1 @ fm2) + + | _ => None + fun combinedCoverage ps = case ps of [] => raise Fail "Empty pattern list for coverage checking" | [p] => coverage p | p :: ps => merge (coverage p, combinedCoverage ps) + fun enumerateCases t = + let + fun dtype cons = + ListUtil.mapConcat (fn (_, n, to) => + case to of + NONE => [Datatype (IM.insert (IM.empty, n, Wild))] + | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c))) + (enumerateCases t)) cons + in + case #1 (#1 (hnormCon (env, denv) t)) of + L'.CNamed n => + (let + val dt = E.lookupDatatype env n + val cons = E.constructors dt + in + dtype cons + end handle E.UnboundNamed _ => [Wild]) + | L'.TRecord c => + (case #1 (#1 (hnormCon (env, denv) c)) of + L'.CRecord (_, xts) => + let + val xts = map (fn (x, t) => (#1 (hnormCon (env, denv) x), t)) xts + + fun exponentiate fs = + case fs of + [] => [SM.empty] + | ((L'.CName x, _), t) :: rest => + let + val this = enumerateCases t + val rest = exponentiate rest + in + ListUtil.mapConcat (fn fmap => + map (fn c => SM.insert (fmap, x, c)) this) rest + end + | _ => raise Fail "exponentiate: Not CName" + in + if List.exists (fn ((L'.CName _, _), _) => false + | (c, _) => true) xts then + [Wild] + else + map (fn ls => Record [ls]) (exponentiate xts) + end + | _ => [Wild]) + | _ => [Wild] + end + + fun coverageImp (c1, c2) = + case (c1, c2) of + (Wild, _) => true + + | (Datatype cmap1, Datatype cmap2) => + List.all (fn (n, c2) => + case IM.find (cmap1, n) of + NONE => false + | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2) + + | (Record fmaps1, Record fmaps2) => + List.all (fn fmap2 => + List.exists (fn fmap1 => + List.all (fn (x, c2) => + case SM.find (fmap1, x) of + NONE => true + | SOME c1 => coverageImp (c1, c2)) + (SM.listItemsi fmap2)) + fmaps1) fmaps2 + + | _ => false + fun isTotal (c, t) = case c of None => (false, []) @@ -1109,6 +1221,7 @@ fun exhaustive (env, denv, t, ps) = | L'.CError => (true, gs) | _ => raise Fail "isTotal: Not a datatype" end + | Record _ => (List.all (fn c2 => coverageImp (c, c2)) (enumerateCases t), []) in isTotal (combinedCoverage ps, t) end -- cgit v1.2.3