diff options
-rw-r--r-- | src/cjr_print.sml | 51 | ||||
-rw-r--r-- | src/mysql.sml | 94 | ||||
-rw-r--r-- | src/postgres.sml | 32 | ||||
-rw-r--r-- | src/prepare.sml | 45 | ||||
-rw-r--r-- | src/settings.sig | 5 | ||||
-rw-r--r-- | src/settings.sml | 8 |
6 files changed, 165 insertions, 70 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml index 835faad5..375cc4b8 100644 --- a/src/cjr_print.sml +++ b/src/cjr_print.sml @@ -1737,41 +1737,26 @@ fun p_exp' par env (e, loc) = string "}))"] | ENextval {seq, prepared} => - let - val query = case seq of - (EPrim (Prim.String s), loc) => - (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc) - | _ => - let - val query = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc) - in - (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), query]), loc) - end - in - box [string "(uw_begin_region(ctx), ", - string "({", - newline, - string "uw_Basis_int n;", - newline, - - case prepared of - NONE => box [string "char *query = ", - p_exp env query, - string ";", - newline, - newline, + box [string "({", + newline, + string "uw_Basis_int n;", + newline, - #nextval (Settings.currentDbms ()) loc] - | SOME (id, query) => #nextvalPrepared (Settings.currentDbms ()) {loc = loc, - id = id, - query = query}, - newline, - newline, + case prepared of + NONE => #nextval (Settings.currentDbms ()) {loc = loc, + seqE = p_exp env seq, + seqName = case #1 seq of + EPrim (Prim.String s) => SOME s + | _ => NONE} + | SOME (id, query) => #nextvalPrepared (Settings.currentDbms ()) {loc = loc, + id = id, + query = query}, + newline, + newline, - string "n;", - newline, - string "}))"] - end + string "n;", + newline, + string "})"] | EUnurlify (e, t) => let diff --git a/src/mysql.sml b/src/mysql.sml index d8847424..439e8444 100644 --- a/src/mysql.sml +++ b/src/mysql.sml @@ -907,6 +907,36 @@ fun queryPrepared {loc, id, query, inputs, cols, doCols} = newline, newline, + string "if (stmt == NULL) {", + newline, + box [string "stmt = mysql_stmt_init(conn->conn);", + newline, + string "if (stmt == NULL) uw_error(ctx, FATAL, \"Out of memory allocating prepared statement\");", + newline, + string "if (mysql_stmt_prepare(stmt, \"", + string (String.toString query), + string "\", ", + string (Int.toString (size query)), + string ")) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, mysql_stmt_error(stmt), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "uw_error(ctx, FATAL, \"Error preparing statement: %s\", msg);", + newline], + string "}", + newline, + string "conn->p", + string (Int.toString id), + string " = stmt;", + newline], + string "}", + newline, + newline, + string "memset(in, 0, sizeof in);", newline, p_list_sepi (box []) (fn i => fn t => @@ -1129,6 +1159,36 @@ fun dmlPrepared {loc, id, dml, inputs} = newline, newline, + string "if (stmt == NULL) {", + newline, + box [string "stmt = mysql_stmt_init(conn->conn);", + newline, + string "if (stmt == NULL) uw_error(ctx, FATAL, \"Out of memory allocating prepared statement\");", + newline, + string "if (mysql_stmt_prepare(stmt, \"", + string (String.toString dml), + string "\", ", + string (Int.toString (size dml)), + string ")) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, mysql_stmt_error(stmt), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "uw_error(ctx, FATAL, \"Error preparing statement: %s\", msg);", + newline], + string "}", + newline, + string "conn->p", + string (Int.toString id), + string " = stmt;", + newline], + string "}", + newline, + newline, + string "memset(in, 0, sizeof in);", newline, p_list_sepi (box []) (fn i => fn t => @@ -1280,8 +1340,35 @@ fun dmlPrepared {loc, id, dml, inputs} = string (String.toString dml), string "\""]}] -fun nextval _ = box [] -fun nextvalPrepared _ = box [] +fun nextval {loc, seqE, seqName} = + box [string "uw_conn *conn = uw_get_db(ctx);", + newline, + string "char *insert = ", + case seqName of + SOME s => string ("\"INSERT INTO " ^ s ^ " VALUES ()\"") + | NONE => box [string "uw_Basis_strcat(ctx, \"INSERT INTO \", uw_Basis_strcat(ctx, ", + seqE, + string ", \" VALUES ()\"))"], + string ";", + newline, + string "char *delete = ", + case seqName of + SOME s => string ("\"DELETE FROM " ^ s ^ "\"") + | NONE => box [string "uw_Basis_strcat(ctx, \"DELETE FROM \", ", + seqE, + string ")"], + string ";", + newline, + newline, + + string "if (mysql_query(conn->conn, insert)) uw_error(ctx, FATAL, \"'nextval' INSERT failed\");", + newline, + string "n = mysql_insert_id(conn->conn);", + newline, + string "if (mysql_query(conn->conn, delete)) uw_error(ctx, FATAL, \"'nextval' DELETE failed\");", + newline] + +fun nextvalPrepared _ = raise Fail "MySQL.nextvalPrepared called" fun sqlifyString s = "'" ^ String.translate (fn #"'" => "\\'" | #"\\" => "\\\\" @@ -1314,6 +1401,7 @@ val () = addDbms {name = "mysql", p_blank = p_blank, supportsDeleteAs = false, createSequence = fn s => "CREATE TABLE " ^ s ^ " (id INTEGER PRIMARY KEY AUTO_INCREMENT)", - textKeysNeedLengths = true} + textKeysNeedLengths = true, + supportsNextval = false} end diff --git a/src/postgres.sml b/src/postgres.sml index 26825363..24166258 100644 --- a/src/postgres.sml +++ b/src/postgres.sml @@ -805,13 +805,28 @@ fun nextvalCommon {loc, query} = string "PQclear(res);", newline] -fun nextval loc = - box [string "PGconn *conn = uw_get_db(ctx);", - newline, - string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);", - newline, - newline, - nextvalCommon {loc = loc, query = string "query"}] +open Cjr + +fun nextval {loc, seqE, seqName} = + let + val query = case seqName of + SOME s => + string ("SELECT NEXTVAL('" ^ s ^ "')") + | _ => box [string "uw_Basis_strcat(ctx, \"SELECT NEXTVAL('\", uw_Basis_strcat(ctx, ", + seqE, + string ", \"')\"))"] + in + box [string "char *query = ", + query, + string ";", + newline, + string "PGconn *conn = uw_get_db(ctx);", + newline, + string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);", + newline, + newline, + nextvalCommon {loc = loc, query = string "query"}] + end fun nextvalPrepared {loc, id, query} = box [string "PGconn *conn = uw_get_db(ctx);", @@ -862,7 +877,8 @@ val () = addDbms {name = "postgres", p_blank = p_blank, supportsDeleteAs = true, createSequence = fn s => "CREATE SEQUENCE " ^ s, - textKeysNeedLengths = false} + textKeysNeedLengths = false, + supportsNextval = true} val () = setDbms "postgres" diff --git a/src/prepare.sml b/src/prepare.sml index 0a8ca7a2..29def780 100644 --- a/src/prepare.sml +++ b/src/prepare.sml @@ -216,27 +216,30 @@ fun prepExp (e as (_, loc), sns) = end) | ENextval {seq, ...} => - let - val s = case seq of - (EPrim (Prim.String s), loc) => - (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc) - | _ => - let - val s' = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc) - in - (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), s']), loc) - end - in - case prepString (s, [], 0) of - NONE => (e, sns) - | SOME (ss, n) => - let - val s = String.concat (rev ss) - in - ((ENextval {seq = seq, prepared = SOME (#2 sns, s)}, loc), - ((s, n) :: #1 sns, #2 sns + 1)) - end - end + if #supportsNextval (Settings.currentDbms ()) then + let + val s = case seq of + (EPrim (Prim.String s), loc) => + (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc) + | _ => + let + val s' = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc) + in + (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), s']), loc) + end + in + case prepString (s, [], 0) of + NONE => (e, sns) + | SOME (ss, n) => + let + val s = String.concat (rev ss) + in + ((ENextval {seq = seq, prepared = SOME (#2 sns, s)}, loc), + ((s, n) :: #1 sns, #2 sns + 1)) + end + end + else + (e, sns) | EUnurlify (e, t) => let diff --git a/src/settings.sig b/src/settings.sig index 873bbcb9..c7855856 100644 --- a/src/settings.sig +++ b/src/settings.sig @@ -142,14 +142,15 @@ signature SETTINGS = sig dml : ErrorMsg.span -> Print.PD.pp_desc, dmlPrepared : {loc : ErrorMsg.span, id : int, dml : string, inputs : sql_type list} -> Print.PD.pp_desc, - nextval : ErrorMsg.span -> Print.PD.pp_desc, + nextval : {loc : ErrorMsg.span, seqE : Print.PD.pp_desc, seqName : string option} -> Print.PD.pp_desc, nextvalPrepared : {loc : ErrorMsg.span, id : int, query : string} -> Print.PD.pp_desc, sqlifyString : string -> string, p_cast : string * sql_type -> string, p_blank : int * sql_type -> string (* Prepared statement input *), supportsDeleteAs : bool, createSequence : string -> string, - textKeysNeedLengths : bool + textKeysNeedLengths : bool, + supportsNextval : bool } val addDbms : dbms -> unit diff --git a/src/settings.sml b/src/settings.sml index 99fa748d..7393013e 100644 --- a/src/settings.sml +++ b/src/settings.sml @@ -332,14 +332,15 @@ type dbms = { dml : ErrorMsg.span -> Print.PD.pp_desc, dmlPrepared : {loc : ErrorMsg.span, id : int, dml : string, inputs : sql_type list} -> Print.PD.pp_desc, - nextval : ErrorMsg.span -> Print.PD.pp_desc, + nextval : {loc : ErrorMsg.span, seqName : string option, seqE : Print.PD.pp_desc} -> Print.PD.pp_desc, nextvalPrepared : {loc : ErrorMsg.span, id : int, query : string} -> Print.PD.pp_desc, sqlifyString : string -> string, p_cast : string * sql_type -> string, p_blank : int * sql_type -> string, supportsDeleteAs : bool, createSequence : string -> string, - textKeysNeedLengths : bool + textKeysNeedLengths : bool, + supportsNextval : bool } val dbmses = ref ([] : dbms list) @@ -359,7 +360,8 @@ val curDb = ref ({name = "", p_blank = fn _ => "", supportsDeleteAs = false, createSequence = fn _ => "", - textKeysNeedLengths = false} : dbms) + textKeysNeedLengths = false, + supportsNextval = false} : dbms) fun addDbms v = dbmses := v :: !dbmses fun setDbms s = |