summaryrefslogtreecommitdiff
path: root/src/elaborate.sml
diff options
context:
space:
mode:
authorGravatar Adam Chlipala <adamc@hcoop.net>2009-05-23 09:45:02 -0400
committerGravatar Adam Chlipala <adamc@hcoop.net>2009-05-23 09:45:02 -0400
commitcabd451f495af6f122b77c61903cc17ee7832d71 (patch)
tree61c0180868573d29dd2b29c2c065539368c79dc7 /src/elaborate.sml
parent32f6bd8f1bcf65a5db96160d63ef2050c9eb5e52 (diff)
Switch to Maranget's pattern exhaustiveness algorithm
Diffstat (limited to 'src/elaborate.sml')
-rw-r--r--src/elaborate.sml260
1 files changed, 247 insertions, 13 deletions
diff --git a/src/elaborate.sml b/src/elaborate.sml
index 8b23d91e..fb376df2 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -38,6 +38,7 @@
open ElabPrint
open ElabErr
+ structure IS = IntBinarySet
structure IM = IntBinaryMap
structure SK = struct
@@ -1291,7 +1292,238 @@ fun elabPat (pAll as (p, loc), (env, bound)) =
end
-datatype coverage =
+(* This exhaustiveness checking follows Luc Maranget's paper "Warnings for pattern matching." *)
+fun exhaustive (env, t, ps, loc) =
+ let
+ fun fail n = raise Fail ("Elaborate.exhaustive: Impossible " ^ Int.toString n)
+
+ fun patConNum pc =
+ case pc of
+ L'.PConVar n => n
+ | L'.PConProj (m1, ms, x) =>
+ let
+ val (str, sgn) = E.chaseMpath env (m1, ms)
+ in
+ case E.projectConstructor env {str = str, sgn = sgn, field = x} of
+ NONE => raise Fail "exhaustive: Can't project datatype"
+ | SOME (_, n, _, _, _) => n
+ end
+
+ fun nameOfNum (t, n) =
+ case t of
+ L'.CModProj (m1, ms, x) =>
+ let
+ val (str, sgn) = E.chaseMpath env (m1, ms)
+ in
+ case E.projectDatatype env {str = str, sgn = sgn, field = x} of
+ NONE => raise Fail "exhaustive: Can't project datatype"
+ | SOME (_, cons) =>
+ case ListUtil.search (fn (name, n', _) =>
+ if n' = n then
+ SOME name
+ else
+ NONE) cons of
+ NONE => fail 9
+ | SOME name => L'.PConProj (m1, ms, name)
+ end
+ | _ => L'.PConVar n
+
+ fun S (args, c, P) =
+ List.mapPartial
+ (fn [] => fail 1
+ | p1 :: ps =>
+ let
+ val loc = #2 p1
+
+ fun wild () =
+ SOME (map (fn _ => (L'.PWild, loc)) args @ ps)
+ in
+ case #1 p1 of
+ L'.PPrim _ => NONE
+ | L'.PCon (_, c', _, NONE) =>
+ if patConNum c' = c then
+ SOME ps
+ else
+ NONE
+ | L'.PCon (_, c', _, SOME p) =>
+ if patConNum c' = c then
+ SOME (p :: ps)
+ else
+ NONE
+ | L'.PRecord xpts =>
+ SOME (map (fn x =>
+ case ListUtil.search (fn (x', p, _) =>
+ if x = x' then
+ SOME p
+ else
+ NONE) xpts of
+ NONE => (L'.PWild, loc)
+ | SOME p => p) args @ ps)
+ | L'.PWild => wild ()
+ | L'.PVar _ => wild ()
+ end)
+ P
+
+ fun D P =
+ List.mapPartial
+ (fn [] => fail 2
+ | (p1, _) :: ps =>
+ case p1 of
+ L'.PWild => SOME ps
+ | L'.PVar _ => SOME ps
+ | L'.PPrim _ => NONE
+ | L'.PCon _ => NONE
+ | L'.PRecord _ => NONE)
+ P
+
+ fun I (P, q) =
+ (*(prefaces "I" [("P", p_list (fn P' => box [PD.string "[", p_list (p_pat env) P', PD.string "]"]) P),
+ ("q", p_list (p_con env) q)];*)
+ case q of
+ [] => (case P of
+ [] => SOME []
+ | _ => NONE)
+ | q1 :: qs =>
+ let
+ val loc = #2 q1
+
+ fun unapp (t, acc) =
+ case t of
+ L'.CApp ((t, _), arg) => unapp (t, arg :: acc)
+ | _ => (t, rev acc)
+
+ val (t1, args) = unapp (#1 (hnormCon env q1), [])
+ fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args
+
+ fun dtype (dtO, names) =
+ let
+ val nameSet = IS.addList (IS.empty, names)
+ val nameSet = foldl (fn (ps, nameSet) =>
+ case ps of
+ [] => fail 4
+ | (L'.PCon (_, pc, _, _), _) :: _ =>
+ (IS.delete (nameSet, patConNum pc)
+ handle NotFound => nameSet)
+ | _ => nameSet)
+ nameSet P
+ in
+ nameSet
+ end
+
+ fun default () = (NONE, IS.singleton 0, [])
+
+ val (dtO, unused, cons) =
+ case t1 of
+ L'.CNamed n =>
+ let
+ val dt = E.lookupDatatype env n
+ val cons = E.constructors dt
+ in
+ (SOME dt,
+ dtype (SOME dt, map #2 cons),
+ map (fn (_, n, co) =>
+ (n,
+ case co of
+ NONE => []
+ | SOME t => [("", doSub t)])) cons)
+ end
+ | L'.CModProj (m1, ms, x) =>
+ let
+ val (str, sgn) = E.chaseMpath env (m1, ms)
+ in
+ case E.projectDatatype env {str = str, sgn = sgn, field = x} of
+ NONE => default ()
+ | SOME (_, cons) =>
+ (NONE,
+ dtype (NONE, map #2 cons),
+ map (fn (s, _, co) =>
+ (patConNum (L'.PConProj (m1, ms, s)),
+ case co of
+ NONE => []
+ | SOME t => [("", doSub t)])) cons)
+ end
+ | L'.TRecord (L'.CRecord (_, xts), _) =>
+ let
+ val xts = map (fn ((L'.CName x, _), co) => SOME (x, co)
+ | _ => NONE) xts
+ in
+ if List.all Option.isSome xts then
+ let
+ val xts = List.mapPartial (fn x => x) xts
+ val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) =>
+ String.compare (x1, x2) = GREATER) xts
+ in
+ (NONE, IS.empty, [(0, xts)])
+ end
+ else
+ default ()
+ end
+ | _ => default ()
+ in
+ if IS.isEmpty unused then
+ let
+ fun recurse cons =
+ case cons of
+ [] => NONE
+ | (name, args) :: cons =>
+ case I (S (map #1 args, name, P),
+ map #2 args @ qs) of
+ NONE => recurse cons
+ | SOME ps =>
+ let
+ val nargs = length args
+ val argPs = List.take (ps, nargs)
+ val restPs = List.drop (ps, nargs)
+
+ val p = case name of
+ 0 => L'.PRecord (ListPair.map
+ (fn ((name, t), p) => (name, p, t))
+ (args, argPs))
+ | _ => L'.PCon (L'.Default, nameOfNum (t1, name), [],
+ case argPs of
+ [] => NONE
+ | [p] => SOME p
+ | _ => fail 3)
+ in
+ SOME ((p, loc) :: restPs)
+ end
+ in
+ recurse cons
+ end
+ else
+ case I (D P, qs) of
+ NONE => NONE
+ | SOME ps =>
+ let
+ val p = case cons of
+ [] => L'.PWild
+ | (0, _) :: _ => L'.PWild
+ | _ =>
+ case IS.find (fn _ => true) unused of
+ NONE => fail 6
+ | SOME name =>
+ case ListUtil.search (fn (name', args) =>
+ if name = name' then
+ SOME (name', args)
+ else
+ NONE) cons of
+ SOME (n, []) =>
+ L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE)
+ | SOME (n, [_]) =>
+ L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc))
+ | _ => fail 7
+ in
+ SOME ((p, loc) :: ps)
+ end
+ end
+ in
+ case I (map (fn x => [x]) ps, [t]) of
+ NONE => NONE
+ | SOME [p] => SOME p
+ | _ => fail 7
+ end
+
+(*datatype coverage =
Wild
| None
| Datatype of coverage IM.map
@@ -1360,16 +1592,16 @@ fun exhaustive (env, t, ps, loc) =
| p :: ps => merge (coverage p, combinedCoverage ps)
fun enumerateCases depth t =
- if depth = 0 then
+ (TextIO.print "enum'\n"; if depth <= 0 then
[Wild]
else
let
- fun dtype cons =
+ val dtype =
ListUtil.mapConcat (fn (_, n, to) =>
case to of
NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
| SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
- (enumerateCases (depth-1) t)) cons
+ (enumerateCases (depth-1) t))
in
case #1 (hnormCon env t) of
L'.CNamed n =>
@@ -1393,8 +1625,11 @@ fun exhaustive (env, t, ps, loc) =
val this = enumerateCases depth t
val rest = exponentiate rest
in
+ TextIO.print ("Before (" ^ Int.toString (length this)
+ ^ ", " ^ Int.toString (length rest) ^ ")\n");
ListUtil.mapConcat (fn fmap =>
map (fn c => SM.insert (fmap, x, c)) this) rest
+ before TextIO.print "After\n"
end
| _ => raise Fail "exponentiate: Not CName"
in
@@ -1406,7 +1641,7 @@ fun exhaustive (env, t, ps, loc) =
end
| _ => [Wild])
| _ => [Wild]
- end
+ end before TextIO.print "/enum'\n")
fun coverageImp (c1, c2) =
let
@@ -1487,10 +1722,11 @@ fun exhaustive (env, t, ps, loc) =
("c", p_con env (c, ErrorMsg.dummySpan))];
raise Fail "isTotal: Not a datatype")
end
- | Record _ => List.all (fn c2 => coverageImp (c, c2)) (enumerateCases depth t)
+ | Record _ => List.all (fn c2 => coverageImp (c, c2))
+ (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n")
in
isTotal (combinedCoverage ps, t)
- end
+ end*)
fun unmodCon env (c, loc) =
case c of
@@ -1835,10 +2071,9 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
end)
gs1 pes
in
- if exhaustive (env, et, map #1 pes', loc) then
- ()
- else
- expError env (Inexhaustive loc);
+ case exhaustive (env, et, map #1 pes', loc) of
+ NONE => ()
+ | SOME p => expError env (Inexhaustive (loc, p));
((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs)
end
@@ -1851,8 +2086,7 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
((L'.ELet (eds, e), loc), t, gs1 @ gs2)
end
in
- (*prefaces "elabExp" [("e", SourcePrint.p_exp eAll),
- ("t", PD.string (LargeReal.toString (Time.toReal (Time.- (Time.now (), befor)))))];*)
+ (*prefaces "/elabExp" [("e", SourcePrint.p_exp eAll)];*)
r
end