summaryrefslogtreecommitdiff
path: root/src/cjr_print.sml
diff options
context:
space:
mode:
Diffstat (limited to 'src/cjr_print.sml')
-rw-r--r--src/cjr_print.sml101
1 files changed, 81 insertions, 20 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml
index 06154b91..d7e426c3 100644
--- a/src/cjr_print.sml
+++ b/src/cjr_print.sml
@@ -408,24 +408,61 @@ fun p_unsql wontLeakStrings env (tAll as (t, loc)) e =
box [string "uw_Basis_strdup(ctx, ", e, string ")"]
| TFfi ("Basis", "bool") => box [string "uw_Basis_stringToBool_error(ctx, ", e, string ")"]
| TFfi ("Basis", "time") => box [string "uw_Basis_stringToTime_error(ctx, ", e, string ")"]
+
| _ => (ErrorMsg.errorAt loc "Don't know how to unmarshal type from SQL";
Print.eprefaces' [("Type", p_typ env tAll)];
string "ERROR")
+fun p_getcol wontLeakStrings env (tAll as (t, loc)) i =
+ case t of
+ TOption t =>
+ box [string "(PQgetisnull (res, i, ",
+ string (Int.toString i),
+ string ") ? NULL : ",
+ case t of
+ (TFfi ("Basis", "string"), _) => p_getcol wontLeakStrings env t i
+ | _ => box [string "({",
+ newline,
+ p_typ env t,
+ space,
+ string "*tmp = uw_malloc(ctx, sizeof(",
+ p_typ env t,
+ string "));",
+ newline,
+ string "*tmp = ",
+ p_getcol wontLeakStrings env t i,
+ string ";",
+ newline,
+ string "tmp;",
+ newline,
+ string "})"],
+ string ")"]
+
+ | _ =>
+ p_unsql wontLeakStrings env tAll
+ (box [string "PQgetvalue(res, i, ",
+ string (Int.toString i),
+ string ")"])
+
datatype sql_type =
Int
| Float
| String
| Bool
| Time
+ | Nullable of sql_type
+
+fun p_sql_type' t =
+ case t of
+ Int => "uw_Basis_int"
+ | Float => "uw_Basis_float"
+ | String => "uw_Basis_string"
+ | Bool => "uw_Basis_bool"
+ | Time => "uw_Basis_time"
+ | Nullable String => "uw_Basis_string"
+ | Nullable t => p_sql_type' t ^ "*"
-fun p_sql_type t =
- string (case t of
- Int => "uw_Basis_int"
- | Float => "uw_Basis_float"
- | String => "uw_Basis_string"
- | Bool => "uw_Basis_bool"
- | Time => "uw_Basis_time")
+fun p_sql_type t = string (p_sql_type' t)
fun getPargs (e, _) =
case e of
@@ -448,6 +485,12 @@ fun p_ensql t e =
| String => e
| Bool => box [string "(", e, string " ? \"TRUE\" : \"FALSE\")"]
| Time => box [string "uw_Basis_sqlifyTime(ctx, ", e, string ")"]
+ | Nullable String => e
+ | Nullable t => box [string "(",
+ e,
+ string " == NULL ? NULL : ",
+ p_ensql t (box [string "*", e]),
+ string ")"]
fun notLeaky env allowHeapAllocated =
let
@@ -1169,10 +1212,7 @@ fun p_exp' par env (e, loc) =
space,
string "=",
space,
- p_unsql wontLeakStrings env t
- (box [string "PQgetvalue(res, i, ",
- string (Int.toString i),
- string ")"]),
+ p_getcol wontLeakStrings env t i,
string ";",
newline]) outputs,
@@ -1660,7 +1700,10 @@ fun p_decl env (dAll as (d, _) : decl) =
string "}",
newline]
- | DPreparedStatements [] => box []
+ | DPreparedStatements [] =>
+ box [string "static void uw_db_prepare(uw_context ctx) {",
+ newline,
+ string "}"]
| DPreparedStatements ss =>
box [string "static void uw_db_prepare(uw_context ctx) {",
newline,
@@ -1708,7 +1751,7 @@ datatype 'a search =
| NotFound
| Error
-fun p_sqltype' env (tAll as (t, loc)) =
+fun p_sqltype'' env (tAll as (t, loc)) =
case t of
TFfi ("Basis", "int") => "int8"
| TFfi ("Basis", "float") => "float8"
@@ -1719,8 +1762,25 @@ fun p_sqltype' env (tAll as (t, loc)) =
Print.eprefaces' [("Type", p_typ env tAll)];
"ERROR")
+fun p_sqltype' env (tAll as (t, loc)) =
+ case t of
+ (TOption t, _) => p_sqltype'' env t
+ | _ => p_sqltype'' env t ^ " NOT NULL"
+
fun p_sqltype env t = string (p_sqltype' env t)
+fun p_sqltype_base' env t =
+ case t of
+ (TOption t, _) => p_sqltype'' env t
+ | _ => p_sqltype'' env t
+
+fun p_sqltype_base env t = string (p_sqltype_base' env t)
+
+fun is_not_null t =
+ case t of
+ (TOption _, _) => false
+ | _ => true
+
fun p_file env (ds, ps) =
let
val (pds, env) = ListUtil.foldlMap (fn (d, env) =>
@@ -1997,8 +2057,13 @@ fun p_file env (ds, ps) =
Char.toLower (ident x),
"' AND atttypid = (SELECT oid FROM pg_type",
" WHERE typname = '",
- p_sqltype' env t,
- "'))"]) xts),
+ p_sqltype_base' env t,
+ "') AND attnotnull = ",
+ if is_not_null t then
+ "TRUE"
+ else
+ "FALSE",
+ ")"]) xts),
")"]
val q'' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '",
@@ -2295,11 +2360,7 @@ fun p_sql env (ds, _) =
box [string "uw_",
string (CharVector.map Char.toLower x),
space,
- p_sqltype env t,
- space,
- string "NOT",
- space,
- string "NULL"]) xts,
+ p_sqltype env (t, ErrorMsg.dummySpan)]) xts,
string ");",
newline,
newline]