diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cjr_print.sml | 7 | ||||
-rw-r--r-- | src/mysql.sml | 513 | ||||
-rw-r--r-- | src/postgres.sml | 31 | ||||
-rw-r--r-- | src/settings.sig | 7 | ||||
-rw-r--r-- | src/settings.sml | 10 |
5 files changed, 543 insertions, 25 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml index 7d1120b4..fcfa402e 100644 --- a/src/cjr_print.sml +++ b/src/cjr_print.sml @@ -1652,7 +1652,7 @@ fun p_exp' par env (e, loc) = #query (Settings.currentDbms ()) {loc = loc, - numCols = length outputs, + cols = map (fn (_, t) => sql_type_in env t) outputs, doCols = doCols}] | SOME (id, query) => box [p_list_sepi newline @@ -1675,7 +1675,7 @@ fun p_exp' par env (e, loc) = id = id, query = query, inputs = map #2 inputs, - numCols = length outputs, + cols = map (fn (_, t) => sql_type_in env t) outputs, doCols = doCols}], newline, @@ -2797,7 +2797,8 @@ fun p_sql env (ds, _) = box [string "uw_", string (CharVector.map Char.toLower x), space, - p_sqltype env (t, ErrorMsg.dummySpan)]) xts, + string (#p_sql_type (Settings.currentDbms ()) + (sql_type_in env t))]) xts, case (pk, csts) of ("", []) => box [] | _ => string ",", diff --git a/src/mysql.sml b/src/mysql.sml index 7b02c787..2fcdef2d 100644 --- a/src/mysql.sml +++ b/src/mysql.sml @@ -31,6 +31,30 @@ open Settings open Print.PD open Print +fun p_sql_type t = + case t of + Int => "bigint" + | Float => "double" + | String => "longtext" + | Bool => "bool" + | Time => "timestamp" + | Blob => "longblob" + | Channel => "bigint" + | Client => "int" + | Nullable t => p_sql_type t + +fun p_buffer_type t = + case t of + Int => "MYSQL_TYPE_LONGLONG" + | Float => "MYSQL_TYPE_DOUBLE" + | String => "MYSQL_TYPE_STRING" + | Bool => "MYSQL_TYPE_LONG" + | Time => "MYSQL_TYPE_TIME" + | Blob => "MYSQL_TYPE_BLOB" + | Channel => "MYSQL_TYPE_LONGLONG" + | Client => "MYSQL_TYPE_LONG" + | Nullable t => p_buffer_type t + fun init {dbstring, prepared = ss, tables, views, sequences} = let val host = ref NONE @@ -138,6 +162,10 @@ fun init {dbstring, prepared = ss, tables, views, sequences} = newline, uhoh true "Error preparing statement: %s" ["msg"]], string "}", + newline, + string "conn->p", + string (Int.toString i), + string " = stmt;", newline] end) ss, @@ -253,12 +281,484 @@ fun init {dbstring, prepared = ss, tables, views, sequences} = newline] end -fun query _ = raise Fail "MySQL query" -fun queryPrepared _ = raise Fail "MySQL queryPrepared" -fun dml _ = raise Fail "MySQL dml" -fun dmlPrepared _ = raise Fail "MySQL dmlPrepared" -fun nextval _ = raise Fail "MySQL nextval" -fun nextvalPrepared _ = raise Fail "MySQL nextvalPrepared" +fun p_getcol {wontLeakStrings = _, col = i, typ = t} = + let + fun getter t = + case t of + String => box [string "({", + newline, + string "uw_Basis_string s = uw_malloc(ctx, length", + string (Int.toString i), + string " + 1);", + newline, + string "out[", + string (Int.toString i), + string "].buffer = s;", + newline, + string "out[", + string (Int.toString i), + string "].buffer_length = length", + string (Int.toString i), + string " + 1;", + newline, + string "mysql_stmt_fetch_column(stmt, &out[", + string (Int.toString i), + string "], ", + string (Int.toString i), + string ", 0);", + newline, + string "s[length", + string (Int.toString i), + string "] = 0;", + newline, + string "s;", + newline, + string "})"] + | Blob => box [string "({", + newline, + string "uw_Basis_blob b = {length", + string (Int.toString i), + string ", uw_malloc(ctx, length", + string (Int.toString i), + string ")};", + newline, + string "out[", + string (Int.toString i), + string "].buffer = b.data;", + newline, + string "out[", + string (Int.toString i), + string "].buffer_length = length", + string (Int.toString i), + string ";", + newline, + string "mysql_stmt_fetch_column(stmt, &out[", + string (Int.toString i), + string "], ", + string (Int.toString i), + string ", 0);", + newline, + string "b;", + newline, + string "})"] + | Time => box [string "({", + string "MYSQL_TIME *mt = buffer", + string (Int.toString i), + string ";", + newline, + newline, + string "struct tm t = {mt->second, mt->minute, mt->hour, mt->day, mt->month, mt->year, 0, 0, -1};", + newline, + string "mktime(&tm);", + newline, + string "})"] + | _ => box [string "buffer", + string (Int.toString i)] + in + case t of + Nullable t => box [string "(is_null", + string (Int.toString i), + string " ? NULL : ", + case t of + String => getter t + | _ => box [string "({", + newline, + string (p_sql_ctype t), + space, + string "*tmp = uw_malloc(ctx, sizeof(", + string (p_sql_ctype t), + string "));", + newline, + string "*tmp = ", + getter t, + string ";", + newline, + string "tmp;", + newline, + string "})"], + string ")"] + | _ => box [string "(is_null", + string (Int.toString i), + string " ? ", + box [string "({", + string (p_sql_ctype t), + space, + string "tmp;", + newline, + string "uw_error(ctx, FATAL, \"Unexpectedly NULL field #", + string (Int.toString i), + string "\");", + newline, + string "tmp;", + newline, + string "})"], + string " : ", + getter t, + string ")"] + end + +fun queryCommon {loc, query, cols, doCols} = + box [string "int n, r;", + newline, + string "MYSQL_BIND out[", + string (Int.toString (length cols)), + string "];", + newline, + p_list_sepi (box []) (fn i => fn t => + let + fun buffers t = + case t of + String => box [string "unsigned long length", + string (Int.toString i), + string ";", + newline] + | Blob => box [string "unsigned long length", + string (Int.toString i), + string ";", + newline] + | _ => box [string (p_sql_ctype t), + space, + string "buffer", + string (Int.toString i), + string ";", + newline] + in + box [string "my_bool is_null", + string (Int.toString i), + string ";", + newline, + case t of + Nullable t => buffers t + | _ => buffers t, + newline] + end) cols, + newline, + + string "memset(out, 0, sizeof out);", + newline, + p_list_sepi (box []) (fn i => fn t => + let + fun buffers t = + case t of + String => box [] + | Blob => box [] + | _ => box [string "out[", + string (Int.toString i), + string "].buffer = &buffer", + string (Int.toString i), + string ";", + newline] + in + box [string "out[", + string (Int.toString i), + string "].buffer_type = ", + string (p_buffer_type t), + string ";", + newline, + string "out[", + string (Int.toString i), + string "].is_null = &is_null", + string (Int.toString i), + string ";", + newline, + + case t of + Nullable t => buffers t + | _ => buffers t, + newline] + end) cols, + newline, + + string "if (mysql_stmt_execute(stmt)) uw_error(ctx, FATAL, \"", + string (ErrorMsg.spanToString loc), + string ": Error executing query\");", + newline, + newline, + + string "if (mysql_stmt_store_result(stmt)) uw_error(ctx, FATAL, \"", + string (ErrorMsg.spanToString loc), + string ": Error storing query result\");", + newline, + newline, + + string "if (mysql_stmt_bind_result(stmt, out)) uw_error(ctx, FATAL, \"", + string (ErrorMsg.spanToString loc), + string ": Error binding query result\");", + newline, + newline, + + string "uw_end_region(ctx);", + newline, + string "while ((r = mysql_stmt_fetch(stmt)) == 0) {", + newline, + doCols p_getcol, + string "}", + newline, + newline, + + string "if (r != MYSQL_NO_DATA) uw_error(ctx, FATAL, \"", + string (ErrorMsg.spanToString loc), + string ": query result fetching failed\");", + newline] + +fun query {loc, cols, doCols} = + box [string "uw_conn *conn = uw_get_db(ctx);", + newline, + string "MYSQL_stmt *stmt = mysql_stmt_init(conn->conn);", + newline, + string "if (stmt == NULL) uw_error(ctx, \"", + string (ErrorMsg.spanToString loc), + string ": can't allocate temporary prepared statement\");", + newline, + string "uw_push_cleanup(ctx, (void (*)(void *))mysql_stmt_close, stmt);", + newline, + string "if (mysql_stmt_prepare(stmt, query, strlen(query))) uw_error(ctx, FATAL, \"", + string (ErrorMsg.spanToString loc), + string "\");", + newline, + newline, + + p_list_sepi (box []) (fn i => fn t => + let + fun buffers t = + case t of + String => box [] + | Blob => box [] + | _ => box [string "out[", + string (Int.toString i), + string "].buffer = &buffer", + string (Int.toString i), + string ";", + newline] + in + box [string "in[", + string (Int.toString i), + string "].buffer_type = ", + string (p_buffer_type t), + string ";", + newline, + + case t of + Nullable t => box [string "in[", + string (Int.toString i), + string "].is_null = &is_null", + string (Int.toString i), + string ";", + newline, + buffers t] + | _ => buffers t, + newline] + end) cols, + newline, + + queryCommon {loc = loc, cols = cols, doCols = doCols, query = string "query"}, + + string "uw_pop_cleanup(ctx);", + newline] + +fun p_ensql t e = + case t of + Int => box [string "uw_Basis_attrifyInt(ctx, ", e, string ")"] + | Float => box [string "uw_Basis_attrifyFloat(ctx, ", e, string ")"] + | String => e + | Bool => box [string "(", e, string " ? \"TRUE\" : \"FALSE\")"] + | Time => box [string "uw_Basis_attrifyTime(ctx, ", e, string ")"] + | Blob => box [e, string ".data"] + | Channel => box [string "uw_Basis_attrifyChannel(ctx, ", e, string ")"] + | Client => box [string "uw_Basis_attrifyClient(ctx, ", e, string ")"] + | Nullable String => e + | Nullable t => box [string "(", + e, + string " == NULL ? NULL : ", + p_ensql t (box [string "(*", e, string ")"]), + string ")"] + +fun queryPrepared {loc, id, query, inputs, cols, doCols} = + box [string "uw_conn *conn = uw_get_db(ctx);", + newline, + string "MYSQL_BIND in[", + string (Int.toString (length inputs)), + string "];", + newline, + p_list_sepi (box []) (fn i => fn t => + let + fun buffers t = + case t of + String => box [string "unsigned long in_length", + string (Int.toString i), + string ";", + newline] + | Blob => box [string "unsigned long in_length", + string (Int.toString i), + string ";", + newline] + | Time => box [string (p_sql_ctype t), + space, + string "in_buffer", + string (Int.toString i), + string ";", + newline] + | _ => box [] + in + box [case t of + Nullable t => box [string "my_bool in_is_null", + string (Int.toString i), + string ";", + newline, + buffers t] + | _ => buffers t, + newline] + end) inputs, + string "MYSQL_STMT *stmt = conn->p", + string (Int.toString id), + string ";", + newline, + newline, + + string "memset(in, 0, sizeof in);", + newline, + p_list_sepi (box []) (fn i => fn t => + let + fun buffers t = + case t of + String => box [string "in[", + string (Int.toString i), + string "].buffer = arg", + string (Int.toString (i + 1)), + string ";", + newline, + string "in_length", + string (Int.toString i), + string "= in[", + string (Int.toString i), + string "].buffer_length = strlen(arg", + string (Int.toString (i + 1)), + string ");", + newline, + string "in[", + string (Int.toString i), + string "].length = &in_length", + string (Int.toString i), + string ";", + newline] + | Blob => box [string "in[", + string (Int.toString i), + string "].buffer = arg", + string (Int.toString (i + 1)), + string ".data;", + newline, + string "in_length", + string (Int.toString i), + string "= in[", + string (Int.toString i), + string "].buffer_length = arg", + string (Int.toString (i + 1)), + string ".size;", + newline, + string "in[", + string (Int.toString i), + string "].length = &in_length", + string (Int.toString i), + string ";", + newline] + | Time => + let + fun oneField dst src = + box [string "in_buffer", + string (Int.toString i), + string ".", + string dst, + string " = tms.tm_", + string src, + string ";", + newline] + in + box [string "({", + newline, + string "struct tm tms;", + newline, + string "if (localtime_r(&arg", + string (Int.toString (i + 1)), + string ", &tm) == NULL) uw_error(\"", + string (ErrorMsg.spanToString loc), + string ": error converting to MySQL time\");", + newline, + oneField "year" "year", + oneField "month" "mon", + oneField "day" "mday", + oneField "hour" "hour", + oneField "minute" "min", + oneField "second" "sec", + newline, + string "in[", + string (Int.toString i), + string "].buffer = &in_buffer", + string (Int.toString i), + string ";", + newline] + end + + | _ => box [string "in[", + string (Int.toString i), + string "].buffer = &arg", + string (Int.toString (i + 1)), + string ";", + newline] + in + box [string "in[", + string (Int.toString i), + string "].buffer_type = ", + string (p_buffer_type t), + string ";", + newline, + + case t of + Nullable t => box [string "in[", + string (Int.toString i), + string "].is_null = &in_is_null", + string (Int.toString i), + string ";", + newline, + string "if (arg", + string (Int.toString (i + 1)), + string " == NULL) {", + newline, + box [string "in_is_null", + string (Int.toString i), + string " = 1;", + newline], + string "} else {", + box [case t of + String => box [] + | _ => + box [string (p_sql_ctype t), + space, + string "arg", + string (Int.toString (i + 1)), + string " = *arg", + string (Int.toString (i + 1)), + string ";", + newline], + string "in_is_null", + string (Int.toString i), + string " = 0;", + newline, + buffers t, + newline]] + + | _ => buffers t, + newline] + end) inputs, + newline, + + queryCommon {loc = loc, cols = cols, doCols = doCols, query = box [string "\"", + string (String.toString query), + string "\""]}] + +fun dml _ = box [] +fun dmlPrepared _ = box [] +fun nextval _ = box [] +fun nextvalPrepared _ = box [] val () = addDbms {name = "mysql", header = "mysql/mysql.h", @@ -276,6 +776,7 @@ val () = addDbms {name = "mysql", string "}", newline], init = init, + p_sql_type = p_sql_type, query = query, queryPrepared = queryPrepared, dml = dml, diff --git a/src/postgres.sml b/src/postgres.sml index 07a68607..ca71798f 100644 --- a/src/postgres.sml +++ b/src/postgres.sml @@ -34,6 +34,18 @@ open Print val ident = String.translate (fn #"'" => "PRIME" | ch => str ch) +fun p_sql_type t = + case t of + Int => "int8" + | Float => "float8" + | String => "text" + | Bool => "bool" + | Time => "timestamp" + | Blob => "bytea" + | Channel => "int8" + | Client => "int4" + | Nullable t => p_sql_type t + fun p_sql_type_base t = case t of Int => "bigint" @@ -540,7 +552,7 @@ fun p_getcol {wontLeakStrings, col = i, typ = t} = getter t end -fun queryCommon {loc, query, numCols, doCols} = +fun queryCommon {loc, query, cols, doCols} = box [string "int n, i;", newline, newline, @@ -564,7 +576,7 @@ fun queryCommon {loc, query, numCols, doCols} = newline, string "if (PQnfields(res) != ", - string (Int.toString numCols), + string (Int.toString (length cols)), string ") {", newline, box [string "int nf = PQnfields(res);", @@ -574,7 +586,7 @@ fun queryCommon {loc, query, numCols, doCols} = string "uw_error(ctx, FATAL, \"", string (ErrorMsg.spanToString loc), string ": Query returned %d columns instead of ", - string (Int.toString numCols), + string (Int.toString (length cols)), string ":\\n%s\\n%s\", nf, ", query, string ", PQerrorMessage(conn));", @@ -598,13 +610,13 @@ fun queryCommon {loc, query, numCols, doCols} = string "uw_pop_cleanup(ctx);", newline] -fun query {loc, numCols, doCols} = +fun query {loc, cols, doCols} = box [string "PGconn *conn = uw_get_db(ctx);", newline, string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);", newline, newline, - queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = string "query"}] + queryCommon {loc = loc, cols = cols, doCols = doCols, query = string "query"}] fun p_ensql t e = case t of @@ -623,7 +635,7 @@ fun p_ensql t e = p_ensql t (box [string "(*", e, string ")"]), string ")"] -fun queryPrepared {loc, id, query, inputs, numCols, doCols} = +fun queryPrepared {loc, id, query, inputs, cols, doCols} = box [string "PGconn *conn = uw_get_db(ctx);", newline, string "const int paramFormats[] = { ", @@ -662,9 +674,9 @@ fun queryPrepared {loc, id, query, inputs, numCols, doCols} = string ", NULL, paramValues, paramLengths, paramFormats, 0);"], newline, newline, - queryCommon {loc = loc, numCols = numCols, doCols = doCols, query = box [string "\"", - string (String.toString query), - string "\""]}] + queryCommon {loc = loc, cols = cols, doCols = doCols, query = box [string "\"", + string (String.toString query), + string "\""]}] fun dmlCommon {loc, dml} = box [string "if (res == NULL) uw_error(ctx, FATAL, \"Out of memory allocating DML result.\");", @@ -821,6 +833,7 @@ val () = addDbms {name = "postgres", link = "-lpq", global_init = box [string "void uw_client_init() { }", newline], + p_sql_type = p_sql_type, init = init, query = query, queryPrepared = queryPrepared, diff --git a/src/settings.sig b/src/settings.sig index 5406d1de..14e6338d 100644 --- a/src/settings.sig +++ b/src/settings.sig @@ -112,7 +112,7 @@ signature SETTINGS = sig | Client | Nullable of sql_type - val p_sql_type : sql_type -> string + val p_sql_ctype : sql_type -> string val isBlob : sql_type -> bool val isNotNull : sql_type -> bool @@ -125,18 +125,19 @@ signature SETTINGS = sig (* Pass these linker arguments *) global_init : Print.PD.pp_desc, (* Define uw_client_init() *) + p_sql_type : sql_type -> string, init : {dbstring : string, prepared : (string * int) list, tables : (string * (string * sql_type) list) list, views : (string * (string * sql_type) list) list, sequences : string list} -> Print.PD.pp_desc, (* Define uw_db_init(), uw_db_close(), uw_db_begin(), uw_db_commit(), and uw_db_rollback() *) - query : {loc : ErrorMsg.span, numCols : int, + query : {loc : ErrorMsg.span, cols : sql_type list, doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc) -> Print.PD.pp_desc} -> Print.PD.pp_desc, queryPrepared : {loc : ErrorMsg.span, id : int, query : string, - inputs : sql_type list, numCols : int, + inputs : sql_type list, cols : sql_type list, doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc) -> Print.PD.pp_desc} -> Print.PD.pp_desc, diff --git a/src/settings.sml b/src/settings.sml index a242768f..f2c2461d 100644 --- a/src/settings.sml +++ b/src/settings.sml @@ -285,7 +285,7 @@ datatype sql_type = | Client | Nullable of sql_type -fun p_sql_type t = +fun p_sql_ctype t = let open Print.PD open Print @@ -300,7 +300,7 @@ fun p_sql_type t = | Channel => "uw_Basis_channel" | Client => "uw_Basis_client" | Nullable String => "uw_Basis_string" - | Nullable t => p_sql_type t ^ "*" + | Nullable t => p_sql_ctype t ^ "*" end fun isBlob Blob = true @@ -315,17 +315,18 @@ type dbms = { header : string, link : string, global_init : Print.PD.pp_desc, + p_sql_type : sql_type -> string, init : {dbstring : string, prepared : (string * int) list, tables : (string * (string * sql_type) list) list, views : (string * (string * sql_type) list) list, sequences : string list} -> Print.PD.pp_desc, - query : {loc : ErrorMsg.span, numCols : int, + query : {loc : ErrorMsg.span, cols : sql_type list, doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc) -> Print.PD.pp_desc} -> Print.PD.pp_desc, queryPrepared : {loc : ErrorMsg.span, id : int, query : string, - inputs : sql_type list, numCols : int, + inputs : sql_type list, cols : sql_type list, doCols : ({wontLeakStrings : bool, col : int, typ : sql_type} -> Print.PD.pp_desc) -> Print.PD.pp_desc} -> Print.PD.pp_desc, @@ -341,6 +342,7 @@ val curDb = ref ({name = "", header = "", link = "", global_init = Print.box [], + p_sql_type = fn _ => "", init = fn _ => Print.box [], query = fn _ => Print.box [], queryPrepared = fn _ => Print.box [], |