From bd0d3db78996b00e153252f03b02551ac3fde4cf Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Sat, 30 May 2009 14:44:29 -0400 Subject: Defer pattern-matching exhaustiveness checks and normalize pattern types more thoroughly --- lib/ur/list.ur | 13 +++ lib/ur/list.urs | 5 +- src/elaborate.sml | 279 ++++++++++-------------------------------------------- 3 files changed, 67 insertions(+), 230 deletions(-) diff --git a/lib/ur/list.ur b/lib/ur/list.ur index 8493f2f5..0776ff30 100644 --- a/lib/ur/list.ur +++ b/lib/ur/list.ur @@ -74,6 +74,19 @@ fun mapM [m ::: (Type -> Type)] (_ : monad m) [a] [b] f = mapM' [] end +fun mapXM [m ::: (Type -> Type)] (_ : monad m) [a] [ctx ::: {Unit}] f = + let + fun mapXM' ls = + case ls of + [] => return + | x :: ls => + this <- f x; + rest <- mapXM' ls; + return {this}{rest} + in + mapXM' + end + fun filter [a] f = let fun fil acc ls = diff --git a/lib/ur/list.urs b/lib/ur/list.urs index 28f08317..92589508 100644 --- a/lib/ur/list.urs +++ b/lib/ur/list.urs @@ -15,7 +15,10 @@ val mapPartial : a ::: Type -> b ::: Type -> (a -> option b) -> t a -> t b val mapX : a ::: Type -> ctx ::: {Unit} -> (a -> xml ctx [] []) -> t a -> xml ctx [] [] val mapM : m ::: (Type -> Type) -> monad m -> a ::: Type -> b ::: Type - -> (a -> m b) -> list a -> m (list b) + -> (a -> m b) -> t a -> m (t b) + +val mapXM : m ::: (Type -> Type) -> monad m -> a ::: Type -> ctx ::: {Unit} + -> (a -> m (xml ctx [] [])) -> t a -> m (xml ctx [] []) val filter : a ::: Type -> (a -> bool) -> t a -> t a diff --git a/src/elaborate.sml b/src/elaborate.sml index b4ce1861..961c2b2e 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -625,6 +625,8 @@ val mayDelay = ref false val delayedUnifs = ref ([] : (ErrorMsg.span * E.env * L'.kind * record_summary * record_summary) list) + val delayedExhaustives = ref ([] : (E.env * L'.con * L'.pat list * ErrorMsg.span) list) + fun unifyRecordCons env (loc, c1, c2) = let fun rkindof c = @@ -1398,11 +1400,12 @@ fun exhaustive (env, t, ps, loc) = val loc = #2 q1 fun unapp (t, acc) = - case t of - L'.CApp ((t, _), arg) => unapp (t, arg :: acc) + case #1 t of + L'.CApp (t, arg) => unapp (t, arg :: acc) | _ => (t, rev acc) - val (t1, args) = unapp (#1 (hnormCon env q1), []) + val (t1, args) = unapp (hnormCon env q1, []) + val t1 = hnormCon env t1 fun doSub t = foldl (fn (arg, t) => subConInCon (0, arg) t) t args fun dtype (dtO, names) = @@ -1423,7 +1426,7 @@ fun exhaustive (env, t, ps, loc) = fun default () = (NONE, IS.singleton 0, []) val (dtO, unused, cons) = - case t1 of + case #1 t1 of L'.CNamed n => let val dt = E.lookupDatatype env n @@ -1452,22 +1455,25 @@ fun exhaustive (env, t, ps, loc) = NONE => [] | SOME t => [("", doSub t)])) cons) end - | L'.TRecord (L'.CRecord (_, xts), _) => - let - val xts = map (fn ((L'.CName x, _), co) => SOME (x, co) - | _ => NONE) xts - in - if List.all Option.isSome xts then - let - val xts = List.mapPartial (fn x => x) xts - val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) => - String.compare (x1, x2) = GREATER) xts - in - (NONE, IS.empty, [(0, xts)]) - end - else - default () - end + | L'.TRecord t => + (case #1 (hnormCon env t) of + L'.CRecord (_, xts) => + let + val xts = map (fn ((L'.CName x, _), co) => SOME (x, co) + | _ => NONE) xts + in + if List.all Option.isSome xts then + let + val xts = List.mapPartial (fn x => x) xts + val xts = ListMergeSort.sort (fn ((x1, _), (x2, _)) => + String.compare (x1, x2) = GREATER) xts + in + (NONE, IS.empty, [(0, xts)]) + end + else + default () + end + | _ => default ()) | _ => default () in if IS.isEmpty unused then @@ -1489,7 +1495,7 @@ fun exhaustive (env, t, ps, loc) = 0 => L'.PRecord (ListPair.map (fn ((name, t), p) => (name, p, t)) (args, argPs)) - | _ => L'.PCon (L'.Default, nameOfNum (t1, name), [], + | _ => L'.PCon (L'.Default, nameOfNum (#1 t1, name), [], case argPs of [] => NONE | [p] => SOME p @@ -1518,9 +1524,9 @@ fun exhaustive (env, t, ps, loc) = else NONE) cons of SOME (n, []) => - L'.PCon (L'.Default, nameOfNum (t1, n), [], NONE) + L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], NONE) | SOME (n, [_]) => - L'.PCon (L'.Default, nameOfNum (t1, n), [], SOME (L'.PWild, loc)) + L'.PCon (L'.Default, nameOfNum (#1 t1, n), [], SOME (L'.PWild, loc)) | _ => fail 7 in SOME ((p, loc) :: ps) @@ -1533,211 +1539,6 @@ fun exhaustive (env, t, ps, loc) = | _ => fail 7 end -(*datatype coverage = - Wild - | None - | Datatype of coverage IM.map - | Record of coverage SM.map list - -fun c2s c = - case c of - Wild => "Wild" - | None => "None" - | Datatype _ => "Datatype" - | Record _ => "Record" - -fun exhaustive (env, t, ps, loc) = - let - fun depth (p, _) = - case p of - L'.PWild => 0 - | L'.PVar _ => 0 - | L'.PPrim _ => 0 - | L'.PCon (_, _, _, NONE) => 1 - | L'.PCon (_, _, _, SOME p) => 1 + depth p - | L'.PRecord xps => foldl (fn ((_, p, _), n) => Int.max (depth p, n)) 0 xps - - val depth = 1 + foldl (fn (p, n) => Int.max (depth p, n)) 0 ps - - fun pcCoverage pc = - case pc of - L'.PConVar n => n - | L'.PConProj (m1, ms, x) => - let - val (str, sgn) = E.chaseMpath env (m1, ms) - in - case E.projectConstructor env {str = str, sgn = sgn, field = x} of - NONE => raise Fail "exhaustive: Can't project constructor" - | SOME (_, n, _, _, _) => n - end - - fun coverage (p, _) = - case p of - L'.PWild => Wild - | L'.PVar _ => Wild - | L'.PPrim _ => None - | L'.PCon (_, pc, _, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild)) - | L'.PCon (_, pc, _, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p)) - | L'.PRecord xps => Record [foldl (fn ((x, p, _), fmap) => - SM.insert (fmap, x, coverage p)) SM.empty xps] - - fun merge (c1, c2) = - case (c1, c2) of - (None, _) => c2 - | (_, None) => c1 - - | (Wild, _) => Wild - | (_, Wild) => Wild - - | (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2)) - - | (Record fm1, Record fm2) => Record (fm1 @ fm2) - - | _ => None - - fun combinedCoverage ps = - case ps of - [] => raise Fail "Empty pattern list for coverage checking" - | [p] => coverage p - | p :: ps => merge (coverage p, combinedCoverage ps) - - fun enumerateCases depth t = - (TextIO.print "enum'\n"; if depth <= 0 then - [Wild] - else - let - val dtype = - ListUtil.mapConcat (fn (_, n, to) => - case to of - NONE => [Datatype (IM.insert (IM.empty, n, Wild))] - | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c))) - (enumerateCases (depth-1) t)) - in - case #1 (hnormCon env t) of - L'.CNamed n => - (let - val dt = E.lookupDatatype env n - val cons = E.constructors dt - in - dtype cons - end handle E.UnboundNamed _ => [Wild]) - | L'.TRecord c => - (case #1 (hnormCon env c) of - L'.CRecord (_, xts) => - let - val xts = map (fn (x, t) => (hnormCon env x, t)) xts - - fun exponentiate fs = - case fs of - [] => [SM.empty] - | ((L'.CName x, _), t) :: rest => - let - val this = enumerateCases depth t - val rest = exponentiate rest - in - TextIO.print ("Before (" ^ Int.toString (length this) - ^ ", " ^ Int.toString (length rest) ^ ")\n"); - ListUtil.mapConcat (fn fmap => - map (fn c => SM.insert (fmap, x, c)) this) rest - before TextIO.print "After\n" - end - | _ => raise Fail "exponentiate: Not CName" - in - if List.exists (fn ((L'.CName _, _), _) => false - | (c, _) => true) xts then - [Wild] - else - map (fn ls => Record [ls]) (exponentiate xts) - end - | _ => [Wild]) - | _ => [Wild] - end before TextIO.print "/enum'\n") - - fun coverageImp (c1, c2) = - let - val r = - case (c1, c2) of - (Wild, _) => true - - | (Datatype cmap1, Datatype cmap2) => - List.all (fn (n, c2) => - case IM.find (cmap1, n) of - NONE => false - | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2) - | (Datatype cmap1, Wild) => - List.all (fn (n, c1) => coverageImp (c1, Wild)) (IM.listItemsi cmap1) - - | (Record fmaps1, Record fmaps2) => - List.all (fn fmap2 => - List.exists (fn fmap1 => - List.all (fn (x, c2) => - case SM.find (fmap1, x) of - NONE => true - | SOME c1 => coverageImp (c1, c2)) - (SM.listItemsi fmap2)) - fmaps1) fmaps2 - - | (Record fmaps1, Wild) => - List.exists (fn fmap1 => - List.all (fn (x, c1) => coverageImp (c1, Wild)) - (SM.listItemsi fmap1)) fmaps1 - - | _ => false - in - (*TextIO.print ("coverageImp(" ^ c2s c1 ^ ", " ^ c2s c2 ^ ") = " ^ Bool.toString r ^ "\n");*) - r - end - - fun isTotal (c, t) = - case c of - None => false - | Wild => true - | Datatype cm => - let - val (t, _) = hnormCon env t - - val dtype = - List.all (fn (_, n, to) => - case IM.find (cm, n) of - NONE => false - | SOME c' => - case to of - NONE => true - | SOME t' => isTotal (c', t')) - - fun unapp t = - case t of - L'.CApp ((t, _), _) => unapp t - | _ => t - in - case unapp t of - L'.CNamed n => - let - val dt = E.lookupDatatype env n - val cons = E.constructors dt - in - dtype cons - end - | L'.CModProj (m1, ms, x) => - let - val (str, sgn) = E.chaseMpath env (m1, ms) - in - case E.projectDatatype env {str = str, sgn = sgn, field = x} of - NONE => raise Fail "isTotal: Can't project datatype" - | SOME (_, cons) => dtype cons - end - | L'.CError => true - | c => - (prefaces "Not a datatype" [("loc", PD.string (ErrorMsg.spanToString loc)), - ("c", p_con env (c, ErrorMsg.dummySpan))]; - raise Fail "isTotal: Not a datatype") - end - | Record _ => List.all (fn c2 => coverageImp (c, c2)) - (TextIO.print "enum\n"; enumerateCases depth t before TextIO.print "/enum\n") - in - isTotal (combinedCoverage ps, t) - end*) - fun unmodCon env (c, loc) = case c of L'.CNamed n => @@ -2083,7 +1884,10 @@ fun elabExp (env, denv) (eAll as (e, loc)) = in case exhaustive (env, et, map #1 pes', loc) of NONE => () - | SOME p => expError env (Inexhaustive (loc, p)); + | SOME p => if !mayDelay then + delayedExhaustives := (env, et, map #1 pes', loc) :: !delayedExhaustives + else + expError env (Inexhaustive (loc, p)); ((L'.ECase (e', pes', {disc = et, result = result}), loc), result, gs) end @@ -2113,6 +1917,13 @@ and elabEdecl denv (dAll as (d, loc), (env, gs)) = val pt = normClassConstraint env pt in + case exhaustive (env, et, [p'], loc) of + NONE => () + | SOME p => if !mayDelay then + delayedExhaustives := (env, et, [p'], loc) :: !delayedExhaustives + else + expError env (Inexhaustive (loc, p)); + ((L'.EDVal (p', pt, e'), loc), (env', gs1 @ gs)) end | L.EDValRec vis => @@ -3956,6 +3767,7 @@ fun elabFile basis topStr topSgn env file = let val () = mayDelay := true val () = delayedUnifs := [] + val () = delayedExhaustives := [] val (sgn, gs) = elabSgn (env, D.empty) (L.SgnConst basis, ErrorMsg.dummySpan) val () = case gs of @@ -4153,6 +3965,15 @@ fun elabFile basis topStr topSgn env file = else app (fn f => f ()) (!checks); + if ErrorMsg.anyErrors () then + () + else + app (fn all as (_, _, _, loc) => + case exhaustive all of + NONE => () + | SOME p => expError env (Inexhaustive (loc, p))) + (!delayedExhaustives); + (*preface ("file", p_file env' file);*) (L'.DFfiStr ("Basis", basis_n, sgn), ErrorMsg.dummySpan) -- cgit v1.2.3