From cabd451f495af6f122b77c61903cc17ee7832d71 Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Sat, 23 May 2009 09:45:02 -0400 Subject: Switch to Maranget's pattern exhaustiveness algorithm --- src/elab_err.sig | 2 +- src/elab_err.sml | 7 +- src/elaborate.sml | 260 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 252 insertions(+), 17 deletions(-) diff --git a/src/elab_err.sig b/src/elab_err.sig index c0a90e19..10cda7d3 100644 --- a/src/elab_err.sig +++ b/src/elab_err.sig @@ -71,7 +71,7 @@ signature ELAB_ERR = sig | UnboundConstructor of ErrorMsg.span * string list * string | PatHasArg of ErrorMsg.span | PatHasNoArg of ErrorMsg.span - | Inexhaustive of ErrorMsg.span + | Inexhaustive of ErrorMsg.span * Elab.pat | DuplicatePatField of ErrorMsg.span * string | Unresolvable of ErrorMsg.span * Elab.con | OutOfContext of ErrorMsg.span * (Elab.exp * Elab.con) option diff --git a/src/elab_err.sml b/src/elab_err.sml index 9eafa7df..dc34560b 100644 --- a/src/elab_err.sml +++ b/src/elab_err.sml @@ -161,7 +161,7 @@ datatype exp_error = | UnboundConstructor of ErrorMsg.span * string list * string | PatHasArg of ErrorMsg.span | PatHasNoArg of ErrorMsg.span - | Inexhaustive of ErrorMsg.span + | Inexhaustive of ErrorMsg.span * pat | DuplicatePatField of ErrorMsg.span * string | Unresolvable of ErrorMsg.span * con | OutOfContext of ErrorMsg.span * (exp * con) option @@ -207,8 +207,9 @@ fun expError env err = ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument" | PatHasNoArg loc => ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument" - | Inexhaustive loc => - ErrorMsg.errorAt loc "Inexhaustive 'case'" + | Inexhaustive (loc, p) => + (ErrorMsg.errorAt loc "Inexhaustive 'case'"; + eprefaces' [("Missed case", p_pat env p)]) | DuplicatePatField (loc, s) => ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern") | OutOfContext (loc, co) => diff --git a/src/elaborate.sml b/src/elaborate.sml index 8b23d91e..fb376df2 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -38,6 +38,7 @@ open ElabPrint open ElabErr + structure IS = IntBinarySet structure IM = IntBinaryMap structure SK = struct @@ -1291,7 +1292,238 @@ fun elabPat (pAll as (p, loc), (env, bound)) = end -datatype coverage = +(* This exhaustiveness checking follows Luc Maranget's paper "Warnings for pattern matching." *) +fun exhaustive (env, t, ps, loc) = + let + fun fail n = raise Fail ("Elaborate.exhaustive: Impossible " ^ Int.toString n) + + fun patConNum pc = + case pc of + L'.PConVar n => n + | L'.PConProj (m1, ms, x) => + let + val (str, sgn) = E.chaseMpath env (m1, ms) + in + case E.projectConstructor env {str = str, sgn = sgn, field = x} of + NONE => raise Fail "exhaustive: Can't project datatype" + | SOME (_, n, _, _, _) => n + end + + fun nameOfNum (t, n) = + case t of + L'.CModProj (m1, ms, x) => + let + val (str, sgn) = E.chaseMpath env (m1, ms) + in + case E.projectDatatype env {str = str, sgn = sgn, field = x} of + NONE => raise Fail "exhaustive: Can't project datatype" + | SOME (_, cons) => + case ListUtil.search (fn (name, n', _) => + if n' = n then + SOME name + else + NONE) cons of + NONE => fail 9 + | SOME name => L'.PConProj (m1, ms, name) + end + | _ => L'.PConVar n + + fun S (args, c, P) = + List.mapPartial + (fn [] => fail 1 + | p1 :: ps => + let + val loc = #2 p1 + + fun wild () = + SOME (map (fn _ => (L'.PWild, loc)) args @ ps) + in + case #1 p1 of + L'.PPrim _ => NONE + | L'.PCon (_, c', _, NONE) => + if patConNum c' = c then + SOME ps + else + NONE + | L'.PCon (_, c', _, SOME p) => + if patConNum c' = c then + SOME (p :: ps) + else + NONE + | L'.PRecord xpts => + SOME (map (fn x => + case ListUtil.search (fn (x', p, _) => + if x = x' then + SOME p + else + NONE) xpts of + NONE => (L'.PWild, loc) + | SOME p => p) args @ ps) + | L'.PWild => wild () + | L'.PVar _ => wild () + end) + P + + fun D P = + List.mapPartial + (fn [] => fail 2 + | (p1, _) :: ps => + case p1 of + L'.PWild => SOME ps + | L'.PVar _ => SOME ps + | L'.PPrim _ => NONE + | L'.PCon _ => NONE + | L'.PRecord _ => NONE) + P + + fun I (P, q) = + (*(prefaces "I" [("P", p_list (fn P' => box [PD.string "[", p_list (p_pat env) P', PD.string "]"]) P), + ("q", p_list (p_con env) q)];*) + case q of + [] => (case P of + [] => SOME [] + | _ => NONE) + | q1 :: qs => + let + val loc = #2 q1 + + fun unapp (t, acc) = + case t of + L'.CApp ((t, _), arg) => unapp (t, arg :: acc) + | _ => (t, rev acc) + + val (t1, args) = unapp (#1 (hnormCon env q1), []) + fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args + + fun dtype (dtO, names) = + let + val nameSet = IS.addList (IS.empty, names) + val nameSet = foldl (fn (ps, nameSet) => + case ps of + [] => fail 4 + | (L'.PCon (_, pc, _, _), _) :: _ => + (IS.delete (nameSet, patConNum pc) + handle NotFound => nameSet) + | _ => nameSet) + nameSet P + in + nameSet + end + + fun default () = (NONE, IS.singleton 0, []) + + val (dtO, unused, cons) = + case t1 of + L'.CNamed n => + let + val dt = E.lookupDatatype env n + val cons = E.constructors dt + in + (SOME dt, + dtype (SOME dt, map #2 cons), + map (fn (_, n, co) => + (n, + case co of + NONE => [] + | SOME t => [("", doSub t)])) cons) + end + | L'.CModProj (m1, ms, x) => + let + val (str, sgn) = E.chaseMpath env (m1, ms) + in + case E.projectDatatype env {str = str, sgn = sgn, field = x} of + NONE => default () + | SOME (_, cons) => + (NONE, + dtype (NONE, map #2 cons), + map (fn (s, _, co) => + (patConNum (L'.PConProj (m1, ms, s)), + case co of + NONE => [] + | SOME t => [("", doSub t)])) cons) + end + | L'.TRecord (L'.CRecord (_, xts), _) => + let + val xts = map (fn ((L'.CName x, _), co) => SOME (x, co) + | _ => NONE) xts + in + if List.all Option.isSome xts then + let + val xts = List.mapPartial (fn x => x) xts + val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) => + String.compare (x1, x2) = GREATER) xts + in + (NONE, IS.empty, [(0, xts)]) + end + else + default () + end + | _ => default () + in + if IS.isEmpty unused then + let + fun recurse cons = + case cons of + [] => NONE + | (name, args) :: cons => + case I (S (map #1 args, name, P), + map #2 args @ qs) of + NONE => recurse cons + | SOME ps => + let + val nargs = length args + val argPs = List.take (ps, nargs) + val restPs = List.drop (ps, nargs) + + val p = case name of + 0 => L'.PRecord (ListPair.map + (fn ((name, t), p) => (name, p, t)) + (args, argPs)) + | _ => L'.PCon (L'.Default, nameOfNum (t1, name), [], + case argPs of + [] => NONE + | [p] => SOME p + | _ => fail 3) + in + SOME ((p, loc) :: restPs) + end + in + recurse cons + end + else + case I (D P, qs) of + NONE => NONE + | SOME ps => + let + val p = case cons of + [] => L'.PWild + | (0, _) :: _ => L'.PWild + | _ => + case IS.find (fn _ => true) unused of + NONE => fail 6 + | SOME name => + case ListUtil.search (fn (name', args) => + if name = name' then + SOME (name', args) + else + NONE) cons of + SOME (n, []) => + L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE) + | SOME (n, [_]) => + L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc)) + | _ => fail 7 + in + SOME ((p, loc) :: ps) + end + end + in + case I (map (fn x => [x]) ps, [t]) of + NONE => NONE + | SOME [p] => SOME p + | _ => fail 7 + end + +(*datatype coverage = Wild | None | Datatype of coverage IM.map @@ -1360,16 +1592,16 @@ fun exhaustive (env, t, ps, loc) = | p :: ps => merge (coverage p, combinedCoverage ps) fun enumerateCases depth t = - if depth = 0 then + (TextIO.print "enum'\n"; if depth <= 0 then [Wild] else let - fun dtype cons = + val dtype = ListUtil.mapConcat (fn (_, n, to) => case to of NONE => [Datatype (IM.insert (IM.empty, n, Wild))] | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c))) - (enumerateCases (depth-1) t)) cons + (enumerateCases (depth-1) t)) in case #1 (hnormCon env t) of L'.CNamed n => @@ -1393,8 +1625,11 @@ fun exhaustive (env, t, ps, loc) = val this = enumerateCases depth t val rest = exponentiate rest in + TextIO.print ("Before (" ^ Int.toString (length this) + ^ ", " ^ Int.toString (length rest) ^ ")\n"); ListUtil.mapConcat (fn fmap => map (fn c => SM.insert (fmap, x, c)) this) rest + before TextIO.print "After\n" end | _ => raise Fail "exponentiate: Not CName" in @@ -1406,7 +1641,7 @@ fun exhaustive (env, t, ps, loc) = end | _ => [Wild]) | _ => [Wild] - end + end before TextIO.print "/enum'\n") fun coverageImp (c1, c2) = let @@ -1487,10 +1722,11 @@ fun exhaustive (env, t, ps, loc) = ("c", p_con env (c, ErrorMsg.dummySpan))]; raise Fail "isTotal: Not a datatype") end - | Record _ => List.all (fn c2 => coverageImp (c, c2)) (enumerateCases depth t) + | Record _ => List.all (fn c2 => coverageImp (c, c2)) + (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n") in isTotal (combinedCoverage ps, t) - end + end*) fun unmodCon env (c, loc) = case c of @@ -1835,10 +2071,9 @@ fun elabExp (env, denv) (eAll as (e, loc)) = end) gs1 pes in - if exhaustive (env, et, map #1 pes', loc) then - () - else - expError env (Inexhaustive loc); + case exhaustive (env, et, map #1 pes', loc) of + NONE => () + | SOME p => expError env (Inexhaustive (loc, p)); ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs) end @@ -1851,8 +2086,7 @@ fun elabExp (env, denv) (eAll as (e, loc)) = ((L'.ELet (eds, e), loc), t, gs1 @ gs2) end in - (*prefaces "elabExp" [("e", SourcePrint.p_exp eAll), - ("t", PD.string (LargeReal.toString (Time.toReal (Time.- (Time.now (), befor)))))];*) + (*prefaces "/elabExp" [("e", SourcePrint.p_exp eAll)];*) r end -- cgit v1.2.3