summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Adam Chlipala <adamc@hcoop.net>2008-07-31 13:08:57 -0400
committerGravatar Adam Chlipala <adamc@hcoop.net>2008-07-31 13:08:57 -0400
commitd668886a45158cf3a292fdef3fa81498efd77652 (patch)
tree8b94f33c7f9d49dfb2a8b7b65cee62097fcf9630
parent183c43eb783edd68f76f941fa61b6ef1f8752a56 (diff)
Elaborating record patterns
-rw-r--r--src/elab.sml1
-rw-r--r--src/elab_print.sml13
-rw-r--r--src/elaborate.sml125
-rw-r--r--src/lacweb.grm4
-rw-r--r--src/source_print.sml6
-rw-r--r--tests/rpat.lac13
6 files changed, 150 insertions, 12 deletions
diff --git a/src/elab.sml b/src/elab.sml
index 34e0c91c..af21def0 100644
--- a/src/elab.sml
+++ b/src/elab.sml
@@ -80,6 +80,7 @@ datatype pat' =
| PVar of string
| PPrim of Prim.t
| PCon of patCon * pat option
+ | PRecord of (string * pat) list * con option
withtype pat = pat' located
diff --git a/src/elab_print.sml b/src/elab_print.sml
index 0e7fe7d7..d0ff8d5f 100644
--- a/src/elab_print.sml
+++ b/src/elab_print.sml
@@ -220,8 +220,19 @@ fun p_pat' par env (p, _) =
| PCon (pc, SOME p) => parenIf par (box [p_patCon env pc,
space,
p_pat' true env p])
+ | PRecord (xps, flex) =>
+ let
+ val pps = map (fn (x, p) => box [string x, space, string "=", space, p_pat env p]) xps
+ in
+ box [string "{",
+ p_list_sep (box [string ",", space]) (fn x => x)
+ (case flex of
+ NONE => pps
+ | SOME _ => pps @ [string "..."]),
+ string "}"]
+ end
-val p_pat = p_pat' false
+and p_pat x = p_pat' false x
fun p_exp' par env (e, _) =
case e of
diff --git a/src/elaborate.sml b/src/elaborate.sml
index 10c5d214..0a71aa8c 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -38,10 +38,14 @@ open Print
open ElabPrint
structure IM = IntBinaryMap
-structure SS = BinarySetFn(struct
- type ord_key = string
- val compare = String.compare
- end)
+
+structure SK = struct
+type ord_key = string
+val compare = String.compare
+end
+
+structure SS = BinarySetFn(SK)
+structure SM = BinaryMapFn(SK)
fun elabExplicitness e =
case e of
@@ -816,6 +820,7 @@ datatype exp_error =
| PatHasArg of ErrorMsg.span
| PatHasNoArg of ErrorMsg.span
| Inexhaustive of ErrorMsg.span
+ | DuplicatePatField of ErrorMsg.span * string
fun expError env err =
case err of
@@ -856,6 +861,8 @@ fun expError env err =
ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
| Inexhaustive loc =>
ErrorMsg.errorAt loc "Inexhaustive 'case'"
+ | DuplicatePatField (loc, s) =>
+ ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
fun checkCon (env, denv) e c1 c2 =
unifyCons (env, denv) c1 c2
@@ -1021,13 +1028,45 @@ fun elabPat (pAll as (p, loc), (env, denv, bound)) =
| SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn)
end)
- | L.PRecord _ => raise Fail "Elaborate PRecord"
+ | L.PRecord (xps, flex) =>
+ let
+ val (xpts, (env, bound, _)) =
+ ListUtil.foldlMap (fn ((x, p), (env, bound, fbound)) =>
+ let
+ val ((p', t), (env, bound)) = elabPat (p, (env, denv, bound))
+ in
+ if SS.member (fbound, x) then
+ expError env (DuplicatePatField (loc, x))
+ else
+ ();
+ ((x, p', t), (env, bound, SS.add (fbound, x)))
+ end)
+ (env, bound, SS.empty) xps
+
+ val k = (L'.KType, loc)
+ val c = (L'.CRecord (k, map (fn (x, _, t) => ((L'.CName x, loc), t)) xpts), loc)
+ val (flex, c) =
+ if flex then
+ let
+ val flex = cunif (loc, (L'.KRecord k, loc))
+ in
+ (SOME flex, (L'.CConcat (c, flex), loc))
+ end
+ else
+ (NONE, c)
+ in
+ (((L'.PRecord (map (fn (x, p', _) => (x, p')) xpts, flex), loc),
+ (L'.TRecord c, loc)),
+ (env, bound))
+ end
+
end
datatype coverage =
Wild
| None
| Datatype of coverage IM.map
+ | Record of coverage SM.map list
fun exhaustive (env, denv, t, ps) =
let
@@ -1050,7 +1089,8 @@ fun exhaustive (env, denv, t, ps) =
| L'.PPrim _ => None
| L'.PCon (pc, NONE) => Datatype (IM.insert (IM.empty, pcCoverage pc, Wild))
| L'.PCon (pc, SOME p) => Datatype (IM.insert (IM.empty, pcCoverage pc, coverage p))
-
+ | L'.PRecord (xps, _) => Record [foldl (fn ((x, p), fmap) =>
+ SM.insert (fmap, x, coverage p)) SM.empty xps]
fun merge (c1, c2) =
case (c1, c2) of
(None, _) => c2
@@ -1061,12 +1101,84 @@ fun exhaustive (env, denv, t, ps) =
| (Datatype cm1, Datatype cm2) => Datatype (IM.unionWith merge (cm1, cm2))
+ | (Record fm1, Record fm2) => Record (fm1 @ fm2)
+
+ | _ => None
+
fun combinedCoverage ps =
case ps of
[] => raise Fail "Empty pattern list for coverage checking"
| [p] => coverage p
| p :: ps => merge (coverage p, combinedCoverage ps)
+ fun enumerateCases t =
+ let
+ fun dtype cons =
+ ListUtil.mapConcat (fn (_, n, to) =>
+ case to of
+ NONE => [Datatype (IM.insert (IM.empty, n, Wild))]
+ | SOME t => map (fn c => Datatype (IM.insert (IM.empty, n, c)))
+ (enumerateCases t)) cons
+ in
+ case #1 (#1 (hnormCon (env, denv) t)) of
+ L'.CNamed n =>
+ (let
+ val dt = E.lookupDatatype env n
+ val cons = E.constructors dt
+ in
+ dtype cons
+ end handle E.UnboundNamed _ => [Wild])
+ | L'.TRecord c =>
+ (case #1 (#1 (hnormCon (env, denv) c)) of
+ L'.CRecord (_, xts) =>
+ let
+ val xts = map (fn (x, t) => (#1 (hnormCon (env, denv) x), t)) xts
+
+ fun exponentiate fs =
+ case fs of
+ [] => [SM.empty]
+ | ((L'.CName x, _), t) :: rest =>
+ let
+ val this = enumerateCases t
+ val rest = exponentiate rest
+ in
+ ListUtil.mapConcat (fn fmap =>
+ map (fn c => SM.insert (fmap, x, c)) this) rest
+ end
+ | _ => raise Fail "exponentiate: Not CName"
+ in
+ if List.exists (fn ((L'.CName _, _), _) => false
+ | (c, _) => true) xts then
+ [Wild]
+ else
+ map (fn ls => Record [ls]) (exponentiate xts)
+ end
+ | _ => [Wild])
+ | _ => [Wild]
+ end
+
+ fun coverageImp (c1, c2) =
+ case (c1, c2) of
+ (Wild, _) => true
+
+ | (Datatype cmap1, Datatype cmap2) =>
+ List.all (fn (n, c2) =>
+ case IM.find (cmap1, n) of
+ NONE => false
+ | SOME c1 => coverageImp (c1, c2)) (IM.listItemsi cmap2)
+
+ | (Record fmaps1, Record fmaps2) =>
+ List.all (fn fmap2 =>
+ List.exists (fn fmap1 =>
+ List.all (fn (x, c2) =>
+ case SM.find (fmap1, x) of
+ NONE => true
+ | SOME c1 => coverageImp (c1, c2))
+ (SM.listItemsi fmap2))
+ fmaps1) fmaps2
+
+ | _ => false
+
fun isTotal (c, t) =
case c of
None => (false, [])
@@ -1109,6 +1221,7 @@ fun exhaustive (env, denv, t, ps) =
| L'.CError => (true, gs)
| _ => raise Fail "isTotal: Not a datatype"
end
+ | Record _ => (List.all (fn c2 => coverageImp (c, c2)) (enumerateCases t), [])
in
isTotal (combinedCoverage ps, t)
end
diff --git a/src/lacweb.grm b/src/lacweb.grm
index a1067aa6..cc68d380 100644
--- a/src/lacweb.grm
+++ b/src/lacweb.grm
@@ -356,9 +356,9 @@ pterm : SYMBOL (PVar SYMBOL, s (SYMBOLleft, SYMBOLright
| UNIT (PRecord ([], false), s (UNITleft, UNITright))
| LBRACE rpat RBRACE (PRecord rpat, s (LBRACEleft, RBRACEright))
-rpat : STRING EQ pat ([(STRING, pat)], false)
+rpat : CSYMBOL EQ pat ([(CSYMBOL, pat)], false)
| DOTDOTDOT ([], true)
- | STRING EQ pat COMMA rpat ((STRING, pat) :: #1 rpat, #2 rpat)
+ | CSYMBOL EQ pat COMMA rpat ((CSYMBOL, pat) :: #1 rpat, #2 rpat)
rexp : ([])
| ident EQ eexp ([(ident, eexp)])
diff --git a/src/source_print.sml b/src/source_print.sml
index 93416fd3..960f3ac5 100644
--- a/src/source_print.sml
+++ b/src/source_print.sml
@@ -173,14 +173,14 @@ fun p_pat' par (p, _) =
p_pat' true p])
| PRecord (xps, flex) =>
let
- val pps = map (fn (x, p) => box [string "x", space, string "=", space, p_pat p]) xps
+ val pps = map (fn (x, p) => box [string x, space, string "=", space, p_pat p]) xps
in
box [string "{",
p_list_sep (box [string ",", space]) (fn x => x)
(if flex then
- pps
+ pps @ [string "..."]
else
- pps @ [string "..."]),
+ pps),
string "}"]
end
diff --git a/tests/rpat.lac b/tests/rpat.lac
new file mode 100644
index 00000000..6c4f9c5e
--- /dev/null
+++ b/tests/rpat.lac
@@ -0,0 +1,13 @@
+val f = fn x : {A : int} => case x of {A = _} => 0
+val f = fn x : {A : int} => case x of {A = _, ...} => 0
+val f = fn x : {A : int, B : int} => case x of {A = _, ...} => 0
+val f = fn x : {A : int, B : int} => case x of {A = 1, B = 2} => 0 | {A = _, ...} => 1
+
+datatype t = A | B
+
+val f = fn x => case x of {A = A, B = 2} => 0 | {A = A, ...} => 0 | {A = B, ...} => 0
+
+val f = fn x => case x of {A = {A = A, ...}, B = B} => 0
+ | {B = A, ...} => 1
+ | {A = {A = B, B = A}, B = B} => 2
+ | {A = {A = B, B = B}, B = B} => 3