From b731727df078c7295aff5309460fb93a2e51c8e5 Mon Sep 17 00:00:00 2001 From: Adam Chlipala Date: Tue, 30 Jun 2009 15:45:10 -0400 Subject: Move all DBMS initialization to #init --- src/postgres.sml | 281 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 275 insertions(+), 6 deletions(-) (limited to 'src/postgres.sml') diff --git a/src/postgres.sml b/src/postgres.sml index b1390bc4..1fdda8ff 100644 --- a/src/postgres.sml +++ b/src/postgres.sml @@ -31,9 +31,275 @@ open Settings open Print.PD open Print -fun init (dbstring, ss) = +val ident = String.translate (fn #"'" => "PRIME" + | ch => str ch) + +fun p_sql_type_base t = + case t of + Int => "int8" + | Float => "float8" + | String => "text" + | Bool => "bool" + | Time => "timestamp" + | Blob => "bytea" + | Channel => "int8" + | Client => "int4" + | Nullable t => p_sql_type_base t + +fun init {dbstring, prepared = ss, tables, sequences} = box [if #persistent (currentProtocol ()) then - box [string "static void uw_db_prepare(uw_context ctx) {", + box [string "static void uw_db_validate(uw_context ctx) {", + newline, + string "PGconn *conn = uw_get_db(ctx);", + newline, + string "PGresult *res;", + newline, + newline, + p_list_sep newline + (fn (s, xts) => + let + val sl = CharVector.map Char.toLower s + + val q = "SELECT COUNT(*) FROM pg_class WHERE relname = '" + ^ sl ^ "'" + + val q' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '", + sl, + "') AND (", + String.concatWith " OR " + (map (fn (x, t) => + String.concat ["(attname = 'uw_", + CharVector.map + Char.toLower (ident x), + "' AND atttypid = (SELECT oid FROM pg_type", + " WHERE typname = '", + p_sql_type_base t, + "') AND attnotnull = ", + if isNotNull t then + "TRUE" + else + "FALSE", + ")"]) xts), + ")"] + + val q'' = String.concat ["SELECT COUNT(*) FROM pg_attribute WHERE attrelid = (SELECT oid FROM pg_class WHERE relname = '", + sl, + "') AND attname LIKE 'uw_%'"] + in + box [string "res = PQexec(conn, \"", + string q, + string "\");", + newline, + newline, + string "if (res == NULL) {", + newline, + box [string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Out of memory allocating query result.\");", + newline], + string "}", + newline, + newline, + string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, PQerrorMessage(conn), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Query failed:\\n", + string q, + string "\\n%s\", msg);", + newline], + string "}", + newline, + newline, + string "if (strcmp(PQgetvalue(res, 0, 0), \"1\")) {", + newline, + box [string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Table '", + string s, + string "' does not exist.\");", + newline], + string "}", + newline, + newline, + string "PQclear(res);", + newline, + + string "res = PQexec(conn, \"", + string q', + string "\");", + newline, + newline, + string "if (res == NULL) {", + newline, + box [string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Out of memory allocating query result.\");", + newline], + string "}", + newline, + newline, + string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, PQerrorMessage(conn), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Query failed:\\n", + string q', + string "\\n%s\", msg);", + newline], + string "}", + newline, + newline, + string "if (strcmp(PQgetvalue(res, 0, 0), \"", + string (Int.toString (length xts)), + string "\")) {", + newline, + box [string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Table '", + string s, + string "' has the wrong column types.\");", + newline], + string "}", + newline, + newline, + string "PQclear(res);", + newline, + newline, + + string "res = PQexec(conn, \"", + string q'', + string "\");", + newline, + newline, + string "if (res == NULL) {", + newline, + box [string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Out of memory allocating query result.\");", + newline], + string "}", + newline, + newline, + string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, PQerrorMessage(conn), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Query failed:\\n", + string q'', + string "\\n%s\", msg);", + newline], + string "}", + newline, + newline, + string "if (strcmp(PQgetvalue(res, 0, 0), \"", + string (Int.toString (length xts)), + string "\")) {", + newline, + box [string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Table '", + string s, + string "' has extra columns.\");", + newline], + string "}", + newline, + newline, + string "PQclear(res);", + newline] + end) tables, + + p_list_sep newline + (fn s => + let + val sl = CharVector.map Char.toLower s + + val q = "SELECT COUNT(*) FROM pg_class WHERE relname = '" + ^ sl ^ "' AND relkind = 'S'" + in + box [string "res = PQexec(conn, \"", + string q, + string "\");", + newline, + newline, + string "if (res == NULL) {", + newline, + box [string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Out of memory allocating query result.\");", + newline], + string "}", + newline, + newline, + string "if (PQresultStatus(res) != PGRES_TUPLES_OK) {", + newline, + box [string "char msg[1024];", + newline, + string "strncpy(msg, PQerrorMessage(conn), 1024);", + newline, + string "msg[1023] = 0;", + newline, + string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Query failed:\\n", + string q, + string "\\n%s\", msg);", + newline], + string "}", + newline, + newline, + string "if (strcmp(PQgetvalue(res, 0, 0), \"1\")) {", + newline, + box [string "PQclear(res);", + newline, + string "PQfinish(conn);", + newline, + string "uw_error(ctx, FATAL, \"Sequence '", + string s, + string "' does not exist.\");", + newline], + string "}", + newline, + newline, + string "PQclear(res);", + newline] + end) sequences, + + string "}", + + string "static void uw_db_prepare(uw_context ctx) {", newline, string "PGconn *conn = uw_get_db(ctx);", newline, @@ -153,7 +419,10 @@ fun init (dbstring, ss) = newline, newline] else - string "static void uw_db_prepare(uw_context ctx) { }", + box [string "static void uw_db_validate(uw_context ctx) { }", + newline, + string "static void uw_db_prepare(uw_context ctx) { }"], + newline, newline, @@ -222,10 +491,10 @@ fun p_getcol {wontLeakStrings, col = i, typ = t} = String => getter t | _ => box [string "({", newline, - p_sql_type t, + string (p_sql_type t), space, string "*tmp = uw_malloc(ctx, sizeof(", - p_sql_type t, + string (p_sql_type t), string "));", newline, string "*tmp = ", @@ -241,7 +510,7 @@ fun p_getcol {wontLeakStrings, col = i, typ = t} = string (Int.toString i), string ") ? ", box [string "({", - p_sql_type t, + string (p_sql_type t), space, string "tmp;", newline, -- cgit v1.2.3