summaryrefslogtreecommitdiff
path: root/src/elaborate.sml
diff options
context:
space:
mode:
Diffstat (limited to 'src/elaborate.sml')
-rw-r--r--src/elaborate.sml125
1 files changed, 119 insertions, 6 deletions
diff --git a/src/elaborate.sml b/src/elaborate.sml
index 10c5d214..0a71aa8c 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -38,10 +38,14 @@ open Print
open ElabPrint
structure IM = IntBinaryMap
-structure SS = BinarySetFn(struct
- type ord_key = string
- val compare = String.compare
- end)
+
+structure SK = struct
+type ord_key = string
+val compare = String.compare
+end
+
+structure SS = BinarySetFn(SK)
+structure SM = BinaryMapFn(SK)
fun elabExplicitness e =
case e of
@@ -816,6 +820,7 @@ datatype exp_error =
| PatHasArg of ErrorMsg.span
| PatHasNoArg of ErrorMsg.span
| Inexhaustive of ErrorMsg.span
+ | DuplicatePatField of ErrorMsg.span * string
fun expError env err =
case err of
@@ -856,6 +861,8 @@ fun expError env err =
ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
| Inexhaustive loc =>
ErrorMsg.errorAt loc "Inexhaustive 'case'"
+ | DuplicatePatField (loc, s) =>
+ ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
fun checkCon (env, denv) e c1 c2 =
unifyCons (env, denv) c1 c2
@@ -1021,13 +1028,45 @@ fun elabPat (pAll as (p, loc), (env, denv, bound)) =
| SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn)
end)
- | L.PRecord _ => raise Fail "Elaborate PRecord"
+ | L.PRecord (xps, flex) =>
+ let
+ val (xpts, (env, bound, _)) =
+ ListUtil.foldlMap (fn ((x, p), (env, bound, fbound)) =>
+ let
+ val ((p', t), (env, bound)) = elabPat (p, (env, denv, bound))
+ in
+ if SS.member (fbound, x) then
+ expError env (DuplicatePatField (loc, x))
+ else
+ ();
+ ((x, p', t), (env, bound, SS.add (fbound, x)))
+ end)
+ (env, bound, SS.empty) xps
+
+ val k = (L'.KType, loc)
+ val c = (L'.CRecord (k, map (fn (x, _, t) => ((L'.CName x, loc), t)) xpts), loc)
+ val (flex, c) =
+ if flex then
+ let
+ val flex = cunif (loc, (L'.KRecord k, loc))
+ in
+ (SOME flex, (L'.CConcat (c, flex), loc))
+ end
+ else
+ (NONE, c)
+ in
+ (((L'.PRecord (map (fn (x, p', _) => (x, p')) xpts, flex), loc),
+ (L'.TRecord c, loc)),
+ (env, bound))
+ end
+
end
datatype coverage =
Wild
| None
| Datatype of coverage IM.map
+ | Record of coverage SM.map list
fun exhaustive (env, denv, t, ps) =
let
@@ -1050,7 +1089,8 @@ fun exhaustive (env, denv, t, ps) =
| L'.PPrim _ => None
| L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
| L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
-
+ | L'.PRecord (xps, _) => Record [foldl (fn ((x, p), fmap) =>
+ SM.insert (fmap, x, coverage p)) SM.empty xps]
fun merge (c1, c2) =
case (c1, c2) of
(None, _) => c2
@@ -1061,12 +1101,84 @@ fun exhaustive (env, denv, t, ps) =
| (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
+ | (Record fm1, Record fm2) => Record (fm1 @ fm2)
+
+ | _ => None
+
fun combinedCoverage ps =
case ps of
[] => raise Fail "Empty pattern list for coverage checking"
| [p] => coverage p
| p :: ps => merge (coverage p, combinedCoverage ps)
+ fun enumerateCases t =
+ let
+ fun dtype cons =
+ ListUtil.mapConcat (fn (_, n, to) =>
+ case to of
+ NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
+ | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
+ (enumerateCases t)) cons
+ in
+ case #1 (#1 (hnormCon (env, denv) t)) of
+ L'.CNamed n =>
+ (let
+ val dt = E.lookupDatatype env n
+ val cons = E.constructors dt
+ in
+ dtype cons
+ end handle E.UnboundNamed _ => [Wild])
+ | L'.TRecord c =>
+ (case #1 (#1 (hnormCon (env, denv) c)) of
+ L'.CRecord (_, xts) =>
+ let
+ val xts = map (fn (x, t) => (#1 (hnormCon (env, denv) x), t)) xts
+
+ fun exponentiate fs =
+ case fs of
+ [] => [SM.empty]
+ | ((L'.CName x, _), t) :: rest =>
+ let
+ val this = enumerateCases t
+ val rest = exponentiate rest
+ in
+ ListUtil.mapConcat (fn fmap =>
+ map (fn c => SM.insert (fmap, x, c)) this) rest
+ end
+ | _ => raise Fail "exponentiate: Not CName"
+ in
+ if List.exists (fn ((L'.CName _, _), _) => false
+ | (c, _) => true) xts then
+ [Wild]
+ else
+ map (fn ls => Record [ls]) (exponentiate xts)
+ end
+ | _ => [Wild])
+ | _ => [Wild]
+ end
+
+ fun coverageImp (c1, c2) =
+ case (c1, c2) of
+ (Wild, _) => true
+
+ | (Datatype cmap1, Datatype cmap2) =>
+ List.all (fn (n, c2) =>
+ case IM.find (cmap1, n) of
+ NONE => false
+ | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
+
+ | (Record fmaps1, Record fmaps2) =>
+ List.all (fn fmap2 =>
+ List.exists (fn fmap1 =>
+ List.all (fn (x, c2) =>
+ case SM.find (fmap1, x) of
+ NONE => true
+ | SOME c1 => coverageImp (c1, c2))
+ (SM.listItemsi fmap2))
+ fmaps1) fmaps2
+
+ | _ => false
+
fun isTotal (c, t) =
case c of
None => (false, [])
@@ -1109,6 +1221,7 @@ fun exhaustive (env, denv, t, ps) =
| L'.CError => (true, gs)
| _ => raise Fail "isTotal: Not a datatype"
end
+ | Record _ => (List.all (fn c2 => coverageImp (c, c2)) (enumerateCases t), [])
in
isTotal (combinedCoverage ps, t)
end