summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/elab_env.sig2
-rw-r--r--src/elab_env.sml19
-rw-r--r--src/elaborate.sml86
-rw-r--r--src/lacweb.grm10
-rw-r--r--src/lacweb.lex1
-rw-r--r--src/source.sml1
-rw-r--r--src/source_print.sml16
-rw-r--r--tests/caseMod.lac19
8 files changed, 126 insertions, 28 deletions
diff --git a/src/elab_env.sig b/src/elab_env.sig
index ff45f056..229436ec 100644
--- a/src/elab_env.sig
+++ b/src/elab_env.sig
@@ -88,6 +88,8 @@ signature ELAB_ENV = sig
val projectCon : env -> { sgn : Elab.sgn, str : Elab.str, field : string } -> (Elab.kind * Elab.con option) option
val projectDatatype : env -> { sgn : Elab.sgn, str : Elab.str, field : string }
-> (string * int * Elab.con option) list option
+ val projectConstructor : env -> { sgn : Elab.sgn, str : Elab.str, field : string }
+ -> (int * Elab.con option * Elab.con) option
val projectVal : env -> { sgn : Elab.sgn, str : Elab.str, field : string } -> Elab.con option
val projectSgn : env -> { sgn : Elab.sgn, str : Elab.str, field : string } -> Elab.sgn option
val projectStr : env -> { sgn : Elab.sgn, str : Elab.str, field : string } -> Elab.sgn option
diff --git a/src/elab_env.sml b/src/elab_env.sml
index 5b716730..720b19da 100644
--- a/src/elab_env.sml
+++ b/src/elab_env.sml
@@ -570,6 +570,25 @@ fun projectDatatype env {sgn, str, field} =
| SOME (xncs, subs) => SOME (map (fn (x, n, to) => (x, n, Option.map (sgnSubCon (str, subs)) to)) xncs))
| _ => NONE
+fun projectConstructor env {sgn, str, field} =
+ case #1 (hnormSgn env sgn) of
+ SgnConst sgis =>
+ let
+ fun consider (n, xncs) =
+ ListUtil.search (fn (x, n', to) =>
+ if x <> field then
+ NONE
+ else
+ SOME (n', to, (CNamed n, #2 str))) xncs
+ in
+ case sgnSeek (fn SgiDatatype (_, n, xncs) => consider (n, xncs)
+ | SgiDatatypeImp (_, n, _, _, _, xncs) => consider (n, xncs)
+ | _ => NONE) sgis of
+ NONE => NONE
+ | SOME ((n, to, t), subs) => SOME (n, Option.map (sgnSubCon (str, subs)) to, sgnSubCon (str, subs) t)
+ end
+ | _ => NONE
+
fun projectVal env {sgn, str, field} =
case #1 (hnormSgn env sgn) of
SgnConst sgis =>
diff --git a/src/elaborate.sml b/src/elaborate.sml
index e15ef185..10c5d214 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -812,7 +812,7 @@ datatype exp_error =
| IncompatibleCons of L'.con * L'.con
| DuplicatePatternVariable of ErrorMsg.span * string
| PatUnify of L'.pat * L'.con * L'.con * cunify_error
- | UnboundConstructor of ErrorMsg.span * string
+ | UnboundConstructor of ErrorMsg.span * string list * string
| PatHasArg of ErrorMsg.span
| PatHasNoArg of ErrorMsg.span
| Inexhaustive of ErrorMsg.span
@@ -848,8 +848,8 @@ fun expError env err =
("Have con", p_con env c1),
("Need con", p_con env c2)];
cunifyError env uerr)
- | UnboundConstructor (loc, s) =>
- ErrorMsg.errorAt loc ("Unbound constructor " ^ s ^ " in pattern")
+ | UnboundConstructor (loc, ms, s) =>
+ ErrorMsg.errorAt loc ("Unbound constructor " ^ String.concatWith "." (ms @ [s]) ^ " in pattern")
| PatHasArg loc =>
ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
| PatHasNoArg loc =>
@@ -958,7 +958,7 @@ fun elabHead (env, denv) (e as (_, loc)) t =
unravel (t, e)
end
-fun elabPat (pAll as (p, loc), (env, bound)) =
+fun elabPat (pAll as (p, loc), (env, denv, bound)) =
let
val perror = (L'.PWild, loc)
val terror = (L'.CError, loc)
@@ -972,13 +972,13 @@ fun elabPat (pAll as (p, loc), (env, bound)) =
rerror)
| (SOME _, NONE) => (expError env (PatHasArg loc);
rerror)
- | (NONE, NONE) => (((L'.PCon (pc, NONE), loc), (L'.CNamed dn, loc)),
+ | (NONE, NONE) => (((L'.PCon (pc, NONE), loc), dn),
(env, bound))
| (SOME p, SOME t) =>
let
- val ((p', pt), (env, bound)) = elabPat (p, (env, bound))
+ val ((p', pt), (env, bound)) = elabPat (p, (env, denv, bound))
in
- (((L'.PCon (pc, SOME p'), loc), (L'.CNamed dn, loc)),
+ (((L'.PCon (pc, SOME p'), loc), dn),
(env, bound))
end
in
@@ -1000,10 +1000,28 @@ fun elabPat (pAll as (p, loc), (env, bound)) =
(env, bound))
| L.PCon ([], x, po) =>
(case E.lookupConstructor env x of
- NONE => (expError env (UnboundConstructor (loc, x));
+ NONE => (expError env (UnboundConstructor (loc, [], x));
rerror)
- | SOME (n, to, dn) => pcon (L'.PConVar n, po, to, dn))
- | L.PCon _ => raise Fail "uhoh"
+ | SOME (n, to, dn) => pcon (L'.PConVar n, po, to, (L'.CNamed dn, loc)))
+ | L.PCon (m1 :: ms, x, po) =>
+ (case E.lookupStr env m1 of
+ NONE => (expError env (UnboundStrInExp (loc, m1));
+ rerror)
+ | SOME (n, sgn) =>
+ let
+ val (str, sgn) = foldl (fn (m, (str, sgn)) =>
+ case E.projectStr env {sgn = sgn, str = str, field = m} of
+ NONE => raise Fail "typeof: Unknown substructure"
+ | SOME sgn => ((L'.StrProj (str, m), loc), sgn))
+ ((L'.StrVar n, loc), sgn) ms
+ in
+ case E.projectConstructor env {str = str, sgn = sgn, field = x} of
+ NONE => (expError env (UnboundConstructor (loc, m1 :: ms, x));
+ rerror)
+ | SOME (_, to, dn) => pcon (L'.PConProj (n, ms, x), po, to, dn)
+ end)
+
+ | L.PRecord _ => raise Fail "Elaborate PRecord"
end
datatype coverage =
@@ -1016,7 +1034,14 @@ fun exhaustive (env, denv, t, ps) =
fun pcCoverage pc =
case pc of
L'.PConVar n => n
- | _ => raise Fail "uh oh^2"
+ | L'.PConProj (m1, ms, x) =>
+ let
+ val (str, sgn) = E.chaseMpath env (m1, ms)
+ in
+ case E.projectConstructor env {str = str, sgn = sgn, field = x} of
+ NONE => raise Fail "exhaustive: Can't project constructor"
+ | SOME (n, _, _) => n
+ end
fun coverage (p, _) =
case p of
@@ -1049,6 +1074,21 @@ fun exhaustive (env, denv, t, ps) =
| Datatype cm =>
let
val ((t, _), gs) = hnormCon (env, denv) t
+
+ fun dtype cons =
+ foldl (fn ((_, n, to), (total, gs)) =>
+ case IM.find (cm, n) of
+ NONE => (false, gs)
+ | SOME c' =>
+ case to of
+ NONE => (total, gs)
+ | SOME t' =>
+ let
+ val (total, gs') = isTotal (c', t')
+ in
+ (total, gs' @ gs)
+ end)
+ (true, gs) cons
in
case t of
L'.CNamed n =>
@@ -1056,19 +1096,15 @@ fun exhaustive (env, denv, t, ps) =
val dt = E.lookupDatatype env n
val cons = E.constructors dt
in
- foldl (fn ((_, n, to), (total, gs)) =>
- case IM.find (cm, n) of
- NONE => (false, gs)
- | SOME c' =>
- case to of
- NONE => (total, gs)
- | SOME t' =>
- let
- val (total, gs') = isTotal (c', t')
- in
- (total, gs' @ gs)
- end)
- (true, gs) cons
+ dtype cons
+ end
+ | L'.CModProj (m1, ms, x) =>
+ let
+ val (str, sgn) = E.chaseMpath env (m1, ms)
+ in
+ case E.projectDatatype env {str = str, sgn = sgn, field = x} of
+ NONE => raise Fail "isTotal: Can't project datatype"
+ | SOME cons => dtype cons
end
| L'.CError => (true, gs)
| _ => raise Fail "isTotal: Not a datatype"
@@ -1295,7 +1331,7 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
val (pes', gs) = ListUtil.foldlMap
(fn ((p, e), gs) =>
let
- val ((p', pt), (env, _)) = elabPat (p, (env, SS.empty))
+ val ((p', pt), (env, _)) = elabPat (p, (env, denv, SS.empty))
val gs1 = checkPatCon (env, denv) p' pt et
val (e', et, gs2) = elabExp (env, denv) e
diff --git a/src/lacweb.grm b/src/lacweb.grm
index 817171a6..a1067aa6 100644
--- a/src/lacweb.grm
+++ b/src/lacweb.grm
@@ -43,7 +43,7 @@ fun uppercaseFirst "" = ""
| SYMBOL of string | CSYMBOL of string
| LPAREN | RPAREN | LBRACK | RBRACK | LBRACE | RBRACE
| EQ | COMMA | COLON | DCOLON | TCOLON | DOT | HASH | UNDER | UNDERUNDER | BAR
- | DIVIDE | GT
+ | DIVIDE | GT | DOTDOTDOT
| CON | LTYPE | VAL | REC | AND | FOLD | UNIT | KUNIT
| DATATYPE | OF
| TYPE | NAME
@@ -104,6 +104,7 @@ fun uppercaseFirst "" = ""
| branchs of (pat * exp) list
| pat of pat
| pterm of pat
+ | rpat of (string * pat) list * bool
| attrs of (con * exp) list
| attr of con * exp
@@ -351,6 +352,13 @@ pterm : SYMBOL (PVar SYMBOL, s (SYMBOLleft, SYMBOLright
| INT (PPrim (Prim.Int INT), s (INTleft, INTright))
| STRING (PPrim (Prim.String STRING), s (STRINGleft, STRINGright))
| LPAREN pat RPAREN (pat)
+ | LBRACE RBRACE (PRecord ([], false), s (LBRACEleft, RBRACEright))
+ | UNIT (PRecord ([], false), s (UNITleft, UNITright))
+ | LBRACE rpat RBRACE (PRecord rpat, s (LBRACEleft, RBRACEright))
+
+rpat : STRING EQ pat ([(STRING, pat)], false)
+ | DOTDOTDOT ([], true)
+ | STRING EQ pat COMMA rpat ((STRING, pat) :: #1 rpat, #2 rpat)
rexp : ([])
| ident EQ eexp ([(ident, eexp)])
diff --git a/src/lacweb.lex b/src/lacweb.lex
index 862d5d31..b62edcc6 100644
--- a/src/lacweb.lex
+++ b/src/lacweb.lex
@@ -242,6 +242,7 @@ notags = [^<{\n]+;
<INITIAL> ":::" => (Tokens.TCOLON (pos yypos, pos yypos + size yytext));
<INITIAL> "::" => (Tokens.DCOLON (pos yypos, pos yypos + size yytext));
<INITIAL> ":" => (Tokens.COLON (pos yypos, pos yypos + size yytext));
+<INITIAL> "..." => (Tokens.DOTDOTDOT (pos yypos, pos yypos + size yytext));
<INITIAL> "." => (Tokens.DOT (pos yypos, pos yypos + size yytext));
<INITIAL> "$" => (Tokens.DOLLAR (pos yypos, pos yypos + size yytext));
<INITIAL> "#" => (Tokens.HASH (pos yypos, pos yypos + size yytext));
diff --git a/src/source.sml b/src/source.sml
index 3dbada25..d58feb94 100644
--- a/src/source.sml
+++ b/src/source.sml
@@ -94,6 +94,7 @@ datatype pat' =
| PVar of string
| PPrim of Prim.t
| PCon of string list * string * pat option
+ | PRecord of (string * pat) list * bool
withtype pat = pat' located
diff --git a/src/source_print.sml b/src/source_print.sml
index 68ef3508..93416fd3 100644
--- a/src/source_print.sml
+++ b/src/source_print.sml
@@ -171,8 +171,20 @@ fun p_pat' par (p, _) =
| PCon (ms, x, SOME p) => parenIf par (box [p_list_sep (string ".") string (ms @ [x]),
space,
p_pat' true p])
-
-val p_pat = p_pat' false
+ | PRecord (xps, flex) =>
+ let
+ val pps = map (fn (x, p) => box [string "x", space, string "=", space, p_pat p]) xps
+ in
+ box [string "{",
+ p_list_sep (box [string ",", space]) (fn x => x)
+ (if flex then
+ pps
+ else
+ pps @ [string "..."]),
+ string "}"]
+ end
+
+and p_pat x = p_pat' false x
fun p_exp' par (e, _) =
case e of
diff --git a/tests/caseMod.lac b/tests/caseMod.lac
new file mode 100644
index 00000000..2c6fbc80
--- /dev/null
+++ b/tests/caseMod.lac
@@ -0,0 +1,19 @@
+structure M = struct
+ datatype t = A | B
+end
+
+val f = fn x : M.t => case x of M.A => M.B | M.B => M.A
+
+datatype t = datatype M.t
+
+val g = fn x : t => case x of M.A => B | B => M.A
+
+structure N = struct
+ datatype t = C of t | D
+end
+
+val h = fn x : N.t => case x of N.C x => x | N.D => M.A
+
+datatype u = datatype N.t
+
+val i = fn x : u => case x of N.C x => x | D => M.A