aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/core/tsi/transport_security.c8
-rw-r--r--test/core/tsi/transport_security_test.c63
2 files changed, 69 insertions, 2 deletions
diff --git a/src/core/tsi/transport_security.c b/src/core/tsi/transport_security.c
index c39e584496..db219a50a6 100644
--- a/src/core/tsi/transport_security.c
+++ b/src/core/tsi/transport_security.c
@@ -145,7 +145,9 @@ void tsi_frame_protector_destroy(tsi_frame_protector *self) {
tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self,
unsigned char *bytes,
size_t *bytes_size) {
- if (self == NULL) return TSI_INVALID_ARGUMENT;
+ if (self == NULL || bytes == NULL || bytes_size == NULL) {
+ return TSI_INVALID_ARGUMENT;
+ }
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size);
}
@@ -153,7 +155,9 @@ tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self,
tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker *self,
const unsigned char *bytes,
size_t *bytes_size) {
- if (self == NULL) return TSI_INVALID_ARGUMENT;
+ if (self == NULL || bytes == NULL || bytes_size == NULL) {
+ return TSI_INVALID_ARGUMENT;
+ }
if (self->frame_protector_created) return TSI_FAILED_PRECONDITION;
return self->vtable->process_bytes_from_peer(self, bytes, bytes_size);
}
diff --git a/test/core/tsi/transport_security_test.c b/test/core/tsi/transport_security_test.c
index 858b92fc9d..7ce343987b 100644
--- a/test/core/tsi/transport_security_test.c
+++ b/test/core/tsi/transport_security_test.c
@@ -43,6 +43,7 @@
#include <openssl/crypto.h>
#include "src/core/support/string.h"
+#include "src/core/tsi/fake_transport_security.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "test/core/util/test_config.h"
@@ -296,8 +297,70 @@ static void test_peer_matches_name(void) {
}
}
+typedef struct {
+ tsi_result res;
+ const char *str;
+} tsi_result_string_pair;
+
+static void test_result_strings(void) {
+ const tsi_result_string_pair results[] = {
+ {TSI_OK, "TSI_OK"},
+ {TSI_UNKNOWN_ERROR, "TSI_UNKNOWN_ERROR"},
+ {TSI_INVALID_ARGUMENT, "TSI_INVALID_ARGUMENT"},
+ {TSI_PERMISSION_DENIED, "TSI_PERMISSION_DENIED"},
+ {TSI_INCOMPLETE_DATA, "TSI_INCOMPLETE_DATA"},
+ {TSI_FAILED_PRECONDITION, "TSI_FAILED_PRECONDITION"},
+ {TSI_UNIMPLEMENTED, "TSI_UNIMPLEMENTED"},
+ {TSI_INTERNAL_ERROR, "TSI_INTERNAL_ERROR"},
+ {TSI_DATA_CORRUPTED, "TSI_DATA_CORRUPTED"},
+ {TSI_NOT_FOUND, "TSI_NOT_FOUND"},
+ {TSI_PROTOCOL_FAILURE, "TSI_PROTOCOL_FAILURE"},
+ {TSI_HANDSHAKE_IN_PROGRESS, "TSI_HANDSHAKE_IN_PROGRESS"},
+ {TSI_OUT_OF_RESOURCES, "TSI_OUT_OF_RESOURCES"}};
+ size_t i;
+ for (i = 0; i < GPR_ARRAY_SIZE(results); i++) {
+ GPR_ASSERT(strcmp(results[i].str, tsi_result_to_string(results[i].res)) ==
+ 0);
+ }
+ GPR_ASSERT(strcmp("UNKNOWN", tsi_result_to_string((tsi_result)42)) == 0);
+}
+
+static void test_protector_invalid_args(void) {
+ GPR_ASSERT(tsi_frame_protector_protect(NULL, NULL, NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_frame_protector_protect_flush(NULL, NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_frame_protector_unprotect(NULL, NULL, NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+}
+
+static void test_handshaker_invalid_args(void) {
+ GPR_ASSERT(tsi_handshaker_get_result(NULL) == TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_handshaker_extract_peer(NULL, NULL) == TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_handshaker_create_frame_protector(NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_handshaker_process_bytes_from_peer(NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+ GPR_ASSERT(tsi_handshaker_get_bytes_to_send_to_peer(NULL, NULL, NULL) ==
+ TSI_INVALID_ARGUMENT);
+}
+
+static void test_handshaker_invalid_state(void) {
+ tsi_handshaker *h = tsi_create_fake_handshaker(0);
+ tsi_peer peer;
+ tsi_frame_protector *p;
+ GPR_ASSERT(tsi_handshaker_extract_peer(h, &peer) == TSI_FAILED_PRECONDITION);
+ GPR_ASSERT(tsi_handshaker_create_frame_protector(h, NULL, &p) ==
+ TSI_FAILED_PRECONDITION);
+ tsi_handshaker_destroy(h);
+}
+
int main(int argc, char **argv) {
grpc_test_init(argc, argv);
test_peer_matches_name();
+ test_result_strings();
+ test_protector_invalid_args();
+ test_handshaker_invalid_args();
+ test_handshaker_invalid_state();
return 0;
}