From c8fa648dbc2489ca4a56abbb27d94819fb75b5ec Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Sat, 16 Aug 2008 12:35:46 -0400 Subject: Inferring sql_type's --- src/elaborate.sml | 77 +++++++++++++++++++++++++++++++++++++++++----------- src/lacweb.grm | 4 +++ src/source.sml | 2 ++ src/source_print.sml | 2 ++ tests/where.lac | 1 + 5 files changed, 70 insertions(+), 16 deletions(-) diff --git a/src/elaborate.sml b/src/elaborate.sml index 41c9e6df..a03904d0 100644 --- a/src/elaborate.sml +++ b/src/elaborate.sml @@ -47,6 +47,8 @@ end structure SS = BinarySetFn(SK) structure SM = BinaryMapFn(SK) +val basis_r = ref 0 + fun elabExplicitness e = case e of L.Explicit => L'.Explicit @@ -862,9 +864,7 @@ and unifyCons' (env, denv) c1 c2 = and unifyCons'' (env, denv) (c1All as (c1, loc)) (c2All as (c2, _)) = let - fun err f = (prefaces "unifyCons'' fails" [("c1All", p_con env c1All), - ("c2All", p_con env c2All)]; - raise CUnify' (f (c1All, c2All))) + fun err f = raise CUnify' (f (c1All, c2All)) fun isRecord () = unifyRecordCons (env, denv) (c1All, c2All) in @@ -985,6 +985,7 @@ datatype exp_error = | PatHasNoArg of ErrorMsg.span | Inexhaustive of ErrorMsg.span | DuplicatePatField of ErrorMsg.span * string + | SqlInfer of ErrorMsg.span * L'.con fun expError env err = case err of @@ -1027,7 +1028,10 @@ fun expError env err = ErrorMsg.errorAt loc "Inexhaustive 'case'" | DuplicatePatField (loc, s) => ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern") - + | SqlInfer (loc, c) => + (ErrorMsg.errorAt loc "Can't infer SQL-ness of type"; + eprefaces' [("Type", p_con env c)]) + fun checkCon (env, denv) e c1 c2 = unifyCons (env, denv) c1 c2 handle CUnify (c1, c2, err) => @@ -1415,6 +1419,51 @@ fun elabExp (env, denv) (eAll as (e, loc)) = ((L'.EModProj (n, ms, s), loc), t, []) end) + | L.EApp (e1, (L.ESqlInfer, _)) => + let + val (e1', t1, gs1) = elabExp (env, denv) e1 + val (e1', t1, gs2) = elabHead (env, denv) e1' t1 + val (t1, gs3) = hnormCon (env, denv) t1 + in + case t1 of + (L'.TFun ((L'.CApp ((L'.CModProj (basis, [], "sql_type"), _), + t), _), ran), _) => + if basis <> !basis_r then + raise Fail "Bad use of ESqlInfer [1]" + else + let + val (t, gs4) = hnormCon (env, denv) t + + fun error () = expError env (SqlInfer (loc, t)) + in + case t of + (L'.CModProj (basis, [], x), _) => + (if basis <> !basis_r then + error () + else + case x of + "bool" => () + | "int" => () + | "float" => () + | "string" => () + | _ => error (); + ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_" ^ x), loc)), loc), + ran, gs1 @ gs2 @ gs3 @ gs4)) + | (L'.CUnif (_, (L'.KType, _), _, r), _) => + let + val t = (L'.CModProj (basis, [], "int"), loc) + in + r := SOME t; + ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_int"), loc)), loc), + ran, gs1 @ gs2 @ gs3 @ gs4) + end + | _ => (error (); + (eerror, cerror, [])) + end + | _ => raise Fail "Bad use of ESqlInfer [2]" + end + | L.ESqlInfer => raise Fail "Bad use of ESqlInfer [3]" + | L.EApp (e1, e2) => let val (e1', t1, gs1) = elabExp (env, denv) e1 @@ -1736,12 +1785,7 @@ fun strError env err = val hnormSgn = E.hnormSgn -fun tableOf' env = - case E.lookupStr env "Basis" of - NONE => raise Fail "Elaborate.tableOf: Can't find Basis" - | SOME (n, _) => n - -fun tableOf env = (L'.CModProj (tableOf' env, [], "sql_table"), ErrorMsg.dummySpan) +fun tableOf () = (L'.CModProj (!basis_r, [], "sql_table"), ErrorMsg.dummySpan) fun elabSgn_item ((sgi, loc), (env, denv, gs)) = case sgi of @@ -1911,10 +1955,10 @@ fun elabSgn_item ((sgi, loc), (env, denv, gs)) = | L.SgiTable (x, c) => let val (c', k, gs) = elabCon (env, denv) c - val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc) + val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc) in checkKind env c' k (L'.KRecord (L'.KType, loc), loc); - ([(L'.SgiTable (tableOf' env, x, n, c'), loc)], (env, denv, gs)) + ([(L'.SgiTable (!basis_r, x, n, c'), loc)], (env, denv, gs)) end and elabSgn (env, denv) (sgn, loc) = @@ -2114,7 +2158,7 @@ fun dopen (env, denv) {str, strs, sgn} = | L'.SgiConstraint (c1, c2) => (L'.DConstraint (c1, c2), loc) | L'.SgiTable (_, x, n, c) => - (L'.DVal (x, n, (L'.CApp (tableOf env, c), loc), + (L'.DVal (x, n, (L'.CApp (tableOf (), c), loc), (L'.EModProj (str, strs, x), loc)), loc) in (d, (E.declBinds env' d, denv')) @@ -2363,7 +2407,7 @@ fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) = NONE | L'.SgiTable (_, x', n1, c1) => if x = x' then - (case unifyCons (env, denv) (L'.CApp (tableOf env, c1), loc) c2 of + (case unifyCons (env, denv) (L'.CApp (tableOf (), c1), loc) c2 of [] => SOME (env, denv) | _ => NONE) handle CUnify (c1, c2, err) => @@ -2799,10 +2843,10 @@ fun elabDecl ((d, loc), (env, denv, gs)) = | L.DTable (x, c) => let val (c', k, gs') = elabCon (env, denv) c - val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc) + val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc) in checkKind env c' k (L'.KRecord (L'.KType, loc), loc); - ([(L'.DTable (tableOf' env, x, n, c'), loc)], (env, denv, gs' @ gs)) + ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs)) end and elabStr (env, denv) (str, loc) = @@ -2979,6 +3023,7 @@ fun elabFile basis env file = raise Fail "Unresolved disjointness constraints in Basis") val (env', basis_n) = E.pushStrNamed env "Basis" sgn + val () = basis_r := basis_n val (ds, (env', _)) = dopen (env', D.empty) {str = basis_n, strs = [], sgn = sgn} diff --git a/src/lacweb.grm b/src/lacweb.grm index 13e464c4..464f5f82 100644 --- a/src/lacweb.grm +++ b/src/lacweb.grm @@ -632,6 +632,10 @@ sqlexp : TRUE (sql_inject (EVar (["Basis"], "True"), EVar (["Basis"], "sql_bool"), s (FALSEleft, FALSEright))) + | LBRACE eexp RBRACE (sql_inject (#1 eexp, + ESqlInfer, + s (LBRACEleft, RBRACEright))) + wopt : (sql_inject (EVar (["Basis"], "True"), EVar (["Basis"], "sql_bool"), ErrorMsg.dummySpan)) diff --git a/src/source.sml b/src/source.sml index de0c296d..70851c73 100644 --- a/src/source.sml +++ b/src/source.sml @@ -119,6 +119,8 @@ datatype exp' = | ECut of exp * con | EFold + | ESqlInfer + | ECase of exp * (pat * exp) list withtype exp = exp' located diff --git a/src/source_print.sml b/src/source_print.sml index a953d7f6..ceb331f0 100644 --- a/src/source_print.sml +++ b/src/source_print.sml @@ -286,6 +286,8 @@ fun p_exp' par (e, _) = space, p_exp e]) pes]) + | ESqlInfer => string "" + and p_exp e = p_exp' false e fun p_datatype (x, xs, cons) = diff --git a/tests/where.lac b/tests/where.lac index 1454583b..c7bd6167 100644 --- a/tests/where.lac +++ b/tests/where.lac @@ -4,3 +4,4 @@ table t2 : {A : float, D : int} val q1 = (SELECT * FROM t1) val q2 = (SELECT * FROM t1 WHERE TRUE) val q3 = (SELECT * FROM t1 WHERE FALSE) +val q4 = (SELECT * FROM t1 WHERE {True}) -- cgit v1.2.3