summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cjr_print.sml7
-rw-r--r--src/mysql.sml513
-rw-r--r--src/postgres.sml31
-rw-r--r--src/settings.sig7
-rw-r--r--src/settings.sml10
5 files changed, 543 insertions, 25 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml
index 7d1120b4..fcfa402e 100644
--- a/src/cjr_print.sml
+++ b/src/cjr_print.sml
@@ -1652,7 +1652,7 @@ fun p_exp' par env (e, loc) =
#query (Settings.currentDbms ())
{loc = loc,
- numCols = length outputs,
+ cols = map (fn (_, t) => sql_type_in env t) outputs,
doCols = doCols}]
| SOME (id, query) =>
box [p_list_sepi newline
@@ -1675,7 +1675,7 @@ fun p_exp' par env (e, loc) =
id = id,
query = query,
inputs = map #2 inputs,
- numCols = length outputs,
+ cols = map (fn (_, t) => sql_type_in env t) outputs,
doCols = doCols}],
newline,
@@ -2797,7 +2797,8 @@ fun p_sql env (ds, _) =
box [string "uw_",
string (CharVector.map Char.toLower x),
space,
- p_sqltype env (t, ErrorMsg.dummySpan)]) xts,
+ string (#p_sql_type (Settings.currentDbms ())
+ (sql_type_in env t))]) xts,
case (pk, csts) of
("", []) => box []
| _ => string ",",
diff --git a/src/mysql.sml b/src/mysql.sml
index 7b02c787..2fcdef2d 100644
--- a/src/mysql.sml
+++ b/src/mysql.sml
@@ -31,6 +31,30 @@ open Settings
open Print.PD
open Print
+fun p_sql_type t =
+ case t of
+ Int => "bigint"
+ | Float => "double"
+ | String => "longtext"
+ | Bool => "bool"
+ | Time => "timestamp"
+ | Blob => "longblob"
+ | Channel => "bigint"
+ | Client => "int"
+ | Nullable t => p_sql_type t
+
+fun p_buffer_type t =
+ case t of
+ Int => "MYSQL_TYPE_LONGLONG"
+ | Float => "MYSQL_TYPE_DOUBLE"
+ | String => "MYSQL_TYPE_STRING"
+ | Bool => "MYSQL_TYPE_LONG"
+ | Time => "MYSQL_TYPE_TIME"
+ | Blob => "MYSQL_TYPE_BLOB"
+ | Channel => "MYSQL_TYPE_LONGLONG"
+ | Client => "MYSQL_TYPE_LONG"
+ | Nullable t => p_buffer_type t
+
fun init {dbstring, prepared = ss, tables, views, sequences} =
let
val host = ref NONE
@@ -138,6 +162,10 @@ fun init {dbstring, prepared = ss, tables, views, sequences} =
newline,
uhoh true "Error preparing statement: %s" ["msg"]],
string "}",
+ newline,
+ string "conn->p",
+ string (Int.toString i),
+ string " = stmt;",
newline]
end)
ss,
@@ -253,12 +281,484 @@ fun init {dbstring, prepared = ss, tables, views, sequences} =
newline]
end
-fun query _ = raise Fail "MySQL query"
-fun queryPrepared _ = raise Fail "MySQL queryPrepared"
-fun dml _ = raise Fail "MySQL dml"
-fun dmlPrepared _ = raise Fail "MySQL dmlPrepared"
-fun nextval _ = raise Fail "MySQL nextval"
-fun nextvalPrepared _ = raise Fail "MySQL nextvalPrepared"
+fun p_getcol {wontLeakStrings = _, col = i, typ = t} =
+ let
+ fun getter t =
+ case t of
+ String => box [string "({",
+ newline,
+ string "uw_Basis_string s = uw_malloc(ctx, length",
+ string (Int.toString i),
+ string " + 1);",
+ newline,
+ string "out[",
+ string (Int.toString i),
+ string "].buffer = s;",
+ newline,
+ string "out[",
+ string (Int.toString i),
+ string "].buffer_length = length",
+ string (Int.toString i),
+ string " + 1;",
+ newline,
+ string "mysql_stmt_fetch_column(stmt, &out[",
+ string (Int.toString i),
+ string "], ",
+ string (Int.toString i),
+ string ", 0);",
+ newline,
+ string "s[length",
+ string (Int.toString i),
+ string "] = 0;",
+ newline,
+ string "s;",
+ newline,
+ string "})"]
+ | Blob => box [string "({",
+ newline,
+ string "uw_Basis_blob b = {length",
+ string (Int.toString i),
+ string ", uw_malloc(ctx, length",
+ string (Int.toString i),
+ string ")};",
+ newline,
+ string "out[",
+ string (Int.toString i),
+ string "].buffer = b.data;",
+ newline,
+ string "out[",
+ string (Int.toString i),
+ string "].buffer_length = length",
+ string (Int.toString i),
+ string ";",
+ newline,
+ string "mysql_stmt_fetch_column(stmt, &out[",
+ string (Int.toString i),
+ string "], ",
+ string (Int.toString i),
+ string ", 0);",
+ newline,
+ string "b;",
+ newline,
+ string "})"]
+ | Time => box [string "({",
+ string "MYSQL_TIME *mt = buffer",
+ string (Int.toString i),
+ string ";",
+ newline,
+ newline,
+ string "struct tm t = {mt->second, mt->minute, mt->hour, mt->day, mt->month, mt->year, 0, 0, -1};",
+ newline,
+ string "mktime(&tm);",
+ newline,
+ string "})"]
+ | _ => box [string "buffer",
+ string (Int.toString i)]
+ in
+ case t of
+ Nullable t => box [string "(is_null",
+ string (Int.toString i),
+ string " ? NULL : ",
+ case t of
+ String => getter t
+ | _ => box [string "({",
+ newline,
+ string (p_sql_ctype t),
+ space,
+ string "*tmp = uw_malloc(ctx, sizeof(",
+ string (p_sql_ctype t),
+ string "));",
+ newline,
+ string "*tmp = ",
+ getter t,
+ string ";",
+ newline,
+ string "tmp;",
+ newline,
+ string "})"],
+ string ")"]
+ | _ => box [string "(is_null",
+ string (Int.toString i),
+ string " ? ",
+ box [string "({",
+ string (p_sql_ctype t),
+ space,
+ string "tmp;",
+ newline,
+ string "uw_error(ctx, FATAL, \"Unexpectedly NULL field #",
+ string (Int.toString i),
+ string "\");",
+ newline,
+ string "tmp;",
+ newline,
+ string "})"],
+ string " : ",
+ getter t,
+ string ")"]
+ end
+
+fun queryCommon {loc, query, cols, doCols} =
+ box [string "int n, r;",
+ newline,
+ string "MYSQL_BIND out[",
+ string (Int.toString (length cols)),
+ string "];",
+ newline,
+ p_list_sepi (box []) (fn i => fn t =>
+ let
+ fun buffers t =
+ case t of
+ String => box [string "unsigned long length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | Blob => box [string "unsigned long length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | _ => box [string (p_sql_ctype t),
+ space,
+ string "buffer",
+ string (Int.toString i),
+ string ";",
+ newline]
+ in
+ box [string "my_bool is_null",
+ string (Int.toString i),
+ string ";",
+ newline,
+ case t of
+ Nullable t => buffers t
+ | _ => buffers t,
+ newline]
+ end) cols,
+ newline,
+
+ string "memset(out, 0, sizeof out);",
+ newline,
+ p_list_sepi (box []) (fn i => fn t =>
+ let
+ fun buffers t =
+ case t of
+ String => box []
+ | Blob => box []
+ | _ => box [string "out[",
+ string (Int.toString i),
+ string "].buffer = &buffer",
+ string (Int.toString i),
+ string ";",
+ newline]
+ in
+ box [string "out[",
+ string (Int.toString i),
+ string "].buffer_type = ",
+ string (p_buffer_type t),
+ string ";",
+ newline,
+ string "out[",
+ string (Int.toString i),
+ string "].is_null = &is_null",
+ string (Int.toString i),
+ string ";",
+ newline,
+
+ case t of
+ Nullable t => buffers t
+ | _ => buffers t,
+ newline]
+ end) cols,
+ newline,
+
+ string "if (mysql_stmt_execute(stmt)) uw_error(ctx, FATAL, \"",
+ string (ErrorMsg.spanToString loc),
+ string ": Error executing query\");",
+ newline,
+ newline,
+
+ string "if (mysql_stmt_store_result(stmt)) uw_error(ctx, FATAL, \"",
+ string (ErrorMsg.spanToString loc),
+ string ": Error storing query result\");",
+ newline,
+ newline,
+
+ string "if (mysql_stmt_bind_result(stmt, out)) uw_error(ctx, FATAL, \"",
+ string (ErrorMsg.spanToString loc),
+ string ": Error binding query result\");",
+ newline,
+ newline,
+
+ string "uw_end_region(ctx);",
+ newline,
+ string "while ((r = mysql_stmt_fetch(stmt)) == 0) {",
+ newline,
+ doCols p_getcol,
+ string "}",
+ newline,
+ newline,
+
+ string "if (r != MYSQL_NO_DATA) uw_error(ctx, FATAL, \"",
+ string (ErrorMsg.spanToString loc),
+ string ": query result fetching failed\");",
+ newline]
+
+fun query {loc, cols, doCols} =
+ box [string "uw_conn *conn = uw_get_db(ctx);",
+ newline,
+ string "MYSQL_stmt *stmt = mysql_stmt_init(conn->conn);",
+ newline,
+ string "if (stmt == NULL) uw_error(ctx, \"",
+ string (ErrorMsg.spanToString loc),
+ string ": can't allocate temporary prepared statement\");",
+ newline,
+ string "uw_push_cleanup(ctx, (void (*)(void *))mysql_stmt_close, stmt);",
+ newline,
+ string "if (mysql_stmt_prepare(stmt, query, strlen(query))) uw_error(ctx, FATAL, \"",
+ string (ErrorMsg.spanToString loc),
+ string "\");",
+ newline,
+ newline,
+
+ p_list_sepi (box []) (fn i => fn t =>
+ let
+ fun buffers t =
+ case t of
+ String => box []
+ | Blob => box []
+ | _ => box [string "out[",
+ string (Int.toString i),
+ string "].buffer = &buffer",
+ string (Int.toString i),
+ string ";",
+ newline]
+ in
+ box [string "in[",
+ string (Int.toString i),
+ string "].buffer_type = ",
+ string (p_buffer_type t),
+ string ";",
+ newline,
+
+ case t of
+ Nullable t => box [string "in[",
+ string (Int.toString i),
+ string "].is_null = &is_null",
+ string (Int.toString i),
+ string ";",
+ newline,
+ buffers t]
+ | _ => buffers t,
+ newline]
+ end) cols,
+ newline,
+
+ queryCommon {loc = loc, cols = cols, doCols = doCols, query = string "query"},
+
+ string "uw_pop_cleanup(ctx);",
+ newline]
+
+fun p_ensql t e =
+ case t of
+ Int => box [string "uw_Basis_attrifyInt(ctx, ", e, string ")"]
+ | Float => box [string "uw_Basis_attrifyFloat(ctx, ", e, string ")"]
+ | String => e
+ | Bool => box [string "(", e, string " ? \"TRUE\" : \"FALSE\")"]
+ | Time => box [string "uw_Basis_attrifyTime(ctx, ", e, string ")"]
+ | Blob => box [e, string ".data"]
+ | Channel => box [string "uw_Basis_attrifyChannel(ctx, ", e, string ")"]
+ | Client => box [string "uw_Basis_attrifyClient(ctx, ", e, string ")"]
+ | Nullable String => e
+ | Nullable t => box [string "(",
+ e,
+ string " == NULL ? NULL : ",
+ p_ensql t (box [string "(*", e, string ")"]),
+ string ")"]
+
+fun queryPrepared {loc, id, query, inputs, cols, doCols} =
+ box [string "uw_conn *conn = uw_get_db(ctx);",
+ newline,
+ string "MYSQL_BIND in[",
+ string (Int.toString (length inputs)),
+ string "];",
+ newline,
+ p_list_sepi (box []) (fn i => fn t =>
+ let
+ fun buffers t =
+ case t of
+ String => box [string "unsigned long in_length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | Blob => box [string "unsigned long in_length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | Time => box [string (p_sql_ctype t),
+ space,
+ string "in_buffer",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | _ => box []
+ in
+ box [case t of
+ Nullable t => box [string "my_bool in_is_null",
+ string (Int.toString i),
+ string ";",
+ newline,
+ buffers t]
+ | _ => buffers t,
+ newline]
+ end) inputs,
+ string "MYSQL_STMT *stmt = conn->p",
+ string (Int.toString id),
+ string ";",
+ newline,
+ newline,
+
+ string "memset(in, 0, sizeof in);",
+ newline,
+ p_list_sepi (box []) (fn i => fn t =>
+ let
+ fun buffers t =
+ case t of
+ String => box [string "in[",
+ string (Int.toString i),
+ string "].buffer = arg",
+ string (Int.toString (i + 1)),
+ string ";",
+ newline,
+ string "in_length",
+ string (Int.toString i),
+ string "= in[",
+ string (Int.toString i),
+ string "].buffer_length = strlen(arg",
+ string (Int.toString (i + 1)),
+ string ");",
+ newline,
+ string "in[",
+ string (Int.toString i),
+ string "].length = &in_length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | Blob => box [string "in[",
+ string (Int.toString i),
+ string "].buffer = arg",
+ string (Int.toString (i + 1)),
+ string ".data;",
+ newline,
+ string "in_length",
+ string (Int.toString i),
+ string "= in[",
+ string (Int.toString i),
+ string "].buffer_length = arg",
+ string (Int.toString (i + 1)),
+ string ".size;",
+ newline,
+ string "in[",
+ string (Int.toString i),
+ string "].length = &in_length",
+ string (Int.toString i),
+ string ";",
+ newline]
+ | Time =>
+ let
+ fun oneField dst src =
+ box [string "in_buffer",
+ string (Int.toString i),
+ string ".",
+ string dst,
+ string " = tms.tm_",
+ string src,
+ string ";",
+ newline]
+ in
+ box [string "({",
+ newline,
+ string "struct tm tms;",
+ newline,
+ string "if (localtime_r(&arg",
+ string (Int.toString (i + 1)),
+ string ", &tm) == NULL) uw_error(\"",
+ string (ErrorMsg.spanToString loc),
+ string ": error converting to MySQL time\");",
+ newline,
+ oneField "year" "year",
+ oneField "month" "mon",
+ oneField "day" "mday",
+ oneField "hour" "hour",
+ oneField "minute" "min",
+ oneField "second" "sec",
+ newline,
+ string "in[",
+ string (Int.toString i),
+ string "].buffer = &in_buffer",
+ string (Int.toString i),
+ string ";",
+ newline]
+ end
+
+ | _ => box [string "in[",
+ string (Int.toString i),
+ string "].buffer = &arg",
+ string (Int.toString (i + 1)),
+ string ";",
+ newline]
+ in
+ box [string "in[",
+ string (Int.toString i),
+ string "].buffer_type = ",
+ string (p_buffer_type t),
+ string ";",
+ newline,
+
+ case t of
+ Nullable t => box [string "in[",
+ string (Int.toString i),
+ string "].is_null = &in_is_null",
+ string (Int.toString i),
+ string ";",
+ newline,
+ string "if (arg",
+ string (Int.toString (i + 1)),
+ string " == NULL) {",
+ newline,
+ box [string "in_is_null",
+ string (Int.toString i),
+ string " = 1;",
+ newline],
+ string "} else {",
+ box [case t of
+ String => box []
+ | _ =>
+ box [string (p_sql_ctype t),
+ space,
+ string "arg",
+ string (Int.toString (i + 1)),
+ string " = *arg",
+ string (Int.toString (i + 1)),
+ string ";",
+ newline],
+ string "in_is_null",
+ string (Int.toString i),
+ string " = 0;",
+ newline,
+ buffers t,
+ newline]]
+
+ | _ => buffers t,
+ newline]
+ end) inputs,
+ newline,
+
+ queryCommon {loc = loc, cols = cols, doCols = doCols, query = box [string "\"",
+ string (String.toString query),
+ string "\""]}]
+
+fun dml _ = box []
+fun dmlPrepared _ = box []
+fun nextval _ = box []
+fun nextvalPrepared _ = box []
val () = addDbms {name = "mysql",
header = "mysql/mysql.h",
@@ -276,6 +776,7 @@ val () = addDbms {name = "mysql",
string "}",
newline],
init = init,
+ p_sql_type = p_sql_type,
query = query,
queryPrepared = queryPrepared,
dml = dml,
diff --git a/src/postgres.sml b/src/postgres.sml
index 07a68607..ca71798f 100644
--- a/src/postgres.sml
+++ b/src/postgres.sml
@@ -34,6 +34,18 @@ open Print
val ident = String.translate (fn #"'" => "PRIME"
| ch => str ch)
+fun p_sql_type t =
+ case t of
+ Int => "int8"
+ | Float => "float8"
+ | String => "text"
+ | Bool => "bool"
+ | Time => "timestamp"
+ | Blob => "bytea"
+ | Channel => "int8"
+ | Client => "int4"
+ | Nullable t => p_sql_type t
+
fun p_sql_type_base t =
case t of
Int => "bigint"
@@ -540,7 +552,7 @@ fun p_getcol {wontLeakStrings, col = i, typ = t} =
getter t
end
-fun queryCommon {loc, query, numCols, doCols} =
+fun queryCommon {loc, query, cols, doCols} =
box [string "int n, i;",
newline,
newline,
@@ -564,7 +576,7 @@ fun queryCommon {loc, query, numCols, doCols} =
newline,
string "if (PQnfields(res) != ",
- string (Int.toString numCols),
+ string (Int.toString (length cols)),
string ") {",
newline,
box [string "int nf = PQnfields(res);",
@@ -574,7 +586,7 @@ fun queryCommon {loc, query, numCols, doCols} =
string "uw_error(ctx, FATAL, \"",
string (ErrorMsg.spanToString loc),
string ": Query returned %d columns instead of ",
- string (Int.toString numCols),
+ string (Int.toString (length cols)),
string ":\\n%s\\n%s\", nf, ",
query,
string ", PQerrorMessage(conn));",
@@ -598,13 +610,13 @@ fun queryCommon {loc, query, numCols, doCols} =
string "uw_pop_cleanup(ctx);",
newline]
-fun query {loc, numCols, doCols} =
+fun query {loc, cols, doCols} =
box [string "PGconn *conn = uw_get_db(ctx);",
newline,
string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);",
newline,
newline,
- queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = string "query"}]
+ queryCommon {loc = loc, cols = cols, doCols = doCols, query = string "query"}]
fun p_ensql t e =
case t of
@@ -623,7 +635,7 @@ fun p_ensql t e =
p_ensql t (box [string "(*", e, string ")"]),
string ")"]
-fun queryPrepared {loc, id, query, inputs, numCols, doCols} =
+fun queryPrepared {loc, id, query, inputs, cols, doCols} =
box [string "PGconn *conn = uw_get_db(ctx);",
newline,
string "const int paramFormats[] = { ",
@@ -662,9 +674,9 @@ fun queryPrepared {loc, id, query, inputs, numCols, doCols} =
string ", NULL, paramValues, paramLengths, paramFormats, 0);"],
newline,
newline,
- queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = box [string "\"",
- string (String.toString query),
- string "\""]}]
+ queryCommon {loc = loc, cols = cols, doCols = doCols, query = box [string "\"",
+ string (String.toString query),
+ string "\""]}]
fun dmlCommon {loc, dml} =
box [string "if (res == NULL) uw_error(ctx, FATAL, \"Out of memory allocating DML result.\");",
@@ -821,6 +833,7 @@ val () = addDbms {name = "postgres",
link = "-lpq",
global_init = box [string "void uw_client_init() { }",
newline],
+ p_sql_type = p_sql_type,
init = init,
query = query,
queryPrepared = queryPrepared,
diff --git a/src/settings.sig b/src/settings.sig
index 5406d1de..14e6338d 100644
--- a/src/settings.sig
+++ b/src/settings.sig
@@ -112,7 +112,7 @@ signature SETTINGS = sig
| Client
| Nullable of sql_type
- val p_sql_type : sql_type -> string
+ val p_sql_ctype : sql_type -> string
val isBlob : sql_type -> bool
val isNotNull : sql_type -> bool
@@ -125,18 +125,19 @@ signature SETTINGS = sig
(* Pass these linker arguments *)
global_init : Print.PD.pp_desc,
(* Define uw_client_init() *)
+ p_sql_type : sql_type -> string,
init : {dbstring : string,
prepared : (string * int) list,
tables : (string * (string * sql_type) list) list,
views : (string * (string * sql_type) list) list,
sequences : string list} -> Print.PD.pp_desc,
(* Define uw_db_init(), uw_db_close(), uw_db_begin(), uw_db_commit(), and uw_db_rollback() *)
- query : {loc : ErrorMsg.span, numCols : int,
+ query : {loc : ErrorMsg.span, cols : sql_type list,
doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
-> Print.PD.pp_desc}
-> Print.PD.pp_desc,
queryPrepared : {loc : ErrorMsg.span, id : int, query : string,
- inputs : sql_type list, numCols : int,
+ inputs : sql_type list, cols : sql_type list,
doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
-> Print.PD.pp_desc}
-> Print.PD.pp_desc,
diff --git a/src/settings.sml b/src/settings.sml
index a242768f..f2c2461d 100644
--- a/src/settings.sml
+++ b/src/settings.sml
@@ -285,7 +285,7 @@ datatype sql_type =
| Client
| Nullable of sql_type
-fun p_sql_type t =
+fun p_sql_ctype t =
let
open Print.PD
open Print
@@ -300,7 +300,7 @@ fun p_sql_type t =
| Channel => "uw_Basis_channel"
| Client => "uw_Basis_client"
| Nullable String => "uw_Basis_string"
- | Nullable t => p_sql_type t ^ "*"
+ | Nullable t => p_sql_ctype t ^ "*"
end
fun isBlob Blob = true
@@ -315,17 +315,18 @@ type dbms = {
header : string,
link : string,
global_init : Print.PD.pp_desc,
+ p_sql_type : sql_type -> string,
init : {dbstring : string,
prepared : (string * int) list,
tables : (string * (string * sql_type) list) list,
views : (string * (string * sql_type) list) list,
sequences : string list} -> Print.PD.pp_desc,
- query : {loc : ErrorMsg.span, numCols : int,
+ query : {loc : ErrorMsg.span, cols : sql_type list,
doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
-> Print.PD.pp_desc}
-> Print.PD.pp_desc,
queryPrepared : {loc : ErrorMsg.span, id : int, query : string,
- inputs : sql_type list, numCols : int,
+ inputs : sql_type list, cols : sql_type list,
doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc)
-> Print.PD.pp_desc}
-> Print.PD.pp_desc,
@@ -341,6 +342,7 @@ val curDb = ref ({name = "",
header = "",
link = "",
global_init = Print.box [],
+ p_sql_type = fn _ => "",
init = fn _ => Print.box [],
query = fn _ => Print.box [],
queryPrepared = fn _ => Print.box [],