aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/core/tsi/ssl_transport_security.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/tsi/ssl_transport_security.cc')
-rw-r--r--src/core/tsi/ssl_transport_security.cc348
1 files changed, 238 insertions, 110 deletions
diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc
index 0ba6587678..8065a8b185 100644
--- a/src/core/tsi/ssl_transport_security.cc
+++ b/src/core/tsi/ssl_transport_security.cc
@@ -57,6 +57,7 @@ extern "C" {
#define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
#define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
+#define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
/* Putting a macro like this and littering the source file with #if is really
bad practice.
@@ -105,10 +106,20 @@ typedef struct {
SSL* ssl;
BIO* network_io;
tsi_result result;
+ unsigned char* outgoing_bytes_buffer;
+ size_t outgoing_bytes_buffer_size;
tsi_ssl_handshaker_factory* factory_ref;
} tsi_ssl_handshaker;
typedef struct {
+ tsi_handshaker_result base;
+ SSL* ssl;
+ BIO* network_io;
+ unsigned char* unused_bytes;
+ size_t unused_bytes_size;
+} tsi_ssl_handshaker_result;
+
+typedef struct {
tsi_frame_protector base;
SSL* ssl;
BIO* network_io;
@@ -120,12 +131,14 @@ typedef struct {
/* --- Library Initialization. ---*/
static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
-static gpr_mu* g_openssl_mutexes = nullptr;
static int g_ssl_ctx_ex_factory_index = -1;
+static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
+
+#if OPENSSL_VERSION_NUMBER < 0x10100000
+static gpr_mu* g_openssl_mutexes = nullptr;
static void openssl_locking_cb(int mode, int type, const char* file,
int line) GRPC_UNUSED;
static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
-static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
static void openssl_locking_cb(int mode, int type, const char* file, int line) {
if (mode & CRYPTO_LOCK) {
@@ -138,22 +151,27 @@ static void openssl_locking_cb(int mode, int type, const char* file, int line) {
static unsigned long openssl_thread_id_cb(void) {
return static_cast<unsigned long>(gpr_thd_currentid());
}
+#endif
static void init_openssl(void) {
- int i;
- int num_locks;
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
- num_locks = CRYPTO_num_locks();
- GPR_ASSERT(num_locks > 0);
- g_openssl_mutexes = static_cast<gpr_mu*>(
- gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
- for (i = 0; i < CRYPTO_num_locks(); i++) {
- gpr_mu_init(&g_openssl_mutexes[i]);
- }
- CRYPTO_set_locking_callback(openssl_locking_cb);
- CRYPTO_set_id_callback(openssl_thread_id_cb);
+#if OPENSSL_VERSION_NUMBER < 0x10100000
+ if (!CRYPTO_get_locking_callback()) {
+ int num_locks = CRYPTO_num_locks();
+ GPR_ASSERT(num_locks > 0);
+ g_openssl_mutexes = static_cast<gpr_mu*>(
+ gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
+ for (int i = 0; i < num_locks; i++) {
+ gpr_mu_init(&g_openssl_mutexes[i]);
+ }
+ CRYPTO_set_locking_callback(openssl_locking_cb);
+ CRYPTO_set_id_callback(openssl_thread_id_cb);
+ } else {
+ gpr_log(GPR_INFO, "OpenSSL callback has already been set.");
+ }
+#endif
g_ssl_ctx_ex_factory_index =
SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1);
@@ -987,94 +1005,15 @@ static void tsi_ssl_handshaker_factory_init(
gpr_ref_init(&factory->refcount, 1);
}
-/* --- tsi_handshaker methods implementation. ---*/
-
-static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker* self,
- unsigned char* bytes,
- size_t* bytes_size) {
- tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
- int bytes_read_from_ssl = 0;
- if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 ||
- *bytes_size > INT_MAX) {
- return TSI_INVALID_ARGUMENT;
- }
- GPR_ASSERT(*bytes_size <= INT_MAX);
- bytes_read_from_ssl =
- BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
- if (bytes_read_from_ssl < 0) {
- *bytes_size = 0;
- if (!BIO_should_retry(impl->network_io)) {
- impl->result = TSI_INTERNAL_ERROR;
- return impl->result;
- } else {
- return TSI_OK;
- }
- }
- *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
- return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
-}
-
-static tsi_result ssl_handshaker_get_result(tsi_handshaker* self) {
- tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
- if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
- SSL_is_init_finished(impl->ssl)) {
- impl->result = TSI_OK;
- }
- return impl->result;
-}
+/* --- tsi_handshaker_result methods implementation. ---*/
-static tsi_result ssl_handshaker_process_bytes_from_peer(
- tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size) {
- tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
- int bytes_written_into_ssl_size = 0;
- if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
- return TSI_INVALID_ARGUMENT;
- }
- GPR_ASSERT(*bytes_size <= INT_MAX);
- bytes_written_into_ssl_size =
- BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
- if (bytes_written_into_ssl_size < 0) {
- gpr_log(GPR_ERROR, "Could not write to memory BIO.");
- impl->result = TSI_INTERNAL_ERROR;
- return impl->result;
- }
- *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
-
- if (!tsi_handshaker_is_in_progress(self)) {
- impl->result = TSI_OK;
- return impl->result;
- } else {
- /* Get ready to get some bytes from SSL. */
- int ssl_result = SSL_do_handshake(impl->ssl);
- ssl_result = SSL_get_error(impl->ssl, ssl_result);
- switch (ssl_result) {
- case SSL_ERROR_WANT_READ:
- if (BIO_pending(impl->network_io) == 0) {
- /* We need more data. */
- return TSI_INCOMPLETE_DATA;
- } else {
- return TSI_OK;
- }
- case SSL_ERROR_NONE:
- return TSI_OK;
- default: {
- char err_str[256];
- ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
- gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
- ssl_error_string(ssl_result), err_str);
- impl->result = TSI_PROTOCOL_FAILURE;
- return impl->result;
- }
- }
- }
-}
-
-static tsi_result ssl_handshaker_extract_peer(tsi_handshaker* self,
- tsi_peer* peer) {
+static tsi_result ssl_handshaker_result_extract_peer(
+ const tsi_handshaker_result* self, tsi_peer* peer) {
tsi_result result = TSI_OK;
const unsigned char* alpn_selected = nullptr;
unsigned int alpn_selected_len;
- tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
+ const tsi_ssl_handshaker_result* impl =
+ reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
if (peer_cert != nullptr) {
result = peer_from_x509(peer_cert, 1, peer);
@@ -1120,12 +1059,14 @@ static tsi_result ssl_handshaker_extract_peer(tsi_handshaker* self,
return result;
}
-static tsi_result ssl_handshaker_create_frame_protector(
- tsi_handshaker* self, size_t* max_output_protected_frame_size,
+static tsi_result ssl_handshaker_result_create_frame_protector(
+ const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
tsi_frame_protector** protector) {
size_t actual_max_output_protected_frame_size =
TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
- tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
+ tsi_ssl_handshaker_result* impl =
+ reinterpret_cast<tsi_ssl_handshaker_result*>(
+ const_cast<tsi_handshaker_result*>(self));
tsi_ssl_frame_protector* protector_impl =
static_cast<tsi_ssl_frame_protector*>(
gpr_zalloc(sizeof(*protector_impl)));
@@ -1153,35 +1094,218 @@ static tsi_result ssl_handshaker_create_frame_protector(
return TSI_INTERNAL_ERROR;
}
- /* Transfer ownership of ssl and network_io to the frame protector. It is OK
- * as the caller cannot call anything else but destroy on the handshaker
- * after this call. */
+ /* Transfer ownership of ssl and network_io to the frame protector. */
protector_impl->ssl = impl->ssl;
impl->ssl = nullptr;
protector_impl->network_io = impl->network_io;
impl->network_io = nullptr;
-
protector_impl->base.vtable = &frame_protector_vtable;
*protector = &protector_impl->base;
return TSI_OK;
}
+static tsi_result ssl_handshaker_result_get_unused_bytes(
+ const tsi_handshaker_result* self, const unsigned char** bytes,
+ size_t* bytes_size) {
+ const tsi_ssl_handshaker_result* impl =
+ reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
+ *bytes_size = impl->unused_bytes_size;
+ *bytes = impl->unused_bytes;
+ return TSI_OK;
+}
+
+static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
+ tsi_ssl_handshaker_result* impl =
+ reinterpret_cast<tsi_ssl_handshaker_result*>(self);
+ SSL_free(impl->ssl);
+ BIO_free(impl->network_io);
+ gpr_free(impl->unused_bytes);
+ gpr_free(impl);
+}
+
+static const tsi_handshaker_result_vtable handshaker_result_vtable = {
+ ssl_handshaker_result_extract_peer,
+ nullptr, /* create_zero_copy_grpc_protector */
+ ssl_handshaker_result_create_frame_protector,
+ ssl_handshaker_result_get_unused_bytes,
+ ssl_handshaker_result_destroy,
+};
+
+static tsi_result ssl_handshaker_result_create(
+ tsi_ssl_handshaker* handshaker, const unsigned char* unused_bytes,
+ size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) {
+ if (handshaker == nullptr || handshaker_result == nullptr ||
+ (unused_bytes_size > 0 && unused_bytes == nullptr)) {
+ return TSI_INVALID_ARGUMENT;
+ }
+ tsi_ssl_handshaker_result* result =
+ static_cast<tsi_ssl_handshaker_result*>(gpr_zalloc(sizeof(*result)));
+ result->base.vtable = &handshaker_result_vtable;
+ /* Transfer ownership of ssl and network_io to the handshaker result. */
+ result->ssl = handshaker->ssl;
+ handshaker->ssl = nullptr;
+ result->network_io = handshaker->network_io;
+ handshaker->network_io = nullptr;
+ if (unused_bytes_size > 0) {
+ result->unused_bytes =
+ static_cast<unsigned char*>(gpr_malloc(unused_bytes_size));
+ memcpy(result->unused_bytes, unused_bytes, unused_bytes_size);
+ }
+ result->unused_bytes_size = unused_bytes_size;
+ *handshaker_result = &result->base;
+ return TSI_OK;
+}
+
+/* --- tsi_handshaker methods implementation. ---*/
+
+static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
+ tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) {
+ int bytes_read_from_ssl = 0;
+ if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 ||
+ *bytes_size > INT_MAX) {
+ return TSI_INVALID_ARGUMENT;
+ }
+ GPR_ASSERT(*bytes_size <= INT_MAX);
+ bytes_read_from_ssl =
+ BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
+ if (bytes_read_from_ssl < 0) {
+ *bytes_size = 0;
+ if (!BIO_should_retry(impl->network_io)) {
+ impl->result = TSI_INTERNAL_ERROR;
+ return impl->result;
+ } else {
+ return TSI_OK;
+ }
+ }
+ *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
+ return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
+}
+
+static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
+ if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
+ SSL_is_init_finished(impl->ssl)) {
+ impl->result = TSI_OK;
+ }
+ return impl->result;
+}
+
+static tsi_result ssl_handshaker_process_bytes_from_peer(
+ tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) {
+ int bytes_written_into_ssl_size = 0;
+ if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
+ return TSI_INVALID_ARGUMENT;
+ }
+ GPR_ASSERT(*bytes_size <= INT_MAX);
+ bytes_written_into_ssl_size =
+ BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
+ if (bytes_written_into_ssl_size < 0) {
+ gpr_log(GPR_ERROR, "Could not write to memory BIO.");
+ impl->result = TSI_INTERNAL_ERROR;
+ return impl->result;
+ }
+ *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
+
+ if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
+ impl->result = TSI_OK;
+ return impl->result;
+ } else {
+ /* Get ready to get some bytes from SSL. */
+ int ssl_result = SSL_do_handshake(impl->ssl);
+ ssl_result = SSL_get_error(impl->ssl, ssl_result);
+ switch (ssl_result) {
+ case SSL_ERROR_WANT_READ:
+ if (BIO_pending(impl->network_io) == 0) {
+ /* We need more data. */
+ return TSI_INCOMPLETE_DATA;
+ } else {
+ return TSI_OK;
+ }
+ case SSL_ERROR_NONE:
+ return TSI_OK;
+ default: {
+ char err_str[256];
+ ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
+ gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
+ ssl_error_string(ssl_result), err_str);
+ impl->result = TSI_PROTOCOL_FAILURE;
+ return impl->result;
+ }
+ }
+ }
+}
+
static void ssl_handshaker_destroy(tsi_handshaker* self) {
tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
SSL_free(impl->ssl);
BIO_free(impl->network_io);
+ gpr_free(impl->outgoing_bytes_buffer);
tsi_ssl_handshaker_factory_unref(impl->factory_ref);
gpr_free(impl);
}
+static tsi_result ssl_handshaker_next(
+ tsi_handshaker* self, const unsigned char* received_bytes,
+ size_t received_bytes_size, const unsigned char** bytes_to_send,
+ size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
+ tsi_handshaker_on_next_done_cb cb, void* user_data) {
+ /* Input sanity check. */
+ if ((received_bytes_size > 0 && received_bytes == nullptr) ||
+ bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
+ handshaker_result == nullptr) {
+ return TSI_INVALID_ARGUMENT;
+ }
+ /* If there are received bytes, process them first. */
+ tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
+ tsi_result status = TSI_OK;
+ size_t bytes_consumed = received_bytes_size;
+ if (received_bytes_size > 0) {
+ status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes,
+ &bytes_consumed);
+ if (status != TSI_OK) return status;
+ }
+ /* Get bytes to send to the peer, if available. */
+ size_t offset = 0;
+ do {
+ size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
+ status = ssl_handshaker_get_bytes_to_send_to_peer(
+ impl, impl->outgoing_bytes_buffer + offset, &to_send_size);
+ offset += to_send_size;
+ if (status == TSI_INCOMPLETE_DATA) {
+ impl->outgoing_bytes_buffer_size *= 2;
+ impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
+ impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
+ }
+ } while (status == TSI_INCOMPLETE_DATA);
+ if (status != TSI_OK) return status;
+ *bytes_to_send = impl->outgoing_bytes_buffer;
+ *bytes_to_send_size = offset;
+ /* If handshake completes, create tsi_handshaker_result. */
+ if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
+ *handshaker_result = nullptr;
+ } else {
+ size_t unused_bytes_size = received_bytes_size - bytes_consumed;
+ const unsigned char* unused_bytes =
+ unused_bytes_size == 0 ? nullptr : received_bytes + bytes_consumed;
+ status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
+ handshaker_result);
+ if (status == TSI_OK) {
+ /* Indicates that the handshake has completed and that a handshaker_result
+ * has been created. */
+ self->handshaker_result_created = true;
+ }
+ }
+ return status;
+}
+
static const tsi_handshaker_vtable handshaker_vtable = {
- ssl_handshaker_get_bytes_to_send_to_peer,
- ssl_handshaker_process_bytes_from_peer,
- ssl_handshaker_get_result,
- ssl_handshaker_extract_peer,
- ssl_handshaker_create_frame_protector,
+ nullptr, /* get_bytes_to_send_to_peer -- deprecated */
+ nullptr, /* process_bytes_from_peer -- deprecated */
+ nullptr, /* get_result -- deprecated */
+ nullptr, /* extract_peer -- deprecated */
+ nullptr, /* create_frame_protector -- deprecated */
ssl_handshaker_destroy,
- nullptr,
+ ssl_handshaker_next,
+ nullptr, /* shutdown */
};
/* --- tsi_ssl_handshaker_factory common methods. --- */
@@ -1259,6 +1383,10 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
impl->ssl = ssl;
impl->network_io = network_io;
impl->result = TSI_HANDSHAKE_IN_PROGRESS;
+ impl->outgoing_bytes_buffer_size =
+ TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
+ impl->outgoing_bytes_buffer =
+ static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
impl->base.vtable = &handshaker_vtable;
impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);