diff options
author | Adam Chlipala <adamc@hcoop.net> | 2008-07-31 10:06:27 -0400 |
---|---|---|
committer | Adam Chlipala <adamc@hcoop.net> | 2008-07-31 10:06:27 -0400 |
commit | f4351288c5b57b130c0a75e5e84a445ca513527f (patch) | |
tree | c0e69cdf2d843fbf3c5d2853ce2effe487090970 | |
parent | aa1b3a24913edd0dc97af0d1fc9e3dc0026a2460 (diff) |
Elaborating some basic pattern matching
-rw-r--r-- | src/elab.sml | 13 | ||||
-rw-r--r-- | src/elab_env.sig | 4 | ||||
-rw-r--r-- | src/elab_env.sml | 15 | ||||
-rw-r--r-- | src/elab_print.sig | 1 | ||||
-rw-r--r-- | src/elab_print.sml | 45 | ||||
-rw-r--r-- | src/elab_util.sml | 11 | ||||
-rw-r--r-- | src/elaborate.sml | 95 | ||||
-rw-r--r-- | src/explify.sml | 2 | ||||
-rw-r--r-- | src/source_print.sml | 2 | ||||
-rw-r--r-- | tests/case.lac | 6 |
10 files changed, 189 insertions, 5 deletions
diff --git a/src/elab.sml b/src/elab.sml index b258d7e5..48790a15 100644 --- a/src/elab.sml +++ b/src/elab.sml @@ -71,6 +71,17 @@ datatype con' = withtype con = con' located +datatype patCon = + PConVar of int + | PConProj of int * string list * string + +datatype pat' = + PWild + | PVar of string + | PCon of patCon * pat option + +withtype pat = pat' located + datatype exp' = EPrim of Prim.t | ERel of int @@ -86,6 +97,8 @@ datatype exp' = | ECut of exp * con * { field : con, rest : con } | EFold of kind + | ECase of exp * (pat * exp) list * con + | EError withtype exp = exp' located diff --git a/src/elab_env.sig b/src/elab_env.sig index 0afa3114..ff45f056 100644 --- a/src/elab_env.sig +++ b/src/elab_env.sig @@ -54,9 +54,11 @@ signature ELAB_ENV = sig val pushDatatype : env -> int -> (string * int * Elab.con option) list -> env type datatyp val lookupDatatype : env -> int -> datatyp - val lookupConstructor : datatyp -> int -> string * Elab.con option + val lookupDatatypeConstructor : datatyp -> int -> string * Elab.con option val constructors : datatyp -> (string * int * Elab.con option) list + val lookupConstructor : env -> string -> (int * Elab.con option * int) option + val pushERel : env -> string -> Elab.con -> env val lookupERel : env -> int -> string * Elab.con diff --git a/src/elab_env.sml b/src/elab_env.sml index 1fe5dd5a..5b716730 100644 --- a/src/elab_env.sml +++ b/src/elab_env.sml @@ -81,6 +81,7 @@ type env = { namedC : (string * kind * con option) IM.map, datatypes : datatyp IM.map, + constructors : (int * con option * int) SM.map, renameE : con var' SM.map, relE : (string * con) list, @@ -109,6 +110,7 @@ val empty = { namedC = IM.empty, datatypes = IM.empty, + constructors = SM.empty, renameE = SM.empty, relE = [], @@ -131,6 +133,7 @@ fun pushCRel (env : env) x k = namedC = IM.map (fn (x, k, co) => (x, k, Option.map lift co)) (#namedC env), datatypes = #datatypes env, + constructors = #constructors env, renameE = #renameE env, relE = map (fn (x, c) => (x, lift c)) (#relE env), @@ -154,6 +157,7 @@ fun pushCNamedAs (env : env) x n k co = namedC = IM.insert (#namedC env, n, (x, k, co)), datatypes = #datatypes env, + constructors = #constructors env, renameE = #renameE env, relE = #relE env, @@ -192,6 +196,9 @@ fun pushDatatype (env : env) n xncs = datatypes = IM.insert (#datatypes env, n, foldl (fn ((x, n, to), cons) => IM.insert (cons, n, (x, to))) IM.empty xncs), + constructors = foldl (fn ((x, n', to), cmap) => + SM.insert (cmap, x, (n', to, n))) + (#constructors env) xncs, renameE = #renameE env, relE = #relE env, @@ -208,11 +215,13 @@ fun lookupDatatype (env : env) n = NONE => raise UnboundNamed n | SOME x => x -fun lookupConstructor dt n = +fun lookupDatatypeConstructor dt n = case IM.find (dt, n) of NONE => raise UnboundNamed n | SOME x => x +fun lookupConstructor (env : env) s = SM.find (#constructors env, s) + fun constructors dt = IM.foldri (fn (n, (x, to), ls) => (x, n, to) :: ls) [] dt fun pushERel (env : env) x t = @@ -225,6 +234,7 @@ fun pushERel (env : env) x t = namedC = #namedC env, datatypes = #datatypes env, + constructors = #constructors env, renameE = SM.insert (renameE, x, Rel' (0, t)), relE = (x, t) :: #relE env, @@ -247,6 +257,7 @@ fun pushENamedAs (env : env) x n t = namedC = #namedC env, datatypes = #datatypes env, + constructors = #constructors env, renameE = SM.insert (#renameE env, x, Named' (n, t)), relE = #relE env, @@ -283,6 +294,7 @@ fun pushSgnNamedAs (env : env) x n sgis = namedC = #namedC env, datatypes = #datatypes env, + constructors = #constructors env, renameE = #renameE env, relE = #relE env, @@ -315,6 +327,7 @@ fun pushStrNamedAs (env : env) x n sgis = namedC = #namedC env, datatypes = #datatypes env, + constructors = #constructors env, renameE = #renameE env, relE = #relE env, diff --git a/src/elab_print.sig b/src/elab_print.sig index 9ab9eae7..ead61e68 100644 --- a/src/elab_print.sig +++ b/src/elab_print.sig @@ -31,6 +31,7 @@ signature ELAB_PRINT = sig val p_kind : Elab.kind Print.printer val p_explicitness : Elab.explicitness Print.printer val p_con : ElabEnv.env -> Elab.con Print.printer + val p_pat : ElabEnv.env -> Elab.pat Print.printer val p_exp : ElabEnv.env -> Elab.exp Print.printer val p_decl : ElabEnv.env -> Elab.decl Print.printer val p_sgn_item : ElabEnv.env -> Elab.sgn_item Print.printer diff --git a/src/elab_print.sml b/src/elab_print.sml index 693bc443..7e13c116 100644 --- a/src/elab_print.sml +++ b/src/elab_print.sml @@ -190,6 +190,38 @@ and p_name env (all as (c, _)) = CName s => string s | _ => p_con env all +fun p_patCon env pc = + case pc of + PConVar n => + ((if !debug then + string (#1 (E.lookupENamed env n) ^ "__" ^ Int.toString n) + else + string (#1 (E.lookupENamed env n))) + handle E.UnboundRel _ => string ("UNBOUND_NAMED" ^ Int.toString n)) + | PConProj (m1, ms, x) => + let + val m1x = #1 (E.lookupStrNamed env m1) + handle E.UnboundNamed _ => "UNBOUND_STR_" ^ Int.toString m1 + + val m1s = if !debug then + m1x ^ "__" ^ Int.toString m1 + else + m1x + in + p_list_sep (string ".") string (m1x :: ms @ [x]) + end + +fun p_pat' par env (p, _) = + case p of + PWild => string "_" + | PVar s => string s + | PCon (pc, NONE) => p_patCon env pc + | PCon (pc, SOME p) => parenIf par (box [p_patCon env pc, + space, + p_pat' true env p]) + +val p_pat = p_pat' false + fun p_exp' par env (e, _) = case e of EPrim p => Prim.p_t p @@ -297,6 +329,19 @@ fun p_exp' par env (e, _) = p_con' true env c]) | EFold _ => string "fold" + | ECase (e, pes, _) => parenIf par (box [string "case", + space, + p_exp env e, + space, + string "of", + space, + p_list_sep (box [space, string "|", space]) + (fn (p, e) => box [p_pat env p, + space, + string "=>", + space, + p_exp env e]) pes]) + | EError => string "<ERROR>" and p_exp env = p_exp' false env diff --git a/src/elab_util.sml b/src/elab_util.sml index 05f80ad1..ac2e1a99 100644 --- a/src/elab_util.sml +++ b/src/elab_util.sml @@ -308,6 +308,17 @@ fun mapfoldB {kind = fk, con = fc, exp = fe, bind} = fn k' => (EFold k', loc)) + | ECase (e, pes, t) => + S.bind2 (mfe ctx e, + fn e' => + S.bind2 (ListUtil.mapfold (fn (p, e) => + S.map2 (mfe ctx e, + fn e' => (p, e'))) pes, + fn pes' => + S.map2 (mfc ctx t, + fn t' => + (ECase (e', pes', t'), loc)))) + | EError => S.return2 eAll in mfe diff --git a/src/elaborate.sml b/src/elaborate.sml index d19dcfce..e59cb9d2 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -809,6 +809,11 @@ datatype exp_error = | Unif of string * L'.con | WrongForm of string * L'.exp * L'.con | IncompatibleCons of L'.con * L'.con + | DuplicatePatternVariable of ErrorMsg.span * string + | PatUnify of L'.pat * L'.con * L'.con * cunify_error + | UnboundConstructor of ErrorMsg.span * string + | PatHasArg of ErrorMsg.span + | PatHasNoArg of ErrorMsg.span fun expError env err = case err of @@ -833,6 +838,20 @@ fun expError env err = (ErrorMsg.errorAt (#2 c1) "Incompatible constructors"; eprefaces' [("Con 1", p_con env c1), ("Con 2", p_con env c2)]) + | DuplicatePatternVariable (loc, s) => + ErrorMsg.errorAt loc ("Duplicate pattern variable " ^ s) + | PatUnify (p, c1, c2, uerr) => + (ErrorMsg.errorAt (#2 p) "Unification failure for pattern"; + eprefaces' [("Pattern", p_pat env p), + ("Have con", p_con env c1), + ("Need con", p_con env c2)]; + cunifyError env uerr) + | UnboundConstructor (loc, s) => + ErrorMsg.errorAt loc ("Unbound constructor " ^ s ^ " in pattern") + | PatHasArg loc => + ErrorMsg.errorAt loc "Constructor expects no argument but is used with argument" + | PatHasNoArg loc => + ErrorMsg.errorAt loc "Constructor expects argument but is used with no argument" fun checkCon (env, denv) e c1 c2 = unifyCons (env, denv) c1 c2 @@ -840,6 +859,12 @@ fun checkCon (env, denv) e c1 c2 = (expError env (Unify (e, c1, c2, err)); []) +fun checkPatCon (env, denv) p c1 c2 = + unifyCons (env, denv) c1 c2 + handle CUnify (c1, c2, err) => + (expError env (PatUnify (p, c1, c2, err)); + []) + fun primType env p = case p of P.Int _ => !int @@ -903,6 +928,8 @@ fun typeof env (e, loc) = | L'.ECut (_, _, {rest, ...}) => (L'.TRecord rest, loc) | L'.EFold dom => foldType (dom, loc) + | L'.ECase (_, _, t) => t + | L'.EError => cerror fun elabHead (env, denv) (e as (_, loc)) t = @@ -927,6 +954,52 @@ fun elabHead (env, denv) (e as (_, loc)) t = unravel (t, e) end +fun elabPat (pAll as (p, loc), (env, bound)) = + let + val perror = (L'.PWild, loc) + val terror = (L'.CError, loc) + val pterror = (perror, terror) + val rerror = (pterror, (env, bound)) + + fun pcon (pc, po, to, dn) = + + case (po, to) of + (NONE, SOME _) => (expError env (PatHasNoArg loc); + rerror) + | (SOME _, NONE) => (expError env (PatHasArg loc); + rerror) + | (NONE, NONE) => (((L'.PCon (pc, NONE), loc), (L'.CNamed dn, loc)), + (env, bound)) + | (SOME p, SOME t) => + let + val ((p', pt), (env, bound)) = elabPat (p, (env, bound)) + in + (((L'.PCon (pc, SOME p'), loc), (L'.CNamed dn, loc)), + (env, bound)) + end + in + case p of + L.PWild => (((L'.PWild, loc), cunif (loc, (L'.KType, loc))), + (env, bound)) + | L.PVar x => + let + val t = if SS.member (bound, x) then + (expError env (DuplicatePatternVariable (loc, x)); + terror) + else + cunif (loc, (L'.KType, loc)) + in + (((L'.PVar x, loc), t), + (E.pushERel env x t, SS.add (bound, x))) + end + | L.PCon ([], x, po) => + (case E.lookupConstructor env x of + NONE => (expError env (UnboundConstructor (loc, x)); + rerror) + | SOME (n, to, dn) => pcon (L'.PConVar n, po, to, dn)) + | L.PCon _ => raise Fail "uhoh" + end + fun elabExp (env, denv) (eAll as (e, loc)) = let @@ -1138,7 +1211,25 @@ fun elabExp (env, denv) (eAll as (e, loc)) = ((L'.EFold dom, loc), foldType (dom, loc), []) end - | L.ECase _ => raise Fail "Elaborate ECase" + | L.ECase (e, pes) => + let + val (e', et, gs1) = elabExp (env, denv) e + val result = cunif (loc, (L'.KType, loc)) + val (pes', gs) = ListUtil.foldlMap + (fn ((p, e), gs) => + let + val ((p', pt), (env, _)) = elabPat (p, (env, SS.empty)) + + val gs1 = checkPatCon (env, denv) p' pt et + val (e', et, gs2) = elabExp (env, denv) e + val gs3 = checkCon (env, denv) e' et result + in + ((p', e'), gs1 @ gs2 @ gs3 @ gs) + end) + gs1 pes + in + ((L'.ECase (e', pes', result), loc), result, gs) + end end @@ -1961,6 +2052,8 @@ fun elabDecl ((d, loc), (env, denv, gs)) = ((x, n', to), (SS.add (used, x), env, gs')) end) (SS.empty, env, []) xcs + + val env = E.pushDatatype env n xcs in ([(L'.DDatatype (x, n, xcs), loc)], (env, denv, gs)) end diff --git a/src/explify.sml b/src/explify.sml index c193a631..da94ba94 100644 --- a/src/explify.sml +++ b/src/explify.sml @@ -89,6 +89,8 @@ fun explifyExp (e, loc) = {field = explifyCon field, rest = explifyCon rest}), loc) | L.EFold k => (L'.EFold (explifyKind k), loc) + | L.ECase _ => raise Fail "Explify ECase" + | L.EError => raise Fail ("explifyExp: EError at " ^ EM.spanToString loc) fun explifySgi (sgi, loc) = diff --git a/src/source_print.sml b/src/source_print.sml index 79f3c254..4bd7e28e 100644 --- a/src/source_print.sml +++ b/src/source_print.sml @@ -252,7 +252,7 @@ fun p_exp' par (e, _) = | ECase (e, pes) => parenIf par (box [string "case", space, - p_exp' false e, + p_exp e, space, string "of", space, diff --git a/tests/case.lac b/tests/case.lac index dc3fe03b..b131b27b 100644 --- a/tests/case.lac +++ b/tests/case.lac @@ -8,5 +8,9 @@ val out = fn x : u => case x of C y => y | D => A datatype nat = O | S of nat -val is_two = fn x : int_list => +val is_two = fn x : nat => case x of S (S O) => A | _ => B + +val zero_is_two = is_two O +val one_is_two = is_two (S O) +val two_is_two = is_two (S (S O)) |