summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/elab.sml13
-rw-r--r--src/elab_env.sig4
-rw-r--r--src/elab_env.sml15
-rw-r--r--src/elab_print.sig1
-rw-r--r--src/elab_print.sml45
-rw-r--r--src/elab_util.sml11
-rw-r--r--src/elaborate.sml95
-rw-r--r--src/explify.sml2
-rw-r--r--src/source_print.sml2
-rw-r--r--tests/case.lac6
10 files changed, 189 insertions, 5 deletions
diff --git a/src/elab.sml b/src/elab.sml
index b258d7e5..48790a15 100644
--- a/src/elab.sml
+++ b/src/elab.sml
@@ -71,6 +71,17 @@ datatype con' =
withtype con = con' located
+datatype patCon =
+ PConVar of int
+ | PConProj of int * string list * string
+
+datatype pat' =
+ PWild
+ | PVar of string
+ | PCon of patCon * pat option
+
+withtype pat = pat' located
+
datatype exp' =
EPrim of Prim.t
| ERel of int
@@ -86,6 +97,8 @@ datatype exp' =
| ECut of exp * con * { field : con, rest : con }
| EFold of kind
+ | ECase of exp * (pat * exp) list * con
+
| EError
withtype exp = exp' located
diff --git a/src/elab_env.sig b/src/elab_env.sig
index 0afa3114..ff45f056 100644
--- a/src/elab_env.sig
+++ b/src/elab_env.sig
@@ -54,9 +54,11 @@ signature ELAB_ENV = sig
val pushDatatype : env -> int -> (string * int * Elab.con option) list -> env
type datatyp
val lookupDatatype : env -> int -> datatyp
- val lookupConstructor : datatyp -> int -> string * Elab.con option
+ val lookupDatatypeConstructor : datatyp -> int -> string * Elab.con option
val constructors : datatyp -> (string * int * Elab.con option) list
+ val lookupConstructor : env -> string -> (int * Elab.con option * int) option
+
val pushERel : env -> string -> Elab.con -> env
val lookupERel : env -> int -> string * Elab.con
diff --git a/src/elab_env.sml b/src/elab_env.sml
index 1fe5dd5a..5b716730 100644
--- a/src/elab_env.sml
+++ b/src/elab_env.sml
@@ -81,6 +81,7 @@ type env = {
namedC : (string * kind * con option) IM.map,
datatypes : datatyp IM.map,
+ constructors : (int * con option * int) SM.map,
renameE : con var' SM.map,
relE : (string * con) list,
@@ -109,6 +110,7 @@ val empty = {
namedC = IM.empty,
datatypes = IM.empty,
+ constructors = SM.empty,
renameE = SM.empty,
relE = [],
@@ -131,6 +133,7 @@ fun pushCRel (env : env) x k =
namedC = IM.map (fn (x, k, co) => (x, k, Option.map lift co)) (#namedC env),
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = #renameE env,
relE = map (fn (x, c) => (x, lift c)) (#relE env),
@@ -154,6 +157,7 @@ fun pushCNamedAs (env : env) x n k co =
namedC = IM.insert (#namedC env, n, (x, k, co)),
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = #renameE env,
relE = #relE env,
@@ -192,6 +196,9 @@ fun pushDatatype (env : env) n xncs =
datatypes = IM.insert (#datatypes env, n,
foldl (fn ((x, n, to), cons) =>
IM.insert (cons, n, (x, to))) IM.empty xncs),
+ constructors = foldl (fn ((x, n', to), cmap) =>
+ SM.insert (cmap, x, (n', to, n)))
+ (#constructors env) xncs,
renameE = #renameE env,
relE = #relE env,
@@ -208,11 +215,13 @@ fun lookupDatatype (env : env) n =
NONE => raise UnboundNamed n
| SOME x => x
-fun lookupConstructor dt n =
+fun lookupDatatypeConstructor dt n =
case IM.find (dt, n) of
NONE => raise UnboundNamed n
| SOME x => x
+fun lookupConstructor (env : env) s = SM.find (#constructors env, s)
+
fun constructors dt = IM.foldri (fn (n, (x, to), ls) => (x, n, to) :: ls) [] dt
fun pushERel (env : env) x t =
@@ -225,6 +234,7 @@ fun pushERel (env : env) x t =
namedC = #namedC env,
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = SM.insert (renameE, x, Rel' (0, t)),
relE = (x, t) :: #relE env,
@@ -247,6 +257,7 @@ fun pushENamedAs (env : env) x n t =
namedC = #namedC env,
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = SM.insert (#renameE env, x, Named' (n, t)),
relE = #relE env,
@@ -283,6 +294,7 @@ fun pushSgnNamedAs (env : env) x n sgis =
namedC = #namedC env,
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = #renameE env,
relE = #relE env,
@@ -315,6 +327,7 @@ fun pushStrNamedAs (env : env) x n sgis =
namedC = #namedC env,
datatypes = #datatypes env,
+ constructors = #constructors env,
renameE = #renameE env,
relE = #relE env,
diff --git a/src/elab_print.sig b/src/elab_print.sig
index 9ab9eae7..ead61e68 100644
--- a/src/elab_print.sig
+++ b/src/elab_print.sig
@@ -31,6 +31,7 @@ signature ELAB_PRINT = sig
val p_kind : Elab.kind Print.printer
val p_explicitness : Elab.explicitness Print.printer
val p_con : ElabEnv.env -> Elab.con Print.printer
+ val p_pat : ElabEnv.env -> Elab.pat Print.printer
val p_exp : ElabEnv.env -> Elab.exp Print.printer
val p_decl : ElabEnv.env -> Elab.decl Print.printer
val p_sgn_item : ElabEnv.env -> Elab.sgn_item Print.printer
diff --git a/src/elab_print.sml b/src/elab_print.sml
index 693bc443..7e13c116 100644
--- a/src/elab_print.sml
+++ b/src/elab_print.sml
@@ -190,6 +190,38 @@ and p_name env (all as (c, _)) =
CName s => string s
| _ => p_con env all
+fun p_patCon env pc =
+ case pc of
+ PConVar n =>
+ ((if !debug then
+ string (#1 (E.lookupENamed env n) ^ "__" ^ Int.toString n)
+ else
+ string (#1 (E.lookupENamed env n)))
+ handle E.UnboundRel _ => string ("UNBOUND_NAMED" ^ Int.toString n))
+ | PConProj (m1, ms, x) =>
+ let
+ val m1x = #1 (E.lookupStrNamed env m1)
+ handle E.UnboundNamed _ => "UNBOUND_STR_" ^ Int.toString m1
+
+ val m1s = if !debug then
+ m1x ^ "__" ^ Int.toString m1
+ else
+ m1x
+ in
+ p_list_sep (string ".") string (m1x :: ms @ [x])
+ end
+
+fun p_pat' par env (p, _) =
+ case p of
+ PWild => string "_"
+ | PVar s => string s
+ | PCon (pc, NONE) => p_patCon env pc
+ | PCon (pc, SOME p) => parenIf par (box [p_patCon env pc,
+ space,
+ p_pat' true env p])
+
+val p_pat = p_pat' false
+
fun p_exp' par env (e, _) =
case e of
EPrim p => Prim.p_t p
@@ -297,6 +329,19 @@ fun p_exp' par env (e, _) =
p_con' true env c])
| EFold _ => string "fold"
+ | ECase (e, pes, _) => parenIf par (box [string "case",
+ space,
+ p_exp env e,
+ space,
+ string "of",
+ space,
+ p_list_sep (box [space, string "|", space])
+ (fn (p, e) => box [p_pat env p,
+ space,
+ string "=>",
+ space,
+ p_exp env e]) pes])
+
| EError => string "<ERROR>"
and p_exp env = p_exp' false env
diff --git a/src/elab_util.sml b/src/elab_util.sml
index 05f80ad1..ac2e1a99 100644
--- a/src/elab_util.sml
+++ b/src/elab_util.sml
@@ -308,6 +308,17 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} =
fn k' =>
(EFold k', loc))
+ | ECase (e, pes, t) =>
+ S.bind2 (mfe ctx e,
+ fn e' =>
+ S.bind2 (ListUtil.mapfold (fn (p, e) =>
+ S.map2 (mfe ctx e,
+ fn e' => (p, e'))) pes,
+ fn pes' =>
+ S.map2 (mfc ctx t,
+ fn t' =>
+ (ECase (e', pes', t'), loc))))
+
| EError => S.return2 eAll
in
mfe
diff --git a/src/elaborate.sml b/src/elaborate.sml
index d19dcfce..e59cb9d2 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -809,6 +809,11 @@ datatype exp_error =
| Unif of string * L'.con
| WrongForm of string * L'.exp * L'.con
| IncompatibleCons of L'.con * L'.con
+ | DuplicatePatternVariable of ErrorMsg.span * string
+ | PatUnify of L'.pat * L'.con * L'.con * cunify_error
+ | UnboundConstructor of ErrorMsg.span * string
+ | PatHasArg of ErrorMsg.span
+ | PatHasNoArg of ErrorMsg.span
fun expError env err =
case err of
@@ -833,6 +838,20 @@ fun expError env err =
(ErrorMsg.errorAt (#2 c1) "Incompatible constructors";
eprefaces' [("Con 1", p_con env c1),
("Con 2", p_con env c2)])
+ | DuplicatePatternVariable (loc, s) =>
+ ErrorMsg.errorAt loc ("Duplicate pattern variable " ^ s)
+ | PatUnify (p, c1, c2, uerr) =>
+ (ErrorMsg.errorAt (#2 p) "Unification failure for pattern";
+ eprefaces' [("Pattern", p_pat env p),
+ ("Have con", p_con env c1),
+ ("Need con", p_con env c2)];
+ cunifyError env uerr)
+ | UnboundConstructor (loc, s) =>
+ ErrorMsg.errorAt loc ("Unbound constructor " ^ s ^ " in pattern")
+ | PatHasArg loc =>
+ ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument"
+ | PatHasNoArg loc =>
+ ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument"
fun checkCon (env, denv) e c1 c2 =
unifyCons (env, denv) c1 c2
@@ -840,6 +859,12 @@ fun checkCon (env, denv) e c1 c2 =
(expError env (Unify (e, c1, c2, err));
[])
+fun checkPatCon (env, denv) p c1 c2 =
+ unifyCons (env, denv) c1 c2
+ handle CUnify (c1, c2, err) =>
+ (expError env (PatUnify (p, c1, c2, err));
+ [])
+
fun primType env p =
case p of
P.Int _ => !int
@@ -903,6 +928,8 @@ fun typeof env (e, loc) =
| L'.ECut (_, _, {rest, ...}) => (L'.TRecord rest, loc)
| L'.EFold dom => foldType (dom, loc)
+ | L'.ECase (_, _, t) => t
+
| L'.EError => cerror
fun elabHead (env, denv) (e as (_, loc)) t =
@@ -927,6 +954,52 @@ fun elabHead (env, denv) (e as (_, loc)) t =
unravel (t, e)
end
+fun elabPat (pAll as (p, loc), (env, bound)) =
+ let
+ val perror = (L'.PWild, loc)
+ val terror = (L'.CError, loc)
+ val pterror = (perror, terror)
+ val rerror = (pterror, (env, bound))
+
+ fun pcon (pc, po, to, dn) =
+
+ case (po, to) of
+ (NONE, SOME _) => (expError env (PatHasNoArg loc);
+ rerror)
+ | (SOME _, NONE) => (expError env (PatHasArg loc);
+ rerror)
+ | (NONE, NONE) => (((L'.PCon (pc, NONE), loc), (L'.CNamed dn, loc)),
+ (env, bound))
+ | (SOME p, SOME t) =>
+ let
+ val ((p', pt), (env, bound)) = elabPat (p, (env, bound))
+ in
+ (((L'.PCon (pc, SOME p'), loc), (L'.CNamed dn, loc)),
+ (env, bound))
+ end
+ in
+ case p of
+ L.PWild => (((L'.PWild, loc), cunif (loc, (L'.KType, loc))),
+ (env, bound))
+ | L.PVar x =>
+ let
+ val t = if SS.member (bound, x) then
+ (expError env (DuplicatePatternVariable (loc, x));
+ terror)
+ else
+ cunif (loc, (L'.KType, loc))
+ in
+ (((L'.PVar x, loc), t),
+ (E.pushERel env x t, SS.add (bound, x)))
+ end
+ | L.PCon ([], x, po) =>
+ (case E.lookupConstructor env x of
+ NONE => (expError env (UnboundConstructor (loc, x));
+ rerror)
+ | SOME (n, to, dn) => pcon (L'.PConVar n, po, to, dn))
+ | L.PCon _ => raise Fail "uhoh"
+ end
+
fun elabExp (env, denv) (eAll as (e, loc)) =
let
@@ -1138,7 +1211,25 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
((L'.EFold dom, loc), foldType (dom, loc), [])
end
- | L.ECase _ => raise Fail "Elaborate ECase"
+ | L.ECase (e, pes) =>
+ let
+ val (e', et, gs1) = elabExp (env, denv) e
+ val result = cunif (loc, (L'.KType, loc))
+ val (pes', gs) = ListUtil.foldlMap
+ (fn ((p, e), gs) =>
+ let
+ val ((p', pt), (env, _)) = elabPat (p, (env, SS.empty))
+
+ val gs1 = checkPatCon (env, denv) p' pt et
+ val (e', et, gs2) = elabExp (env, denv) e
+ val gs3 = checkCon (env, denv) e' et result
+ in
+ ((p', e'), gs1 @ gs2 @ gs3 @ gs)
+ end)
+ gs1 pes
+ in
+ ((L'.ECase (e', pes', result), loc), result, gs)
+ end
end
@@ -1961,6 +2052,8 @@ fun elabDecl ((d, loc), (env, denv, gs)) =
((x, n', to), (SS.add (used, x), env, gs'))
end)
(SS.empty, env, []) xcs
+
+ val env = E.pushDatatype env n xcs
in
([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs))
end
diff --git a/src/explify.sml b/src/explify.sml
index c193a631..da94ba94 100644
--- a/src/explify.sml
+++ b/src/explify.sml
@@ -89,6 +89,8 @@ fun explifyExp (e, loc) =
{field = explifyCon field, rest = explifyCon rest}), loc)
| L.EFold k => (L'.EFold (explifyKind k), loc)
+ | L.ECase _ => raise Fail "Explify ECase"
+
| L.EError => raise Fail ("explifyExp: EError at " ^ EM.spanToString loc)
fun explifySgi (sgi, loc) =
diff --git a/src/source_print.sml b/src/source_print.sml
index 79f3c254..4bd7e28e 100644
--- a/src/source_print.sml
+++ b/src/source_print.sml
@@ -252,7 +252,7 @@ fun p_exp' par (e, _) =
| ECase (e, pes) => parenIf par (box [string "case",
space,
- p_exp' false e,
+ p_exp e,
space,
string "of",
space,
diff --git a/tests/case.lac b/tests/case.lac
index dc3fe03b..b131b27b 100644
--- a/tests/case.lac
+++ b/tests/case.lac
@@ -8,5 +8,9 @@ val out = fn x : u => case x of C y => y | D => A
datatype nat = O | S of nat
-val is_two = fn x : int_list =>
+val is_two = fn x : nat =>
case x of S (S O) => A | _ => B
+
+val zero_is_two = is_two O
+val one_is_two = is_two (S O)
+val two_is_two = is_two (S (S O))