From b4f1361d2dff2e180e4656efa491b275707cdf02 Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Sat, 16 Aug 2008 14:32:18 -0400 Subject: Initial type class support --- src/elaborate.sml | 235 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 177 insertions(+), 58 deletions(-) (limited to 'src/elaborate.sml') diff --git a/src/elaborate.sml b/src/elaborate.sml index a03904d0..e4369dd4 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -985,7 +985,8 @@ datatype exp_error = | PatHasNoArg of ErrorMsg.span | Inexhaustive of ErrorMsg.span | DuplicatePatField of ErrorMsg.span * string - | SqlInfer of ErrorMsg.span * L'.con + | Unresolvable of ErrorMsg.span * L'.con + | OutOfContext of ErrorMsg.span fun expError env err = case err of @@ -1028,9 +1029,11 @@ fun expError env err = ErrorMsg.errorAt loc "Inexhaustive 'case'" | DuplicatePatField (loc, s) => ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern") - | SqlInfer (loc, c) => - (ErrorMsg.errorAt loc "Can't infer SQL-ness of type"; - eprefaces' [("Type", p_con env c)]) + | OutOfContext loc => + ErrorMsg.errorAt loc "Type class wildcard occurs out of context" + | Unresolvable (loc, c) => + (ErrorMsg.errorAt loc "Can't resolve type class instance"; + eprefaces' [("Class constraint", p_con env c)]) fun checkCon (env, denv) e c1 c2 = unifyCons (env, denv) c1 c2 @@ -1419,50 +1422,23 @@ fun elabExp (env, denv) (eAll as (e, loc)) = ((L'.EModProj (n, ms, s), loc), t, []) end) - | L.EApp (e1, (L.ESqlInfer, _)) => + | L.EApp (e1, (L.EWild, _)) => let val (e1', t1, gs1) = elabExp (env, denv) e1 val (e1', t1, gs2) = elabHead (env, denv) e1' t1 val (t1, gs3) = hnormCon (env, denv) t1 in case t1 of - (L'.TFun ((L'.CApp ((L'.CModProj (basis, [], "sql_type"), _), - t), _), ran), _) => - if basis <> !basis_r then - raise Fail "Bad use of ESqlInfer [1]" - else - let - val (t, gs4) = hnormCon (env, denv) t - - fun error () = expError env (SqlInfer (loc, t)) - in - case t of - (L'.CModProj (basis, [], x), _) => - (if basis <> !basis_r then - error () - else - case x of - "bool" => () - | "int" => () - | "float" => () - | "string" => () - | _ => error (); - ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_" ^ x), loc)), loc), - ran, gs1 @ gs2 @ gs3 @ gs4)) - | (L'.CUnif (_, (L'.KType, _), _, r), _) => - let - val t = (L'.CModProj (basis, [], "int"), loc) - in - r := SOME t; - ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_int"), loc)), loc), - ran, gs1 @ gs2 @ gs3 @ gs4) - end - | _ => (error (); - (eerror, cerror, [])) - end - | _ => raise Fail "Bad use of ESqlInfer [2]" + (L'.TFun (dom, ran), _) => + (case E.resolveClass env dom of + NONE => (expError env (Unresolvable (loc, dom)); + (eerror, cerror, [])) + | SOME pf => ((L'.EApp (e1', pf), loc), ran, gs1 @ gs2 @ gs3)) + | _ => (expError env (OutOfContext loc); + (eerror, cerror, [])) end - | L.ESqlInfer => raise Fail "Bad use of ESqlInfer [3]" + | L.EWild => (expError env (OutOfContext loc); + (eerror, cerror, [])) | L.EApp (e1, e2) => let @@ -1961,6 +1937,26 @@ fun elabSgn_item ((sgi, loc), (env, denv, gs)) = ([(L'.SgiTable (!basis_r, x, n, c'), loc)], (env, denv, gs)) end + | L.SgiClassAbs x => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val (env, n) = E.pushCNamed env x k NONE + val env = E.pushClass env n + in + ([(L'.SgiClassAbs (x, n), loc)], (env, denv, [])) + end + + | L.SgiClass (x, c) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val (c', ck, gs) = elabCon (env, denv) c + val (env, n) = E.pushCNamed env x k (SOME c') + val env = E.pushClass env n + in + checkKind env c' ck k; + ([(L'.SgiClass (x, n, c'), loc)], (env, denv, [])) + end + and elabSgn (env, denv) (sgn, loc) = case sgn of L.SgnConst sgis => @@ -2027,7 +2023,19 @@ and elabSgn (env, denv) (sgn, loc) = sgnError env (DuplicateVal (loc, x)) else (); - (cons, SS.add (vals, x), sgns, strs))) + (cons, SS.add (vals, x), sgns, strs)) + | L'.SgiClassAbs (x, _) => + (if SS.member (cons, x) then + sgnError env (DuplicateCon (loc, x)) + else + (); + (SS.add (cons, x), vals, sgns, strs)) + | L'.SgiClass (x, _, _) => + (if SS.member (cons, x) then + sgnError env (DuplicateCon (loc, x)) + else + (); + (SS.add (cons, x), vals, sgns, strs))) (SS.empty, SS.empty, SS.empty, SS.empty) sgis' in ((L'.SgnConst sgis', loc), gs) @@ -2160,6 +2168,20 @@ fun dopen (env, denv) {str, strs, sgn} = | L'.SgiTable (_, x, n, c) => (L'.DVal (x, n, (L'.CApp (tableOf (), c), loc), (L'.EModProj (str, strs, x), loc)), loc) + | L'.SgiClassAbs (x, n) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val c = (L'.CModProj (str, strs, x), loc) + in + (L'.DCon (x, n, k, c), loc) + end + | L'.SgiClass (x, n, _) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val c = (L'.CModProj (str, strs, x), loc) + in + (L'.DCon (x, n, k, c), loc) + end in (d, (E.declBinds env' d, denv')) end) @@ -2283,27 +2305,41 @@ fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) = in found (x', n1, k', SOME (L'.CModProj (m1, ms, s), loc)) end + | L'.SgiClassAbs (x', n1) => found (x', n1, + (L'.KArrow ((L'.KType, loc), + (L'.KType, loc)), loc), + NONE) + | L'.SgiClass (x', n1, c) => found (x', n1, + (L'.KArrow ((L'.KType, loc), + (L'.KType, loc)), loc), + SOME c) | _ => NONE end) | L'.SgiCon (x, n2, k2, c2) => seek (fn sgi1All as (sgi1, _) => - case sgi1 of - L'.SgiCon (x', n1, k1, c1) => - if x = x' then - let - fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2), denv) - in - (case unifyCons (env, denv) c1 c2 of - [] => good () - | _ => NONE) - handle CUnify (c1, c2, err) => - (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err)); - good ()) - end - else - NONE - | _ => NONE) + let + fun found (x', n1, k1, c1) = + if x = x' then + let + fun good () = SOME (E.pushCNamedAs env x n2 k2 (SOME c2), denv) + in + (case unifyCons (env, denv) c1 c2 of + [] => good () + | _ => NONE) + handle CUnify (c1, c2, err) => + (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err)); + good ()) + end + else + NONE + in + case sgi1 of + L'.SgiCon (x', n1, k1, c1) => found (x', n1, k1, c1) + | L'.SgiClass (x', n1, c1) => + found (x', n1, (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc), c1) + | _ => NONE + end) | L'.SgiDatatype (x, n2, xs2, xncs2) => seek (fn sgi1All as (sgi1, _) => @@ -2491,6 +2527,54 @@ fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) = else NONE | _ => NONE) + + | L'.SgiClassAbs (x, n2) => + seek (fn sgi1All as (sgi1, _) => + let + fun found (x', n1, co) = + if x = x' then + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val env = E.pushCNamedAs env x n1 k co + in + SOME (if n1 = n2 then + env + else + E.pushCNamedAs env x n2 k (SOME (L'.CNamed n1, loc2)), + denv) + end + else + NONE + in + case sgi1 of + L'.SgiClassAbs (x', n1) => found (x', n1, NONE) + | L'.SgiClass (x', n1, c) => found (x', n1, SOME c) + | _ => NONE + end) + | L'.SgiClass (x, n2, c2) => + seek (fn sgi1All as (sgi1, _) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + + fun found (x', n1, c1) = + if x = x' then + let + fun good () = SOME (E.pushCNamedAs env x n2 k (SOME c2), denv) + in + (case unifyCons (env, denv) c1 c2 of + [] => good () + | _ => NONE) + handle CUnify (c1, c2, err) => + (sgnError env (SgiWrongCon (sgi1All, c1, sgi2All, c2, err)); + good ()) + end + else + NONE + in + case sgi1 of + L'.SgiClass (x', n1, c1) => found (x', n1, c1) + | _ => NONE + end) end in ignore (foldl folder (env, denv) sgis2) @@ -2849,6 +2933,17 @@ fun elabDecl ((d, loc), (env, denv, gs)) = ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs)) end + | L.DClass (x, c) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + val (c', ck, gs) = elabCon (env, denv) c + val (env, n) = E.pushCNamed env x k (SOME c') + val env = E.pushClass env n + in + checkKind env c' ck k; + ([(L'.DCon (x, n, k, c'), loc)], (env, denv, [])) + end + and elabStr (env, denv) (str, loc) = case str of L.StrConst ds => @@ -2949,6 +3044,30 @@ and elabStr (env, denv) (str, loc) = (SS.add (vals, x), x) in ((L'.SgiTable (tn, x, n, c), loc) :: sgis, cons, vals, sgns, strs) + end + | L'.SgiClassAbs (x, n) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + + val (cons, x) = + if SS.member (cons, x) then + (cons, "?" ^ x) + else + (SS.add (cons, x), x) + in + ((L'.SgiClassAbs (x, n), loc) :: sgis, cons, vals, sgns, strs) + end + | L'.SgiClass (x, n, c) => + let + val k = (L'.KArrow ((L'.KType, loc), (L'.KType, loc)), loc) + + val (cons, x) = + if SS.member (cons, x) then + (cons, "?" ^ x) + else + (SS.add (cons, x), x) + in + ((L'.SgiClass (x, n, c), loc) :: sgis, cons, vals, sgns, strs) end) ([], SS.empty, SS.empty, SS.empty, SS.empty) sgis -- cgit v1.2.3