summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cjr_print.sml51
-rw-r--r--src/mysql.sml94
-rw-r--r--src/postgres.sml32
-rw-r--r--src/prepare.sml45
-rw-r--r--src/settings.sig5
-rw-r--r--src/settings.sml8
6 files changed, 165 insertions, 70 deletions
diff --git a/src/cjr_print.sml b/src/cjr_print.sml
index 835faad5..375cc4b8 100644
--- a/src/cjr_print.sml
+++ b/src/cjr_print.sml
@@ -1737,41 +1737,26 @@ fun p_exp' par env (e, loc) =
string "}))"]
| ENextval {seq, prepared} =>
- let
- val query = case seq of
- (EPrim (Prim.String s), loc) =>
- (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc)
- | _ =>
- let
- val query = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc)
- in
- (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), query]), loc)
- end
- in
- box [string "(uw_begin_region(ctx), ",
- string "({",
- newline,
- string "uw_Basis_int n;",
- newline,
-
- case prepared of
- NONE => box [string "char *query = ",
- p_exp env query,
- string ";",
- newline,
- newline,
+ box [string "({",
+ newline,
+ string "uw_Basis_int n;",
+ newline,
- #nextval (Settings.currentDbms ()) loc]
- | SOME (id, query) => #nextvalPrepared (Settings.currentDbms ()) {loc = loc,
- id = id,
- query = query},
- newline,
- newline,
+ case prepared of
+ NONE => #nextval (Settings.currentDbms ()) {loc = loc,
+ seqE = p_exp env seq,
+ seqName = case #1 seq of
+ EPrim (Prim.String s) => SOME s
+ | _ => NONE}
+ | SOME (id, query) => #nextvalPrepared (Settings.currentDbms ()) {loc = loc,
+ id = id,
+ query = query},
+ newline,
+ newline,
- string "n;",
- newline,
- string "}))"]
- end
+ string "n;",
+ newline,
+ string "})"]
| EUnurlify (e, t) =>
let
diff --git a/src/mysql.sml b/src/mysql.sml
index d8847424..439e8444 100644
--- a/src/mysql.sml
+++ b/src/mysql.sml
@@ -907,6 +907,36 @@ fun queryPrepared {loc, id, query, inputs, cols, doCols} =
newline,
newline,
+ string "if (stmt == NULL) {",
+ newline,
+ box [string "stmt = mysql_stmt_init(conn->conn);",
+ newline,
+ string "if (stmt == NULL) uw_error(ctx, FATAL, \"Out of memory allocating prepared statement\");",
+ newline,
+ string "if (mysql_stmt_prepare(stmt, \"",
+ string (String.toString query),
+ string "\", ",
+ string (Int.toString (size query)),
+ string ")) {",
+ newline,
+ box [string "char msg[1024];",
+ newline,
+ string "strncpy(msg, mysql_stmt_error(stmt), 1024);",
+ newline,
+ string "msg[1023] = 0;",
+ newline,
+ string "uw_error(ctx, FATAL, \"Error preparing statement: %s\", msg);",
+ newline],
+ string "}",
+ newline,
+ string "conn->p",
+ string (Int.toString id),
+ string " = stmt;",
+ newline],
+ string "}",
+ newline,
+ newline,
+
string "memset(in, 0, sizeof in);",
newline,
p_list_sepi (box []) (fn i => fn t =>
@@ -1129,6 +1159,36 @@ fun dmlPrepared {loc, id, dml, inputs} =
newline,
newline,
+ string "if (stmt == NULL) {",
+ newline,
+ box [string "stmt = mysql_stmt_init(conn->conn);",
+ newline,
+ string "if (stmt == NULL) uw_error(ctx, FATAL, \"Out of memory allocating prepared statement\");",
+ newline,
+ string "if (mysql_stmt_prepare(stmt, \"",
+ string (String.toString dml),
+ string "\", ",
+ string (Int.toString (size dml)),
+ string ")) {",
+ newline,
+ box [string "char msg[1024];",
+ newline,
+ string "strncpy(msg, mysql_stmt_error(stmt), 1024);",
+ newline,
+ string "msg[1023] = 0;",
+ newline,
+ string "uw_error(ctx, FATAL, \"Error preparing statement: %s\", msg);",
+ newline],
+ string "}",
+ newline,
+ string "conn->p",
+ string (Int.toString id),
+ string " = stmt;",
+ newline],
+ string "}",
+ newline,
+ newline,
+
string "memset(in, 0, sizeof in);",
newline,
p_list_sepi (box []) (fn i => fn t =>
@@ -1280,8 +1340,35 @@ fun dmlPrepared {loc, id, dml, inputs} =
string (String.toString dml),
string "\""]}]
-fun nextval _ = box []
-fun nextvalPrepared _ = box []
+fun nextval {loc, seqE, seqName} =
+ box [string "uw_conn *conn = uw_get_db(ctx);",
+ newline,
+ string "char *insert = ",
+ case seqName of
+ SOME s => string ("\"INSERT INTO " ^ s ^ " VALUES ()\"")
+ | NONE => box [string "uw_Basis_strcat(ctx, \"INSERT INTO \", uw_Basis_strcat(ctx, ",
+ seqE,
+ string ", \" VALUES ()\"))"],
+ string ";",
+ newline,
+ string "char *delete = ",
+ case seqName of
+ SOME s => string ("\"DELETE FROM " ^ s ^ "\"")
+ | NONE => box [string "uw_Basis_strcat(ctx, \"DELETE FROM \", ",
+ seqE,
+ string ")"],
+ string ";",
+ newline,
+ newline,
+
+ string "if (mysql_query(conn->conn, insert)) uw_error(ctx, FATAL, \"'nextval' INSERT failed\");",
+ newline,
+ string "n = mysql_insert_id(conn->conn);",
+ newline,
+ string "if (mysql_query(conn->conn, delete)) uw_error(ctx, FATAL, \"'nextval' DELETE failed\");",
+ newline]
+
+fun nextvalPrepared _ = raise Fail "MySQL.nextvalPrepared called"
fun sqlifyString s = "'" ^ String.translate (fn #"'" => "\\'"
| #"\\" => "\\\\"
@@ -1314,6 +1401,7 @@ val () = addDbms {name = "mysql",
p_blank = p_blank,
supportsDeleteAs = false,
createSequence = fn s => "CREATE TABLE " ^ s ^ " (id INTEGER PRIMARY KEY AUTO_INCREMENT)",
- textKeysNeedLengths = true}
+ textKeysNeedLengths = true,
+ supportsNextval = false}
end
diff --git a/src/postgres.sml b/src/postgres.sml
index 26825363..24166258 100644
--- a/src/postgres.sml
+++ b/src/postgres.sml
@@ -805,13 +805,28 @@ fun nextvalCommon {loc, query} =
string "PQclear(res);",
newline]
-fun nextval loc =
- box [string "PGconn *conn = uw_get_db(ctx);",
- newline,
- string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);",
- newline,
- newline,
- nextvalCommon {loc = loc, query = string "query"}]
+open Cjr
+
+fun nextval {loc, seqE, seqName} =
+ let
+ val query = case seqName of
+ SOME s =>
+ string ("SELECT NEXTVAL('" ^ s ^ "')")
+ | _ => box [string "uw_Basis_strcat(ctx, \"SELECT NEXTVAL('\", uw_Basis_strcat(ctx, ",
+ seqE,
+ string ", \"')\"))"]
+ in
+ box [string "char *query = ",
+ query,
+ string ";",
+ newline,
+ string "PGconn *conn = uw_get_db(ctx);",
+ newline,
+ string "PGresult *res = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);",
+ newline,
+ newline,
+ nextvalCommon {loc = loc, query = string "query"}]
+ end
fun nextvalPrepared {loc, id, query} =
box [string "PGconn *conn = uw_get_db(ctx);",
@@ -862,7 +877,8 @@ val () = addDbms {name = "postgres",
p_blank = p_blank,
supportsDeleteAs = true,
createSequence = fn s => "CREATE SEQUENCE " ^ s,
- textKeysNeedLengths = false}
+ textKeysNeedLengths = false,
+ supportsNextval = true}
val () = setDbms "postgres"
diff --git a/src/prepare.sml b/src/prepare.sml
index 0a8ca7a2..29def780 100644
--- a/src/prepare.sml
+++ b/src/prepare.sml
@@ -216,27 +216,30 @@ fun prepExp (e as (_, loc), sns) =
end)
| ENextval {seq, ...} =>
- let
- val s = case seq of
- (EPrim (Prim.String s), loc) =>
- (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc)
- | _ =>
- let
- val s' = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc)
- in
- (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), s']), loc)
- end
- in
- case prepString (s, [], 0) of
- NONE => (e, sns)
- | SOME (ss, n) =>
- let
- val s = String.concat (rev ss)
- in
- ((ENextval {seq = seq, prepared = SOME (#2 sns, s)}, loc),
- ((s, n) :: #1 sns, #2 sns + 1))
- end
- end
+ if #supportsNextval (Settings.currentDbms ()) then
+ let
+ val s = case seq of
+ (EPrim (Prim.String s), loc) =>
+ (EPrim (Prim.String ("SELECT NEXTVAL('" ^ s ^ "')")), loc)
+ | _ =>
+ let
+ val s' = (EFfiApp ("Basis", "strcat", [seq, (EPrim (Prim.String "')"), loc)]), loc)
+ in
+ (EFfiApp ("Basis", "strcat", [(EPrim (Prim.String "SELECT NEXTVAL('"), loc), s']), loc)
+ end
+ in
+ case prepString (s, [], 0) of
+ NONE => (e, sns)
+ | SOME (ss, n) =>
+ let
+ val s = String.concat (rev ss)
+ in
+ ((ENextval {seq = seq, prepared = SOME (#2 sns, s)}, loc),
+ ((s, n) :: #1 sns, #2 sns + 1))
+ end
+ end
+ else
+ (e, sns)
| EUnurlify (e, t) =>
let
diff --git a/src/settings.sig b/src/settings.sig
index 873bbcb9..c7855856 100644
--- a/src/settings.sig
+++ b/src/settings.sig
@@ -142,14 +142,15 @@ signature SETTINGS = sig
dml : ErrorMsg.span -> Print.PD.pp_desc,
dmlPrepared : {loc : ErrorMsg.span, id : int, dml : string,
inputs : sql_type list} -> Print.PD.pp_desc,
- nextval : ErrorMsg.span -> Print.PD.pp_desc,
+ nextval : {loc : ErrorMsg.span, seqE : Print.PD.pp_desc, seqName : string option} -> Print.PD.pp_desc,
nextvalPrepared : {loc : ErrorMsg.span, id : int, query : string} -> Print.PD.pp_desc,
sqlifyString : string -> string,
p_cast : string * sql_type -> string,
p_blank : int * sql_type -> string (* Prepared statement input *),
supportsDeleteAs : bool,
createSequence : string -> string,
- textKeysNeedLengths : bool
+ textKeysNeedLengths : bool,
+ supportsNextval : bool
}
val addDbms : dbms -> unit
diff --git a/src/settings.sml b/src/settings.sml
index 99fa748d..7393013e 100644
--- a/src/settings.sml
+++ b/src/settings.sml
@@ -332,14 +332,15 @@ type dbms = {
dml : ErrorMsg.span -> Print.PD.pp_desc,
dmlPrepared : {loc : ErrorMsg.span, id : int, dml : string,
inputs : sql_type list} -> Print.PD.pp_desc,
- nextval : ErrorMsg.span -> Print.PD.pp_desc,
+ nextval : {loc : ErrorMsg.span, seqName : string option, seqE : Print.PD.pp_desc} -> Print.PD.pp_desc,
nextvalPrepared : {loc : ErrorMsg.span, id : int, query : string} -> Print.PD.pp_desc,
sqlifyString : string -> string,
p_cast : string * sql_type -> string,
p_blank : int * sql_type -> string,
supportsDeleteAs : bool,
createSequence : string -> string,
- textKeysNeedLengths : bool
+ textKeysNeedLengths : bool,
+ supportsNextval : bool
}
val dbmses = ref ([] : dbms list)
@@ -359,7 +360,8 @@ val curDb = ref ({name = "",
p_blank = fn _ => "",
supportsDeleteAs = false,
createSequence = fn _ => "",
- textKeysNeedLengths = false} : dbms)
+ textKeysNeedLengths = false,
+ supportsNextval = false} : dbms)
fun addDbms v = dbmses := v :: !dbmses
fun setDbms s =