diff options
Diffstat (limited to 'src/cjr_print.sml')
-rw-r--r-- | src/cjr_print.sml | 101 |
1 files changed, 81 insertions, 20 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml index 06154b91..d7e426c3 100644 --- a/src/cjr_print.sml +++ b/src/cjr_print.sml @@ -408,24 +408,61 @@ fun p_unsql wontLeakStrings env (tAll as (t, loc)) e = box [string "uw_Basis_strdup(ctx, ", e, string ")"] | TFfi ("Basis", "bool") => box [string "uw_Basis_stringToBool_error(ctx, ", e, string ")"] | TFfi ("Basis", "time") => box [string "uw_Basis_stringToTime_error(ctx, ", e, string ")"] + | _ => (ErrorMsg.errorAt loc "Don't know how to unmarshal type from SQL"; Print.eprefaces' [("Type", p_typ env tAll)]; string "ERROR") +fun p_getcol wontLeakStrings env (tAll as (t, loc)) i = + case t of + TOption t => + box [string "(PQgetisnull (res, i, ", + string (Int.toString i), + string ") ? NULL : ", + case t of + (TFfi ("Basis", "string"), _) => p_getcol wontLeakStrings env t i + | _ => box [string "({", + newline, + p_typ env t, + space, + string "*tmp = uw_malloc(ctx, sizeof(", + p_typ env t, + string "));", + newline, + string "*tmp = ", + p_getcol wontLeakStrings env t i, + string ";", + newline, + string "tmp;", + newline, + string "})"], + string ")"] + + | _ => + p_unsql wontLeakStrings env tAll + (box [string "PQgetvalue(res, i, ", + string (Int.toString i), + string ")"]) + datatype sql_type = Int | Float | String | Bool | Time + | Nullable of sql_type + +fun p_sql_type' t = + case t of + Int => "uw_Basis_int" + | Float => "uw_Basis_float" + | String => "uw_Basis_string" + | Bool => "uw_Basis_bool" + | Time => "uw_Basis_time" + | Nullable String => "uw_Basis_string" + | Nullable t => p_sql_type' t ^ "*" -fun p_sql_type t = - string (case t of - Int => "uw_Basis_int" - | Float => "uw_Basis_float" - | String => "uw_Basis_string" - | Bool => "uw_Basis_bool" - | Time => "uw_Basis_time") +fun p_sql_type t = string (p_sql_type' t) fun getPargs (e, _) = case e of @@ -448,6 +485,12 @@ fun p_ensql t e = | String => e | Bool => box [string "(", e, string " ? \"TRUE\" : \"FALSE\")"] | Time => box [string "uw_Basis_sqlifyTime(ctx, ", e, string ")"] + | Nullable String => e + | Nullable t => box [string "(", + e, + string " == NULL ? NULL : ", + p_ensql t (box [string "*", e]), + string ")"] fun notLeaky env allowHeapAllocated = let @@ -1169,10 +1212,7 @@ fun p_exp' par env (e, loc) = space, string "=", space, - p_unsql wontLeakStrings env t - (box [string "PQgetvalue(res, i, ", - string (Int.toString i), - string ")"]), + p_getcol wontLeakStrings env t i, string ";", newline]) outputs, @@ -1660,7 +1700,10 @@ fun p_decl env (dAll as (d, _) : decl) = string "}", newline] - | DPreparedStatements [] => box [] + | DPreparedStatements [] => + box [string "static void uw_db_prepare(uw_context ctx) {", + newline, + string "}"] | DPreparedStatements ss => box [string "static void uw_db_prepare(uw_context ctx) {", newline, @@ -1708,7 +1751,7 @@ datatype 'a search = | NotFound | Error -fun p_sqltype' env (tAll as (t, loc)) = +fun p_sqltype'' env (tAll as (t, loc)) = case t of TFfi ("Basis", "int") => "int8" | TFfi ("Basis", "float") => "float8" @@ -1719,8 +1762,25 @@ fun p_sqltype' env (tAll as (t, loc)) = Print.eprefaces' [("Type", p_typ env tAll)]; "ERROR") +fun p_sqltype' env (tAll as (t, loc)) = + case t of + (TOption t, _) => p_sqltype'' env t + | _ => p_sqltype'' env t ^ " NOT NULL" + fun p_sqltype env t = string (p_sqltype' env t) +fun p_sqltype_base' env t = + case t of + (TOption t, _) => p_sqltype'' env t + | _ => p_sqltype'' env t + +fun p_sqltype_base env t = string (p_sqltype_base' env t) + +fun is_not_null t = + case t of + (TOption _, _) => false + | _ => true + fun p_file env (ds, ps) = let val (pds, env) = ListUtil.foldlMap (fn (d, env) => @@ -1997,8 +2057,13 @@ fun p_file env (ds, ps) = Char.toLower (ident x), "' AND atttypid = (SELECT oid FROM pg_type", " WHERE typname = '", - p_sqltype' env t, - "'))"]) xts), + p_sqltype_base' env t, + "') AND attnotnull = ", + if is_not_null t then + "TRUE" + else + "FALSE", + ")"]) xts), ")"] val q'' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '", @@ -2295,11 +2360,7 @@ fun p_sql env (ds, _) = box [string "uw_", string (CharVector.map Char.toLower x), space, - p_sqltype env t, - space, - string "NOT", - space, - string "NULL"]) xts, + p_sqltype env (t, ErrorMsg.dummySpan)]) xts, string ");", newline, newline] |