diff options
-rw-r--r-- | src/core/tsi/transport_security.c | 8 | ||||
-rw-r--r-- | test/core/tsi/transport_security_test.c | 63 |
2 files changed, 69 insertions, 2 deletions
diff --git a/src/core/tsi/transport_security.c b/src/core/tsi/transport_security.c index c39e584496..db219a50a6 100644 --- a/src/core/tsi/transport_security.c +++ b/src/core/tsi/transport_security.c @@ -145,7 +145,9 @@ void tsi_frame_protector_destroy(tsi_frame_protector *self) { tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, unsigned char *bytes, size_t *bytes_size) { - if (self == NULL) return TSI_INVALID_ARGUMENT; + if (self == NULL || bytes == NULL || bytes_size == NULL) { + return TSI_INVALID_ARGUMENT; + } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size); } @@ -153,7 +155,9 @@ tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker *self, const unsigned char *bytes, size_t *bytes_size) { - if (self == NULL) return TSI_INVALID_ARGUMENT; + if (self == NULL || bytes == NULL || bytes_size == NULL) { + return TSI_INVALID_ARGUMENT; + } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; return self->vtable->process_bytes_from_peer(self, bytes, bytes_size); } diff --git a/test/core/tsi/transport_security_test.c b/test/core/tsi/transport_security_test.c index 858b92fc9d..7ce343987b 100644 --- a/test/core/tsi/transport_security_test.c +++ b/test/core/tsi/transport_security_test.c @@ -43,6 +43,7 @@ #include <openssl/crypto.h> #include "src/core/support/string.h" +#include "src/core/tsi/fake_transport_security.h" #include "src/core/tsi/ssl_transport_security.h" #include "test/core/util/test_config.h" @@ -296,8 +297,70 @@ static void test_peer_matches_name(void) { } } +typedef struct { + tsi_result res; + const char *str; +} tsi_result_string_pair; + +static void test_result_strings(void) { + const tsi_result_string_pair results[] = { + {TSI_OK, "TSI_OK"}, + {TSI_UNKNOWN_ERROR, "TSI_UNKNOWN_ERROR"}, + {TSI_INVALID_ARGUMENT, "TSI_INVALID_ARGUMENT"}, + {TSI_PERMISSION_DENIED, "TSI_PERMISSION_DENIED"}, + {TSI_INCOMPLETE_DATA, "TSI_INCOMPLETE_DATA"}, + {TSI_FAILED_PRECONDITION, "TSI_FAILED_PRECONDITION"}, + {TSI_UNIMPLEMENTED, "TSI_UNIMPLEMENTED"}, + {TSI_INTERNAL_ERROR, "TSI_INTERNAL_ERROR"}, + {TSI_DATA_CORRUPTED, "TSI_DATA_CORRUPTED"}, + {TSI_NOT_FOUND, "TSI_NOT_FOUND"}, + {TSI_PROTOCOL_FAILURE, "TSI_PROTOCOL_FAILURE"}, + {TSI_HANDSHAKE_IN_PROGRESS, "TSI_HANDSHAKE_IN_PROGRESS"}, + {TSI_OUT_OF_RESOURCES, "TSI_OUT_OF_RESOURCES"}}; + size_t i; + for (i = 0; i < GPR_ARRAY_SIZE(results); i++) { + GPR_ASSERT(strcmp(results[i].str, tsi_result_to_string(results[i].res)) == + 0); + } + GPR_ASSERT(strcmp("UNKNOWN", tsi_result_to_string((tsi_result)42)) == 0); +} + +static void test_protector_invalid_args(void) { + GPR_ASSERT(tsi_frame_protector_protect(NULL, NULL, NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_frame_protector_protect_flush(NULL, NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_frame_protector_unprotect(NULL, NULL, NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); +} + +static void test_handshaker_invalid_args(void) { + GPR_ASSERT(tsi_handshaker_get_result(NULL) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_extract_peer(NULL, NULL) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_create_frame_protector(NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_process_bytes_from_peer(NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(tsi_handshaker_get_bytes_to_send_to_peer(NULL, NULL, NULL) == + TSI_INVALID_ARGUMENT); +} + +static void test_handshaker_invalid_state(void) { + tsi_handshaker *h = tsi_create_fake_handshaker(0); + tsi_peer peer; + tsi_frame_protector *p; + GPR_ASSERT(tsi_handshaker_extract_peer(h, &peer) == TSI_FAILED_PRECONDITION); + GPR_ASSERT(tsi_handshaker_create_frame_protector(h, NULL, &p) == + TSI_FAILED_PRECONDITION); + tsi_handshaker_destroy(h); +} + int main(int argc, char **argv) { grpc_test_init(argc, argv); test_peer_matches_name(); + test_result_strings(); + test_protector_invalid_args(); + test_handshaker_invalid_args(); + test_handshaker_invalid_state(); return 0; } |