summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/elaborate.sml77
-rw-r--r--src/lacweb.grm4
-rw-r--r--src/source.sml2
-rw-r--r--src/source_print.sml2
-rw-r--r--tests/where.lac1
5 files changed, 70 insertions, 16 deletions
diff --git a/src/elaborate.sml b/src/elaborate.sml
index 41c9e6df..a03904d0 100644
--- a/src/elaborate.sml
+++ b/src/elaborate.sml
@@ -47,6 +47,8 @@ end
structure SS = BinarySetFn(SK)
structure SM = BinaryMapFn(SK)
+val basis_r = ref 0
+
fun elabExplicitness e =
case e of
L.Explicit => L'.Explicit
@@ -862,9 +864,7 @@ and unifyCons' (env, denv) c1 c2 =
and unifyCons'' (env, denv) (c1All as (c1, loc)) (c2All as (c2, _)) =
let
- fun err f = (prefaces "unifyCons'' fails" [("c1All", p_con env c1All),
- ("c2All", p_con env c2All)];
- raise CUnify' (f (c1All, c2All)))
+ fun err f = raise CUnify' (f (c1All, c2All))
fun isRecord () = unifyRecordCons (env, denv) (c1All, c2All)
in
@@ -985,6 +985,7 @@ datatype exp_error =
| PatHasNoArg of ErrorMsg.span
| Inexhaustive of ErrorMsg.span
| DuplicatePatField of ErrorMsg.span * string
+ | SqlInfer of ErrorMsg.span * L'.con
fun expError env err =
case err of
@@ -1027,7 +1028,10 @@ fun expError env err =
ErrorMsg.errorAt loc "Inexhaustive 'case'"
| DuplicatePatField (loc, s) =>
ErrorMsg.errorAt loc ("Duplicate record field " ^ s ^ " in pattern")
-
+ | SqlInfer (loc, c) =>
+ (ErrorMsg.errorAt loc "Can't infer SQL-ness of type";
+ eprefaces' [("Type", p_con env c)])
+
fun checkCon (env, denv) e c1 c2 =
unifyCons (env, denv) c1 c2
handle CUnify (c1, c2, err) =>
@@ -1415,6 +1419,51 @@ fun elabExp (env, denv) (eAll as (e, loc)) =
((L'.EModProj (n, ms, s), loc), t, [])
end)
+ | L.EApp (e1, (L.ESqlInfer, _)) =>
+ let
+ val (e1', t1, gs1) = elabExp (env, denv) e1
+ val (e1', t1, gs2) = elabHead (env, denv) e1' t1
+ val (t1, gs3) = hnormCon (env, denv) t1
+ in
+ case t1 of
+ (L'.TFun ((L'.CApp ((L'.CModProj (basis, [], "sql_type"), _),
+ t), _), ran), _) =>
+ if basis <> !basis_r then
+ raise Fail "Bad use of ESqlInfer [1]"
+ else
+ let
+ val (t, gs4) = hnormCon (env, denv) t
+
+ fun error () = expError env (SqlInfer (loc, t))
+ in
+ case t of
+ (L'.CModProj (basis, [], x), _) =>
+ (if basis <> !basis_r then
+ error ()
+ else
+ case x of
+ "bool" => ()
+ | "int" => ()
+ | "float" => ()
+ | "string" => ()
+ | _ => error ();
+ ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_" ^ x), loc)), loc),
+ ran, gs1 @ gs2 @ gs3 @ gs4))
+ | (L'.CUnif (_, (L'.KType, _), _, r), _) =>
+ let
+ val t = (L'.CModProj (basis, [], "int"), loc)
+ in
+ r := SOME t;
+ ((L'.EApp (e1', (L'.EModProj (basis, [], "sql_int"), loc)), loc),
+ ran, gs1 @ gs2 @ gs3 @ gs4)
+ end
+ | _ => (error ();
+ (eerror, cerror, []))
+ end
+ | _ => raise Fail "Bad use of ESqlInfer [2]"
+ end
+ | L.ESqlInfer => raise Fail "Bad use of ESqlInfer [3]"
+
| L.EApp (e1, e2) =>
let
val (e1', t1, gs1) = elabExp (env, denv) e1
@@ -1736,12 +1785,7 @@ fun strError env err =
val hnormSgn = E.hnormSgn
-fun tableOf' env =
- case E.lookupStr env "Basis" of
- NONE => raise Fail "Elaborate.tableOf: Can't find Basis"
- | SOME (n, _) => n
-
-fun tableOf env = (L'.CModProj (tableOf' env, [], "sql_table"), ErrorMsg.dummySpan)
+fun tableOf () = (L'.CModProj (!basis_r, [], "sql_table"), ErrorMsg.dummySpan)
fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
case sgi of
@@ -1911,10 +1955,10 @@ fun elabSgn_item ((sgi, loc), (env, denv, gs)) =
| L.SgiTable (x, c) =>
let
val (c', k, gs) = elabCon (env, denv) c
- val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
+ val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc)
in
checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
- ([(L'.SgiTable (tableOf' env, x, n, c'), loc)], (env, denv, gs))
+ ([(L'.SgiTable (!basis_r, x, n, c'), loc)], (env, denv, gs))
end
and elabSgn (env, denv) (sgn, loc) =
@@ -2114,7 +2158,7 @@ fun dopen (env, denv) {str, strs, sgn} =
| L'.SgiConstraint (c1, c2) =>
(L'.DConstraint (c1, c2), loc)
| L'.SgiTable (_, x, n, c) =>
- (L'.DVal (x, n, (L'.CApp (tableOf env, c), loc),
+ (L'.DVal (x, n, (L'.CApp (tableOf (), c), loc),
(L'.EModProj (str, strs, x), loc)), loc)
in
(d, (E.declBinds env' d, denv'))
@@ -2363,7 +2407,7 @@ fun subSgn (env, denv) sgn1 (sgn2 as (_, loc2)) =
NONE
| L'.SgiTable (_, x', n1, c1) =>
if x = x' then
- (case unifyCons (env, denv) (L'.CApp (tableOf env, c1), loc) c2 of
+ (case unifyCons (env, denv) (L'.CApp (tableOf (), c1), loc) c2 of
[] => SOME (env, denv)
| _ => NONE)
handle CUnify (c1, c2, err) =>
@@ -2799,10 +2843,10 @@ fun elabDecl ((d, loc), (env, denv, gs)) =
| L.DTable (x, c) =>
let
val (c', k, gs') = elabCon (env, denv) c
- val (env, n) = E.pushENamed env x (L'.CApp (tableOf env, c'), loc)
+ val (env, n) = E.pushENamed env x (L'.CApp (tableOf (), c'), loc)
in
checkKind env c' k (L'.KRecord (L'.KType, loc), loc);
- ([(L'.DTable (tableOf' env, x, n, c'), loc)], (env, denv, gs' @ gs))
+ ([(L'.DTable (!basis_r, x, n, c'), loc)], (env, denv, gs' @ gs))
end
and elabStr (env, denv) (str, loc) =
@@ -2979,6 +3023,7 @@ fun elabFile basis env file =
raise Fail "Unresolved disjointness constraints in Basis")
val (env', basis_n) = E.pushStrNamed env "Basis" sgn
+ val () = basis_r := basis_n
val (ds, (env', _)) = dopen (env', D.empty) {str = basis_n, strs = [], sgn = sgn}
diff --git a/src/lacweb.grm b/src/lacweb.grm
index 13e464c4..464f5f82 100644
--- a/src/lacweb.grm
+++ b/src/lacweb.grm
@@ -632,6 +632,10 @@ sqlexp : TRUE (sql_inject (EVar (["Basis"], "True"),
EVar (["Basis"], "sql_bool"),
s (FALSEleft, FALSEright)))
+ | LBRACE eexp RBRACE (sql_inject (#1 eexp,
+ ESqlInfer,
+ s (LBRACEleft, RBRACEright)))
+
wopt : (sql_inject (EVar (["Basis"], "True"),
EVar (["Basis"], "sql_bool"),
ErrorMsg.dummySpan))
diff --git a/src/source.sml b/src/source.sml
index de0c296d..70851c73 100644
--- a/src/source.sml
+++ b/src/source.sml
@@ -119,6 +119,8 @@ datatype exp' =
| ECut of exp * con
| EFold
+ | ESqlInfer
+
| ECase of exp * (pat * exp) list
withtype exp = exp' located
diff --git a/src/source_print.sml b/src/source_print.sml
index a953d7f6..ceb331f0 100644
--- a/src/source_print.sml
+++ b/src/source_print.sml
@@ -286,6 +286,8 @@ fun p_exp' par (e, _) =
space,
p_exp e]) pes])
+ | ESqlInfer => string "<sql-infer>"
+
and p_exp e = p_exp' false e
fun p_datatype (x, xs, cons) =
diff --git a/tests/where.lac b/tests/where.lac
index 1454583b..c7bd6167 100644
--- a/tests/where.lac
+++ b/tests/where.lac
@@ -4,3 +4,4 @@ table t2 : {A : float, D : int}
val q1 = (SELECT * FROM t1)
val q2 = (SELECT * FROM t1 WHERE TRUE)
val q3 = (SELECT * FROM t1 WHERE FALSE)
+val q4 = (SELECT * FROM t1 WHERE {True})