summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/c/driver.c51
-rw-r--r--src/c/urweb.c22
-rw-r--r--src/cjr_print.sml97
-rw-r--r--src/compiler.sml2
-rw-r--r--src/monoize.sml33
5 files changed, 195 insertions, 10 deletions
diff --git a/src/c/driver.c b/src/c/driver.c
index e6538616..f7456ed9 100644
--- a/src/c/driver.c
+++ b/src/c/driver.c
@@ -10,6 +10,8 @@
#include <pthread.h>
+#include <mhash.h>
+
#include "urweb.h"
int uw_backlog = 10;
@@ -102,6 +104,46 @@ static uw_context new_context() {
return ctx;
}
+#define KEYSIZE 16
+#define PASSSIZE 4
+
+#define HASH_ALGORITHM MHASH_SHA256
+#define HASH_BLOCKSIZE 32
+#define KEYGEN_ALGORITHM KEYGEN_MCRYPT
+
+int uw_hash_blocksize = HASH_BLOCKSIZE;
+
+static int password[PASSSIZE];
+static unsigned char private_key[KEYSIZE];
+
+static void init_crypto() {
+ KEYGEN kg = {{HASH_ALGORITHM, HASH_ALGORITHM}};
+ int i;
+
+ assert(mhash_get_block_size(HASH_ALGORITHM) == HASH_BLOCKSIZE);
+
+ for (i = 0; i < PASSSIZE; ++i)
+ password[i] = rand();
+
+ if (mhash_keygen_ext(KEYGEN_ALGORITHM, kg,
+ private_key, sizeof(private_key),
+ (unsigned char*)password, sizeof(password)) < 0) {
+ printf("Key generation failed\n");
+ exit(1);
+ }
+}
+
+void uw_sign(const char *in, char *out) {
+ MHASH td;
+
+ td = mhash_hmac_init(HASH_ALGORITHM, private_key, sizeof(private_key),
+ mhash_get_hash_pblock(HASH_ALGORITHM));
+
+ mhash(td, in, strlen(in));
+ if (mhash_hmac_deinit(td, out) < 0)
+ printf("Signing failed");
+}
+
static void *worker(void *data) {
int me = *(int *)data, retries_left = MAX_RETRIES;
uw_context ctx = new_context();
@@ -344,9 +386,13 @@ static void sigint(int signum) {
}
static void initialize() {
- uw_context ctx = new_context();
+ uw_context ctx;
failure_kind fk;
+ init_crypto();
+
+ ctx = new_context();
+
if (!ctx)
exit(1);
@@ -411,6 +457,7 @@ int main(int argc, char *argv[]) {
}
}
+ uw_global_init();
initialize();
names = calloc(nthreads, sizeof(int));
@@ -444,8 +491,6 @@ int main(int argc, char *argv[]) {
sin_size = sizeof their_addr;
- uw_global_init();
-
printf("Listening on port %d....\n", uw_port);
{
diff --git a/src/c/urweb.c b/src/c/urweb.c
index d3a93af9..bd42352f 100644
--- a/src/c/urweb.c
+++ b/src/c/urweb.c
@@ -1981,3 +1981,25 @@ failure_kind uw_initialize(uw_context ctx) {
uw_Basis_string uw_Basis_bless(uw_context ctx, uw_Basis_string s) {
return s;
}
+
+uw_Basis_string uw_unnull(uw_Basis_string s) {
+ return s ? s : "";
+}
+
+extern int uw_hash_blocksize;
+
+uw_Basis_string uw_Basis_makeSigString(uw_context ctx, uw_Basis_string sig) {
+ uw_Basis_string r = uw_malloc(ctx, 2 * uw_hash_blocksize + 1);
+ int i;
+
+ for (i = 0; i < uw_hash_blocksize; ++i)
+ sprintf(&r[2*i], "%.02X", ((unsigned char *)sig)[i]);
+
+ return r;
+}
+
+extern uw_Basis_string uw_cookie_sig(uw_context);
+
+uw_Basis_string uw_Basis_sigString(uw_context ctx, uw_unit u) {
+ return uw_cookie_sig(ctx);
+}
diff --git a/src/cjr_print.sml b/src/cjr_print.sml
index e834300d..774b2b75 100644
--- a/src/cjr_print.sml
+++ b/src/cjr_print.sml
@@ -2198,6 +2198,26 @@ fun is_not_null t =
(TOption _, _) => false
| _ => true
+fun sigName fields =
+ let
+ fun inFields s = List.exists (fn (s', _) => s' = s) fields
+
+ fun getSigName n =
+ let
+ val s = "Sig" ^ Int.toString n
+ in
+ if inFields s then
+ getSigName (n + 1)
+ else
+ s
+ end
+ in
+ if inFields "Sig" then
+ getSigName 0
+ else
+ "Sig"
+ end
+
fun p_file env (ds, ps) =
let
val (pds, env) = ListUtil.foldlMap (fn (d, env) =>
@@ -2214,6 +2234,7 @@ fun p_file env (ds, ps) =
(TRecord i, _) =>
let
val xts = E.lookupStruct env i
+ val xts = (sigName xts, (TRecord 0, ErrorMsg.dummySpan)) :: xts
val xtsSet = SS.addList (SS.empty, map #1 xts)
in
foldl (fn ((x, _), fields) =>
@@ -2245,6 +2266,8 @@ fun p_file env (ds, ps) =
end)
SM.empty fields
+ val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
fun makeSwitch (fnums, i) =
case SM.foldl (fn (n, NotFound) => Found n
| (n, Error) => Error
@@ -2328,10 +2351,10 @@ fun p_file env (ds, ps) =
fun p_page (ek, s, n, ts, ran, side) =
let
- val (ts, defInputs, inputsVar) =
+ val (ts, defInputs, inputsVar, fields) =
case ek of
- Core.Link => (List.take (ts, length ts - 1), string "", string "")
- | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "")
+ Core.Link => (List.take (ts, length ts - 1), string "", string "", NONE)
+ | Core.Rpc _ => (List.take (ts, length ts - 1), string "", string "", NONE)
| Core.Action _ =>
case List.nth (ts, length ts - 2) of
(TRecord i, _) =>
@@ -2392,12 +2415,43 @@ fun p_file env (ds, ps) =
newline],
box [string ",",
space,
- string "uw_inputs"])
+ string "uw_inputs"],
+ SOME xts)
end
| _ => raise Fail "CjrPrint: Last argument to an action isn't a record"
+
+ fun couldWrite ek =
+ case ek of
+ Link => false
+ | Action ef => ef = ReadWrite
+ | Rpc ef => ef = ReadWrite
in
- box [string "if (!strncmp(request, \"",
+ box [if couldWrite ek then
+ box [string "{",
+ newline,
+ string "uw_Basis_string sig = ",
+ case fields of
+ NONE => string "uw_Basis_requestHeader(ctx, \"UrWeb-Sig\")"
+ | SOME fields =>
+ case SM.find (fnums, sigName fields) of
+ NONE => raise Fail "CjrPrint: sig name wasn't assigned a number"
+ | SOME inum =>
+ string ("uw_get_input(ctx, " ^ Int.toString inum ^ ")"),
+ string ";",
+ newline,
+ string "if (sig == NULL) uw_error(ctx, FATAL, \"Missing cookie signature\");",
+ newline,
+ string "if (strcmp(sig, uw_cookie_sig(ctx)))",
+ newline,
+ box [string "uw_error(ctx, FATAL, \"Wrong cookie signature\");",
+ newline],
+ string "}",
+ newline]
+ else
+ box [],
+
+ string "if (!strncmp(request, \"",
string (String.toString s),
string "\", ",
string (Int.toString (size s)),
@@ -2745,6 +2799,18 @@ fun p_file env (ds, ps) =
string "}"]
val hasDb = List.exists (fn (DDatabase _, _) => true | _ => false) ds
+
+ val cookies = List.mapPartial (fn (DCookie s, _) => SOME s | _ => NONE) ds
+
+ val cookieCode = foldl (fn (cookie, acc) =>
+ SOME (case acc of
+ NONE => string ("uw_unnull(uw_Basis_get_cookie(ctx, \""
+ ^ cookie ^ "\"))")
+ | SOME acc => box [string ("uw_Basis_strcat(ctx, uw_unnull(uw_Basis_get_cookie(ctx, \""
+ ^ cookie ^ "\")), uw_Basis_strcat(ctx, \"/\", "),
+ acc,
+ string "))"]))
+ NONE cookies
in
box [string "#include <stdio.h>",
newline,
@@ -2783,6 +2849,27 @@ fun p_file env (ds, ps) =
string "}",
newline,
newline,
+
+ string "extern void uw_sign(const char *in, char *out);",
+ newline,
+ string "extern int uw_hash_blocksize;",
+ newline,
+ string "uw_Basis_string uw_cookie_sig(uw_context ctx) {",
+ newline,
+ box [string "uw_Basis_string r = uw_malloc(ctx, uw_hash_blocksize);",
+ newline,
+ string "uw_sign(",
+ case cookieCode of
+ NONE => string "\"\""
+ | SOME code => code,
+ string ", r);",
+ newline,
+ string "return uw_Basis_makeSigString(ctx, r);",
+ newline],
+ string "}",
+ newline,
+ newline,
+
string "void uw_handle(uw_context ctx, char *request) {",
newline,
string "if (!strcmp(request, \"/app.js\")) {",
diff --git a/src/compiler.sml b/src/compiler.sml
index 5223abe9..cf54c3cf 100644
--- a/src/compiler.sml
+++ b/src/compiler.sml
@@ -611,7 +611,7 @@ fun compileC {cname, oname, ename, libs, profile} =
val driver_o = clibFile "driver.o"
val compile = "gcc " ^ Config.gccArgs ^ " -Wstrict-prototypes -Werror -O3 -I include -c " ^ cname ^ " -o " ^ oname
- val link = "gcc -Werror -O3 -lm -pthread " ^ libs ^ " " ^ urweb_o ^ " " ^ oname ^ " " ^ driver_o ^ " -o " ^ ename
+ val link = "gcc -Werror -O3 -lm -lmhash -pthread " ^ libs ^ " " ^ urweb_o ^ " " ^ oname ^ " " ^ driver_o ^ " -o " ^ ename
val (compile, link) =
if profile then
diff --git a/src/monoize.sml b/src/monoize.sml
index 0c05cf90..a979e5ed 100644
--- a/src/monoize.sml
+++ b/src/monoize.sml
@@ -2399,7 +2399,7 @@ fun monoExp (env, st, fm) (all as (e, loc)) =
| L.EApp ((L.ECApp (
(L.ECApp ((L.EFfi ("Basis", "form"), _), _), _),
- _), _),
+ (L.CRecord (_, fields), _)), _),
xml) =>
let
fun findSubmit (e, _) =
@@ -2468,7 +2468,38 @@ fun monoExp (env, st, fm) (all as (e, loc)) =
fm)
end
+ fun inFields s = List.exists (fn ((L.CName s', _), _) => s' = s
+ | _ => true) fields
+
+ fun getSigName () =
+ let
+ fun getSigName' n =
+ let
+ val s = "Sig" ^ Int.toString n
+ in
+ if inFields s then
+ getSigName' (n + 1)
+ else
+ s
+ end
+ in
+ if inFields "Sig" then
+ getSigName' 0
+ else
+ "Sig"
+ end
+
+ val sigName = getSigName ()
+ val sigSet = (L'.EFfiApp ("Basis", "sigString", [(L'.ERecord [], loc)]), loc)
+ val sigSet = (L'.EStrcat ((L'.EPrim (Prim.String ("<input type=\"hidden\" name=\""
+ ^ sigName
+ ^ "\" value=\"")), loc),
+ sigSet), loc)
+ val sigSet = (L'.EStrcat (sigSet,
+ (L'.EPrim (Prim.String "\">"), loc)), loc)
+
val (xml, fm) = monoExp (env, st, fm) xml
+ val xml = (L'.EStrcat (sigSet, xml), loc)
in
((L'.EStrcat ((L'.EStrcat ((L'.EPrim (Prim.String "<form method=\"post\""), loc),
(L'.EStrcat (action,