diff options
Diffstat (limited to 'src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc')
-rw-r--r-- | src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc | 304 |
1 files changed, 126 insertions, 178 deletions
diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index 1df1021bb1..1b7e58d3ce 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -26,33 +26,34 @@ #include <grpc/support/alloc.h> #include <grpc/support/log.h> +#include <grpc/support/string_util.h> #include <grpc/support/sync.h> #include <grpc/support/thd_id.h> #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/slice/slice_internal.h" #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h" #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" +#include "src/core/tsi/alts/handshaker/alts_shared_resource.h" #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" -#include "src/core/tsi/alts_transport_security.h" - -#define TSI_ALTS_INITIAL_BUFFER_SIZE 256 - -static alts_shared_resource* kSharedResource = alts_get_shared_resource(); /* Main struct for ALTS TSI handshaker. */ -typedef struct alts_tsi_handshaker { +struct alts_tsi_handshaker { tsi_handshaker base; alts_handshaker_client* client; - grpc_slice recv_bytes; grpc_slice target_name; - unsigned char* buffer; - size_t buffer_size; bool is_client; bool has_sent_start_message; + bool has_created_handshaker_client; + char* handshaker_service_url; + grpc_pollset_set* interested_parties; grpc_alts_credentials_options* options; -} alts_tsi_handshaker; + alts_handshaker_client_vtable* client_vtable_for_testing; + grpc_channel* channel; +}; /* Main struct for ALTS TSI handshaker result. */ typedef struct alts_tsi_handshaker_result { @@ -182,7 +183,7 @@ static void handshaker_result_destroy(tsi_handshaker_result* self) { gpr_free(result->peer_identity); gpr_free(result->key_data); gpr_free(result->unused_bytes); - grpc_slice_unref(result->rpc_versions); + grpc_slice_unref_internal(result->rpc_versions); gpr_free(result); } @@ -192,9 +193,9 @@ static const tsi_handshaker_result_vtable result_vtable = { handshaker_result_create_frame_protector, handshaker_result_get_unused_bytes, handshaker_result_destroy}; -static tsi_result create_handshaker_result(grpc_gcp_handshaker_resp* resp, - bool is_client, - tsi_handshaker_result** self) { +tsi_result alts_tsi_handshaker_result_create(grpc_gcp_handshaker_resp* resp, + bool is_client, + tsi_handshaker_result** self) { if (self == nullptr || resp == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()"); return TSI_INVALID_ARGUMENT; @@ -233,6 +234,27 @@ static tsi_result create_handshaker_result(grpc_gcp_handshaker_resp* resp, return TSI_OK; } +/* gRPC provided callback used when gRPC thread model is applied. */ +static void on_handshaker_service_resp_recv(void* arg, grpc_error* error) { + alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg); + if (client == nullptr) { + gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr"); + return; + } + alts_handshaker_client_handle_response(client, true); +} + +/* gRPC provided callback used when dedicatd CQ and thread are used. + * It serves to safely bring the control back to application. */ +static void on_handshaker_service_resp_recv_dedicated(void* arg, + grpc_error* error) { + alts_shared_resource_dedicated* resource = + grpc_alts_get_shared_resource_dedicated(); + grpc_cq_end_op(resource->cq, arg, GRPC_ERROR_NONE, + [](void* done_arg, grpc_cq_completion* storage) {}, nullptr, + &resource->storage); +} + static tsi_result handshaker_next( tsi_handshaker* self, const unsigned char* received_bytes, size_t received_bytes_size, const unsigned char** bytes_to_send, @@ -249,12 +271,36 @@ static tsi_result handshaker_next( alts_tsi_handshaker* handshaker = reinterpret_cast<alts_tsi_handshaker*>(self); tsi_result ok = TSI_OK; - alts_tsi_event* event = nullptr; - ok = alts_tsi_event_create(handshaker, cb, user_data, handshaker->options, - handshaker->target_name, &event); - if (ok != TSI_OK) { - gpr_log(GPR_ERROR, "Failed to create ALTS TSI event"); - return ok; + if (!handshaker->has_created_handshaker_client) { + if (handshaker->channel == nullptr) { + grpc_alts_shared_resource_dedicated_start( + handshaker->handshaker_service_url); + handshaker->interested_parties = + grpc_alts_get_shared_resource_dedicated()->interested_parties; + GPR_ASSERT(handshaker->interested_parties != nullptr); + } + grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr + ? on_handshaker_service_resp_recv_dedicated + : on_handshaker_service_resp_recv; + grpc_channel* channel = + handshaker->channel == nullptr + ? grpc_alts_get_shared_resource_dedicated()->channel + : handshaker->channel; + handshaker->client = alts_grpc_handshaker_client_create( + handshaker, channel, handshaker->handshaker_service_url, + handshaker->interested_parties, handshaker->options, + handshaker->target_name, grpc_cb, cb, user_data, + handshaker->client_vtable_for_testing, handshaker->is_client); + if (handshaker->client == nullptr) { + gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); + return TSI_FAILED_PRECONDITION; + } + handshaker->has_created_handshaker_client = true; + } + if (handshaker->channel == nullptr && + handshaker->client_vtable_for_testing == nullptr) { + GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq, + handshaker->client)); } grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0) ? grpc_empty_slice() @@ -263,18 +309,13 @@ static tsi_result handshaker_next( received_bytes_size); if (!handshaker->has_sent_start_message) { ok = handshaker->is_client - ? alts_handshaker_client_start_client(handshaker->client, event) - : alts_handshaker_client_start_server(handshaker->client, event, - &slice); + ? alts_handshaker_client_start_client(handshaker->client) + : alts_handshaker_client_start_server(handshaker->client, &slice); handshaker->has_sent_start_message = true; } else { - if (!GRPC_SLICE_IS_EMPTY(handshaker->recv_bytes)) { - grpc_slice_unref(handshaker->recv_bytes); - } - handshaker->recv_bytes = grpc_slice_ref(slice); - ok = alts_handshaker_client_next(handshaker->client, event, &slice); + ok = alts_handshaker_client_next(handshaker->client, &slice); } - grpc_slice_unref(slice); + grpc_slice_unref_internal(slice); if (ok != TSI_OK) { gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests"); return ok; @@ -282,6 +323,22 @@ static tsi_result handshaker_next( return TSI_ASYNC; } +/* + * This API will be invoked by a non-gRPC application, and an ExecCtx needs + * to be explicitly created in order to invoke ALTS handshaker client API's + * that assumes the caller is inside gRPC core. + */ +static tsi_result handshaker_next_dedicated( + tsi_handshaker* self, const unsigned char* received_bytes, + size_t received_bytes_size, const unsigned char** bytes_to_send, + size_t* bytes_to_send_size, tsi_handshaker_result** result, + tsi_handshaker_on_next_done_cb cb, void* user_data) { + grpc_core::ExecCtx exec_ctx; + return handshaker_next(self, received_bytes, received_bytes_size, + bytes_to_send, bytes_to_send_size, result, cb, + user_data); +} + static void handshaker_shutdown(tsi_handshaker* self) { GPR_ASSERT(self != nullptr); if (self->handshake_shutdown) { @@ -299,10 +356,12 @@ static void handshaker_destroy(tsi_handshaker* self) { alts_tsi_handshaker* handshaker = reinterpret_cast<alts_tsi_handshaker*>(self); alts_handshaker_client_destroy(handshaker->client); - grpc_slice_unref(handshaker->recv_bytes); - grpc_slice_unref(handshaker->target_name); + grpc_slice_unref_internal(handshaker->target_name); grpc_alts_credentials_options_destroy(handshaker->options); - gpr_free(handshaker->buffer); + if (handshaker->channel != nullptr) { + grpc_channel_destroy(handshaker->channel); + } + gpr_free(handshaker->handshaker_service_url); gpr_free(handshaker); } @@ -312,80 +371,57 @@ static const tsi_handshaker_vtable handshaker_vtable = { nullptr, handshaker_destroy, handshaker_next, handshaker_shutdown}; -static void thread_worker(void* arg) { - while (true) { - grpc_event event = grpc_completion_queue_next( - kSharedResource->cq, gpr_inf_future(GPR_CLOCK_REALTIME), nullptr); - GPR_ASSERT(event.type != GRPC_QUEUE_TIMEOUT); - if (event.type == GRPC_QUEUE_SHUTDOWN) { - /* signal alts_tsi_shutdown() to destroy completion queue. */ - grpc_tsi_alts_signal_for_cq_destroy(); - break; - } - /* event.type == GRPC_OP_COMPLETE. */ - alts_tsi_event* alts_event = static_cast<alts_tsi_event*>(event.tag); - alts_tsi_event_dispatch_to_handshaker(alts_event, event.success); - alts_tsi_event_destroy(alts_event); - } -} - -static void init_shared_resources(const char* handshaker_service_url) { - GPR_ASSERT(handshaker_service_url != nullptr); - gpr_mu_lock(&kSharedResource->mu); - if (kSharedResource->channel == nullptr) { - gpr_cv_init(&kSharedResource->cv); - kSharedResource->channel = - grpc_insecure_channel_create(handshaker_service_url, nullptr, nullptr); - kSharedResource->cq = grpc_completion_queue_create_for_next(nullptr); - kSharedResource->thread = - grpc_core::Thread("alts_tsi_handshaker", &thread_worker, nullptr); - kSharedResource->thread.Start(); - } - gpr_mu_unlock(&kSharedResource->mu); +static const tsi_handshaker_vtable handshaker_vtable_dedicated = { + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + handshaker_destroy, + handshaker_next_dedicated, + handshaker_shutdown}; + +bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) { + GPR_ASSERT(handshaker != nullptr); + return handshaker->base.handshake_shutdown; } tsi_result alts_tsi_handshaker_create( const grpc_alts_credentials_options* options, const char* target_name, - const char* handshaker_service_url, bool is_client, tsi_handshaker** self) { + const char* handshaker_service_url, bool is_client, + grpc_pollset_set* interested_parties, tsi_handshaker** self) { if (handshaker_service_url == nullptr || self == nullptr || options == nullptr || (is_client && target_name == nullptr)) { gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()"); return TSI_INVALID_ARGUMENT; } - init_shared_resources(handshaker_service_url); - alts_handshaker_client* client = alts_grpc_handshaker_client_create( - kSharedResource->channel, kSharedResource->cq, handshaker_service_url); - if (client == nullptr) { - gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); - return TSI_FAILED_PRECONDITION; - } alts_tsi_handshaker* handshaker = static_cast<alts_tsi_handshaker*>(gpr_zalloc(sizeof(*handshaker))); - handshaker->client = client; - handshaker->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE; - handshaker->buffer = - static_cast<unsigned char*>(gpr_zalloc(handshaker->buffer_size)); + bool use_dedicated_cq = interested_parties == nullptr; + handshaker->client = nullptr; handshaker->is_client = is_client; handshaker->has_sent_start_message = false; handshaker->target_name = target_name == nullptr ? grpc_empty_slice() : grpc_slice_from_static_string(target_name); + handshaker->interested_parties = interested_parties; + handshaker->has_created_handshaker_client = false; + handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); handshaker->options = grpc_alts_credentials_options_copy(options); - handshaker->base.vtable = &handshaker_vtable; + handshaker->base.vtable = + use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable; + handshaker->channel = + use_dedicated_cq + ? nullptr + : grpc_insecure_channel_create(handshaker->handshaker_service_url, + nullptr, nullptr); *self = &handshaker->base; return TSI_OK; } -static bool is_handshake_finished_properly(grpc_gcp_handshaker_resp* resp) { - GPR_ASSERT(resp != nullptr); - if (resp->has_result) { - return true; - } - return false; -} - -static void set_unused_bytes(tsi_handshaker_result* self, - grpc_slice* recv_bytes, size_t bytes_consumed) { +void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* self, + grpc_slice* recv_bytes, + size_t bytes_consumed) { GPR_ASSERT(recv_bytes != nullptr && self != nullptr); if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) { return; @@ -400,81 +436,6 @@ static void set_unused_bytes(tsi_handshaker_result* self, result->unused_bytes_size); } -void alts_tsi_handshaker_handle_response(alts_tsi_handshaker* handshaker, - grpc_byte_buffer* recv_buffer, - grpc_status_code status, - grpc_slice* details, - tsi_handshaker_on_next_done_cb cb, - void* user_data, bool is_ok) { - /* Invalid input check. */ - if (cb == nullptr) { - gpr_log(GPR_ERROR, - "cb is nullptr in alts_tsi_handshaker_handle_response()"); - return; - } - if (handshaker == nullptr || recv_buffer == nullptr) { - gpr_log(GPR_ERROR, - "Invalid arguments to alts_tsi_handshaker_handle_response()"); - cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); - return; - } - if (handshaker->base.handshake_shutdown) { - gpr_log(GPR_ERROR, "TSI handshake shutdown"); - cb(TSI_HANDSHAKE_SHUTDOWN, user_data, nullptr, 0, nullptr); - return; - } - /* Failed grpc call check. */ - if (!is_ok || status != GRPC_STATUS_OK) { - gpr_log(GPR_ERROR, "grpc call made to handshaker service failed"); - if (details != nullptr) { - char* error_details = grpc_slice_to_c_string(*details); - gpr_log(GPR_ERROR, "error details:%s", error_details); - gpr_free(error_details); - } - cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); - return; - } - grpc_gcp_handshaker_resp* resp = - alts_tsi_utils_deserialize_response(recv_buffer); - /* Invalid handshaker response check. */ - if (resp == nullptr) { - gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed"); - cb(TSI_DATA_CORRUPTED, user_data, nullptr, 0, nullptr); - return; - } - grpc_slice* slice = static_cast<grpc_slice*>(resp->out_frames.arg); - unsigned char* bytes_to_send = nullptr; - size_t bytes_to_send_size = 0; - if (slice != nullptr) { - bytes_to_send_size = GRPC_SLICE_LENGTH(*slice); - while (bytes_to_send_size > handshaker->buffer_size) { - handshaker->buffer_size *= 2; - handshaker->buffer = static_cast<unsigned char*>( - gpr_realloc(handshaker->buffer, handshaker->buffer_size)); - } - memcpy(handshaker->buffer, GRPC_SLICE_START_PTR(*slice), - bytes_to_send_size); - bytes_to_send = handshaker->buffer; - } - tsi_handshaker_result* result = nullptr; - if (is_handshake_finished_properly(resp)) { - create_handshaker_result(resp, handshaker->is_client, &result); - set_unused_bytes(result, &handshaker->recv_bytes, resp->bytes_consumed); - } - grpc_status_code code = static_cast<grpc_status_code>(resp->status.code); - if (code != GRPC_STATUS_OK) { - grpc_slice* details = static_cast<grpc_slice*>(resp->status.details.arg); - if (details != nullptr) { - char* error_details = grpc_slice_to_c_string(*details); - gpr_log(GPR_ERROR, "Error from handshaker service:%s", error_details); - gpr_free(error_details); - } - } - grpc_gcp_handshaker_resp_destroy(resp); - cb(alts_tsi_utils_convert_to_tsi_result(code), user_data, bytes_to_send, - bytes_to_send_size, result); -} - namespace grpc_core { namespace internal { @@ -484,29 +445,16 @@ bool alts_tsi_handshaker_get_has_sent_start_message_for_testing( return handshaker->has_sent_start_message; } -bool alts_tsi_handshaker_get_is_client_for_testing( - alts_tsi_handshaker* handshaker) { +void alts_tsi_handshaker_set_client_vtable_for_testing( + alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) { GPR_ASSERT(handshaker != nullptr); - return handshaker->is_client; -} - -void alts_tsi_handshaker_set_recv_bytes_for_testing( - alts_tsi_handshaker* handshaker, grpc_slice* slice) { - GPR_ASSERT(handshaker != nullptr && slice != nullptr); - handshaker->recv_bytes = grpc_slice_ref(*slice); + handshaker->client_vtable_for_testing = vtable; } -grpc_slice alts_tsi_handshaker_get_recv_bytes_for_testing( +bool alts_tsi_handshaker_get_is_client_for_testing( alts_tsi_handshaker* handshaker) { GPR_ASSERT(handshaker != nullptr); - return handshaker->recv_bytes; -} - -void alts_tsi_handshaker_set_client_for_testing( - alts_tsi_handshaker* handshaker, alts_handshaker_client* client) { - GPR_ASSERT(handshaker != nullptr && client != nullptr); - alts_handshaker_client_destroy(handshaker->client); - handshaker->client = client; + return handshaker->is_client; } alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing( |