diff options
Diffstat (limited to 'src/core/tsi/ssl_transport_security.c')
-rw-r--r-- | src/core/tsi/ssl_transport_security.c | 113 |
1 files changed, 105 insertions, 8 deletions
diff --git a/src/core/tsi/ssl_transport_security.c b/src/core/tsi/ssl_transport_security.c index 1fd65928f9..7ebf9dd96f 100644 --- a/src/core/tsi/ssl_transport_security.c +++ b/src/core/tsi/ssl_transport_security.c @@ -67,7 +67,13 @@ /* --- Structure definitions. ---*/ +struct tsi_ssl_handshaker_factory { + const tsi_ssl_handshaker_factory_vtable *vtable; + gpr_refcount refcount; +}; + struct tsi_ssl_client_handshaker_factory { + tsi_ssl_handshaker_factory base; SSL_CTX *ssl_context; unsigned char *alpn_protocol_list; size_t alpn_protocol_list_length; @@ -77,6 +83,7 @@ struct tsi_ssl_server_handshaker_factory { /* Several contexts to support SNI. The tsi_peer array contains the subject names of the server certificates associated with the contexts at the same index. */ + tsi_ssl_handshaker_factory base; SSL_CTX **ssl_contexts; tsi_peer *ssl_context_x509_subject_names; size_t ssl_context_count; @@ -90,6 +97,7 @@ typedef struct { BIO *into_ssl; BIO *from_ssl; tsi_result result; + tsi_ssl_handshaker_factory *factory_ref; } tsi_ssl_handshaker; typedef struct { @@ -846,6 +854,47 @@ static const tsi_frame_protector_vtable frame_protector_vtable = { ssl_protector_destroy, }; +/* --- tsi_server_handshaker_factory methods implementation. --- */ + +static void tsi_ssl_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *self) { + if (self == NULL) return; + + if (self->vtable != NULL && self->vtable->destroy != NULL) { + self->vtable->destroy(self); + } + /* Note, we don't free(self) here because this object is always directly + * embedded in another object. If tsi_ssl_handshaker_factory_init allocates + * any memory, it should be free'd here. */ +} + +static tsi_ssl_handshaker_factory *tsi_ssl_handshaker_factory_ref( + tsi_ssl_handshaker_factory *self) { + if (self == NULL) return NULL; + gpr_refn(&self->refcount, 1); + return self; +} + +static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory *self) { + if (self == NULL) return; + + if (gpr_unref(&self->refcount)) { + tsi_ssl_handshaker_factory_destroy(self); + } +} + +static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {NULL}; + +/* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for + * allocating memory for the factory. */ +static void tsi_ssl_handshaker_factory_init( + tsi_ssl_handshaker_factory *factory) { + GPR_ASSERT(factory != NULL); + + factory->vtable = &handshaker_factory_vtable; + gpr_ref_init(&factory->refcount, 1); +} + /* --- tsi_handshaker methods implementation. ---*/ static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, @@ -1013,6 +1062,7 @@ static tsi_result ssl_handshaker_create_frame_protector( static void ssl_handshaker_destroy(tsi_handshaker *self) { tsi_ssl_handshaker *impl = (tsi_ssl_handshaker *)self; SSL_free(impl->ssl); /* The BIO objects are owned by ssl */ + tsi_ssl_handshaker_factory_unref(impl->factory_ref); gpr_free(impl); } @@ -1030,6 +1080,7 @@ static const tsi_handshaker_vtable handshaker_vtable = { static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client, const char *server_name_indication, + tsi_ssl_handshaker_factory *factory, tsi_handshaker **handshaker) { SSL *ssl = SSL_new(ctx); BIO *into_ssl = NULL; @@ -1085,6 +1136,8 @@ static tsi_result create_tsi_ssl_handshaker(SSL_CTX *ctx, int is_client, impl->from_ssl = from_ssl; impl->result = TSI_HANDSHAKE_IN_PROGRESS; impl->base.vtable = &handshaker_vtable; + impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory); + *handshaker = &impl->base; return TSI_OK; } @@ -1121,11 +1174,20 @@ tsi_result tsi_ssl_client_handshaker_factory_create_handshaker( tsi_ssl_client_handshaker_factory *self, const char *server_name_indication, tsi_handshaker **handshaker) { return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication, - handshaker); + &self->base, handshaker); } -void tsi_ssl_client_handshaker_factory_destroy( +void tsi_ssl_client_handshaker_factory_unref( tsi_ssl_client_handshaker_factory *self) { + if (self == NULL) return; + tsi_ssl_handshaker_factory_unref(&self->base); +} + +static void tsi_ssl_client_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *factory) { + if (factory == NULL) return; + tsi_ssl_client_handshaker_factory *self = + (tsi_ssl_client_handshaker_factory *)factory; if (self->ssl_context != NULL) SSL_CTX_free(self->ssl_context); if (self->alpn_protocol_list != NULL) gpr_free(self->alpn_protocol_list); gpr_free(self); @@ -1150,11 +1212,21 @@ tsi_result tsi_ssl_server_handshaker_factory_create_handshaker( if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT; /* Create the handshaker with the first context. We will switch if needed because of SNI in ssl_server_handshaker_factory_servername_callback. */ - return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, handshaker); + return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, NULL, &self->base, + handshaker); } -void tsi_ssl_server_handshaker_factory_destroy( +void tsi_ssl_server_handshaker_factory_unref( tsi_ssl_server_handshaker_factory *self) { + if (self == NULL) return; + tsi_ssl_handshaker_factory_unref(&self->base); +} + +static void tsi_ssl_server_handshaker_factory_destroy( + tsi_ssl_handshaker_factory *factory) { + if (factory == NULL) return; + tsi_ssl_server_handshaker_factory *self = + (tsi_ssl_server_handshaker_factory *)factory; size_t i; for (i = 0; i < self->ssl_context_count; i++) { if (self->ssl_contexts[i] != NULL) { @@ -1263,6 +1335,9 @@ static int server_handshaker_factory_npn_advertised_callback( /* --- tsi_ssl_handshaker_factory constructors. --- */ +static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = { + tsi_ssl_client_handshaker_factory_destroy}; + tsi_result tsi_create_ssl_client_handshaker_factory( const tsi_ssl_pem_key_cert_pair *pem_key_cert_pair, const char *pem_root_certs, const char *cipher_suites, @@ -1285,6 +1360,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory( } impl = gpr_zalloc(sizeof(*impl)); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &client_handshaker_factory_vtable; + impl->ssl_context = ssl_context; do { @@ -1322,7 +1400,7 @@ tsi_result tsi_create_ssl_client_handshaker_factory( } } while (0); if (result != TSI_OK) { - tsi_ssl_client_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NULL); @@ -1332,6 +1410,9 @@ tsi_result tsi_create_ssl_client_handshaker_factory( return TSI_OK; } +static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = { + tsi_ssl_server_handshaker_factory_destroy}; + tsi_result tsi_create_ssl_server_handshaker_factory( const tsi_ssl_pem_key_cert_pair *pem_key_cert_pairs, size_t num_key_cert_pairs, const char *pem_client_root_certs, @@ -1364,12 +1445,15 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( } impl = gpr_zalloc(sizeof(*impl)); + tsi_ssl_handshaker_factory_init(&impl->base); + impl->base.vtable = &server_handshaker_factory_vtable; + impl->ssl_contexts = gpr_zalloc(num_key_cert_pairs * sizeof(SSL_CTX *)); impl->ssl_context_x509_subject_names = gpr_zalloc(num_key_cert_pairs * sizeof(tsi_peer)); if (impl->ssl_contexts == NULL || impl->ssl_context_x509_subject_names == NULL) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return TSI_OUT_OF_RESOURCES; } impl->ssl_context_count = num_key_cert_pairs; @@ -1379,7 +1463,7 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( &impl->alpn_protocol_list, &impl->alpn_protocol_list_length); if (result != TSI_OK) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } } @@ -1451,10 +1535,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_ex( } while (0); if (result != TSI_OK) { - tsi_ssl_server_handshaker_factory_destroy(impl); + tsi_ssl_handshaker_factory_unref(&impl->base); return result; } } + *factory = impl; return TSI_OK; } @@ -1501,3 +1586,15 @@ int tsi_ssl_peer_matches_name(const tsi_peer *peer, const char *name) { return 0; /* Not found. */ } + +/* --- Testing support. --- */ +const tsi_ssl_handshaker_factory_vtable *tsi_ssl_handshaker_factory_swap_vtable( + tsi_ssl_handshaker_factory *factory, + tsi_ssl_handshaker_factory_vtable *new_vtable) { + GPR_ASSERT(factory != NULL); + GPR_ASSERT(factory->vtable != NULL); + + const tsi_ssl_handshaker_factory_vtable *orig_vtable = factory->vtable; + factory->vtable = new_vtable; + return orig_vtable; +} |