summaryrefslogtreecommitdiff
path: root/src/elaborate.sml
diff options
context:
space:
mode:
authorGravatar Adam Chlipala <adamc@hcoop.net>2009-05-30 14:44:29 -0400
committerGravatar Adam Chlipala <adamc@hcoop.net>2009-05-30 14:44:29 -0400
commitbd0d3db78996b00e153252f03b02551ac3fde4cf (patch)
tree0fad0449847add724cfee07969e95597f7136ea8 /src/elaborate.sml
parent54276f5a38163eb7997c574810faed0cc6dea35c (diff)
Defer pattern-matching exhaustiveness checks and normalize pattern types more thoroughly
Diffstat (limited to 'src/elaborate.sml')
-rw-r--r--src/elaborate.sml279
1 files changed, 50 insertions, 229 deletions
diff --git a/src/elaborate.sml b/src/elaborate.sml
index b4ce1861..961c2b2e 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -625,6 +625,8 @@
val mayDelay = ref false
val delayedUnifs = ref ([] : (ErrorMsg.span * E.env * L'.kind * record_summary * record_summary) list)
+ val delayedExhaustives = ref ([] : (E.env * L'.con * L'.pat list * ErrorMsg.span) list)
+
fun unifyRecordCons env (loc, c1, c2) =
let
fun rkindof c =
@@ -1398,11 +1400,12 @@ fun exhaustive (env, t, ps, loc) =
val loc = #2 q1
fun unapp (t, acc) =
- case t of
- L'.CApp ((t, _), arg) => unapp (t, arg :: acc)
+ case #1 t of
+ L'.CApp (t, arg) => unapp (t, arg :: acc)
| _ => (t, rev acc)
- val (t1, args) = unapp (#1 (hnormCon env q1), [])
+ val (t1, args) = unapp (hnormCon env q1, [])
+ val t1 = hnormCon env t1
fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args
fun dtype (dtO, names) =
@@ -1423,7 +1426,7 @@ fun exhaustive (env, t, ps, loc) =
fun default () = (NONE, IS.singleton 0, [])
val (dtO, unused, cons) =
- case t1 of
+ case #1 t1 of
L'.CNamed n =>
let
val dt = E.lookupDatatype env n
@@ -1452,22 +1455,25 @@ fun exhaustive (env, t, ps, loc) =
NONE => []
| SOME t => [("", doSub t)])) cons)
end
- | L'.TRecord (L'.CRecord (_, xts), _) =>
- let
- val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
- | _ => NONE) xts
- in
- if List.all Option.isSome xts then
- let
- val xts = List.mapPartial (fn x => x) xts
- val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
- String.compare (x1, x2) = GREATER) xts
- in
- (NONE, IS.empty, [(0, xts)])
- end
- else
- default ()
- end
+ | L'.TRecord t =>
+ (case #1 (hnormCon env t) of
+ L'.CRecord (_, xts) =>
+ let
+ val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
+ | _ => NONE) xts
+ in
+ if List.all Option.isSome xts then
+ let
+ val xts = List.mapPartial (fn x => x) xts
+ val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
+ String.compare (x1, x2) = GREATER) xts
+ in
+ (NONE, IS.empty, [(0, xts)])
+ end
+ else
+ default ()
+ end
+ | _ => default ())
| _ => default ()
in
if IS.isEmpty unused then
@@ -1489,7 +1495,7 @@ fun exhaustive (env, t, ps, loc) =
0 => L'.PRecord (ListPair.map
(fn ((name, t), p) => (name, p, t))
(args, argPs))
- | _ => L'.PCon (L'.Default, nameOfNum (t1, name), [],
+ | _ => L'.PCon (L'.Default, nameOfNum (#1 t1, name), [],
case argPs of
[] => NONE
| [p] => SOME p
@@ -1518,9 +1524,9 @@ fun exhaustive (env, t, ps, loc) =
else
NONE) cons of
SOME (n, []) =>
- L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE)
+ L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], NONE)
| SOME (n, [_]) =>
- L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc))
+ L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], SOME (L'.PWild, loc))
| _ => fail 7
in
SOME ((p, loc) :: ps)
@@ -1533,211 +1539,6 @@ fun exhaustive (env, t, ps, loc) =
| _ => fail 7
end
-(*datatype coverage =
- Wild
- | None
- | Datatype of coverage IM.map
- | Record of coverage SM.map list
-
-fun c2s c =
- case c of
- Wild => "Wild"
- | None => "None"
- | Datatype _ => "Datatype"
- | Record _ => "Record"
-
-fun exhaustive (env, t, ps, loc) =
- let
- fun depth (p, _) =
- case p of
- L'.PWild => 0
- | L'.PVar _ => 0
- | L'.PPrim _ => 0
- | L'.PCon (_, _, _, NONE) => 1
- | L'.PCon (_, _, _, SOME p) => 1 + depth p
- | L'.PRecord xps => foldl (fn ((_, p, _), n) => Int.max (depth p, n)) 0 xps
-
- val depth = 1 + foldl (fn (p, n) => Int.max (depth p, n)) 0 ps
-
- fun pcCoverage pc =
- case pc of
- L'.PConVar n => n
- | L'.PConProj (m1, ms, x) =>
- let
- val (str, sgn) = E.chaseMpath env (m1, ms)
- in
- case E.projectConstructor env {str = str, sgn = sgn, field = x} of
- NONE => raise Fail "exhaustive: Can't project constructor"
- | SOME (_, n, _, _, _) => n
- end
-
- fun coverage (p, _) =
- case p of
- L'.PWild => Wild
- | L'.PVar _ => Wild
- | L'.PPrim _ => None
- | L'.PCon (_, pc, _, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
- | L'.PCon (_, pc, _, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
- | L'.PRecord xps => Record [foldl (fn ((x, p, _), fmap) =>
- SM.insert (fmap, x, coverage p)) SM.empty xps]
-
- fun merge (c1, c2) =
- case (c1, c2) of
- (None, _) => c2
- | (_, None) => c1
-
- | (Wild, _) => Wild
- | (_, Wild) => Wild
-
- | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
-
- | (Record fm1, Record fm2) => Record (fm1 @ fm2)
-
- | _ => None
-
- fun combinedCoverage ps =
- case ps of
- [] => raise Fail "Empty pattern list for coverage checking"
- | [p] => coverage p
- | p :: ps => merge (coverage p, combinedCoverage ps)
-
- fun enumerateCases depth t =
- (TextIO.print "enum'\n"; if depth <= 0 then
- [Wild]
- else
- let
- val dtype =
- ListUtil.mapConcat (fn (_, n, to) =>
- case to of
- NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
- | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
- (enumerateCases (depth-1) t))
- in
- case #1 (hnormCon env t) of
- L'.CNamed n =>
- (let
- val dt = E.lookupDatatype env n
- val cons = E.constructors dt
- in
- dtype cons
- end handle E.UnboundNamed _ => [Wild])
- | L'.TRecord c =>
- (case #1 (hnormCon env c) of
- L'.CRecord (_, xts) =>
- let
- val xts = map (fn (x, t) => (hnormCon env x, t)) xts
-
- fun exponentiate fs =
- case fs of
- [] => [SM.empty]
- | ((L'.CName x, _), t) :: rest =>
- let
- val this = enumerateCases depth t
- val rest = exponentiate rest
- in
- TextIO.print ("Before (" ^ Int.toString (length this)
- ^ ", " ^ Int.toString (length rest) ^ ")\n");
- ListUtil.mapConcat (fn fmap =>
- map (fn c => SM.insert (fmap, x, c)) this) rest
- before TextIO.print "After\n"
- end
- | _ => raise Fail "exponentiate: Not CName"
- in
- if List.exists (fn ((L'.CName _, _), _) => false
- | (c, _) => true) xts then
- [Wild]
- else
- map (fn ls => Record [ls]) (exponentiate xts)
- end
- | _ => [Wild])
- | _ => [Wild]
- end before TextIO.print "/enum'\n")
-
- fun coverageImp (c1, c2) =
- let
- val r =
- case (c1, c2) of
- (Wild, _) => true
-
- | (Datatype cmap1, Datatype cmap2) =>
- List.all (fn (n, c2) =>
- case IM.find (cmap1, n) of
- NONE => false
- | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
- | (Datatype cmap1, Wild) =>
- List.all (fn (n, c1) => coverageImp (c1, Wild)) (IM.listItemsi cmap1)
-
- | (Record fmaps1, Record fmaps2) =>
- List.all (fn fmap2 =>
- List.exists (fn fmap1 =>
- List.all (fn (x, c2) =>
- case SM.find (fmap1, x) of
- NONE => true
- | SOME c1 => coverageImp (c1, c2))
- (SM.listItemsi fmap2))
- fmaps1) fmaps2
-
- | (Record fmaps1, Wild) =>
- List.exists (fn fmap1 =>
- List.all (fn (x, c1) => coverageImp (c1, Wild))
- (SM.listItemsi fmap1)) fmaps1
-
- | _ => false
- in
- (*TextIO.print ("coverageImp(" ^ c2s c1 ^ ", " ^ c2s c2 ^ ") = " ^ Bool.toString r ^ "\n");*)
- r
- end
-
- fun isTotal (c, t) =
- case c of
- None => false
- | Wild => true
- | Datatype cm =>
- let
- val (t, _) = hnormCon env t
-
- val dtype =
- List.all (fn (_, n, to) =>
- case IM.find (cm, n) of
- NONE => false
- | SOME c' =>
- case to of
- NONE => true
- | SOME t' => isTotal (c', t'))
-
- fun unapp t =
- case t of
- L'.CApp ((t, _), _) => unapp t
- | _ => t
- in
- case unapp t of
- L'.CNamed n =>
- let
- val dt = E.lookupDatatype env n
- val cons = E.constructors dt
- in
- dtype cons
- end
- | L'.CModProj (m1, ms, x) =>
- let
- val (str, sgn) = E.chaseMpath env (m1, ms)
- in
- case E.projectDatatype env {str = str, sgn = sgn, field = x} of
- NONE => raise Fail "isTotal: Can't project datatype"
- | SOME (_, cons) => dtype cons
- end
- | L'.CError => true
- | c =>
- (prefaces "Not a datatype" [("loc", PD.string (ErrorMsg.spanToString loc)),
- ("c", p_con env (c, ErrorMsg.dummySpan))];
- raise Fail "isTotal: Not a datatype")
- end
- | Record _ => List.all (fn c2 => coverageImp (c, c2))
- (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n")
- in
- isTotal (combinedCoverage ps, t)
- end*)
-
fun unmodCon env (c, loc) =
case c of
L'.CNamed n =>
@@ -2083,7 +1884,10 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
in
case exhaustive (env, et, map #1 pes', loc) of
NONE => ()
- | SOME p => expError env (Inexhaustive (loc, p));
+ | SOME p => if !mayDelay then
+ delayedExhaustives := (env, et, map #1 pes', loc) :: !delayedExhaustives
+ else
+ expError env (Inexhaustive (loc, p));
((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs)
end
@@ -2113,6 +1917,13 @@ and elabEdecl denv (dAll as (d, loc), (env, gs)) =
val pt = normClassConstraint env pt
in
+ case exhaustive (env, et, [p'], loc) of
+ NONE => ()
+ | SOME p => if !mayDelay then
+ delayedExhaustives := (env, et, [p'], loc) :: !delayedExhaustives
+ else
+ expError env (Inexhaustive (loc, p));
+
((L'.EDVal (p', pt, e'), loc), (env', gs1 @ gs))
end
| L.EDValRec vis =>
@@ -3956,6 +3767,7 @@ fun elabFile basis topStr topSgn env file =
let
val () = mayDelay := true
val () = delayedUnifs := []
+ val () = delayedExhaustives := []
val (sgn, gs) = elabSgn (env, D.empty) (L.SgnConst basis, ErrorMsg.dummySpan)
val () = case gs of
@@ -4153,6 +3965,15 @@ fun elabFile basis topStr topSgn env file =
else
app (fn f => f ()) (!checks);
+ if ErrorMsg.anyErrors () then
+ ()
+ else
+ app (fn all as (_, _, _, loc) =>
+ case exhaustive all of
+ NONE => ()
+ | SOME p => expError env (Inexhaustive (loc, p)))
+ (!delayedExhaustives);
+
(*preface ("file", p_file env' file);*)
(L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan)