diff options
Diffstat (limited to 'src/core/lib/security/transport/client_auth_filter.cc')
-rw-r--r-- | src/core/lib/security/transport/client_auth_filter.cc | 100 |
1 files changed, 50 insertions, 50 deletions
diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 6955e8698e..66f86b8bc5 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -55,7 +55,7 @@ struct call_data { // that the memory is not initialized. void destroy() { grpc_credentials_mdelem_array_destroy(&md_array); - grpc_call_credentials_unref(creds); + creds.reset(); grpc_slice_unref_internal(host); grpc_slice_unref_internal(method); grpc_auth_metadata_context_reset(&auth_md_context); @@ -64,7 +64,7 @@ struct call_data { gpr_arena* arena; grpc_call_stack* owning_call; grpc_call_combiner* call_combiner; - grpc_call_credentials* creds = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> creds; grpc_slice host = grpc_empty_slice(); grpc_slice method = grpc_empty_slice(); /* pollset{_set} bound to this call; if we need to make external @@ -83,8 +83,18 @@ struct call_data { /* We can have a per-channel credentials. */ struct channel_data { - grpc_channel_security_connector* security_connector; - grpc_auth_context* auth_context; + channel_data(grpc_channel_security_connector* security_connector, + grpc_auth_context* auth_context) + : security_connector( + security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")), + auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {} + ~channel_data() { + security_connector.reset(DEBUG_LOCATION, "client_auth_filter"); + auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; }; } // namespace @@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset( gpr_free(const_cast<char*>(auth_md_context->method_name)); auth_md_context->method_name = nullptr; } - GRPC_AUTH_CONTEXT_UNREF( - (grpc_auth_context*)auth_md_context->channel_auth_context, - "grpc_auth_metadata_context"); - auth_md_context->channel_auth_context = nullptr; + if (auth_md_context->channel_auth_context != nullptr) { + const_cast<grpc_auth_context*>(auth_md_context->channel_auth_context) + ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context"); + auth_md_context->channel_auth_context = nullptr; + } } static void add_error(grpc_error** combined, grpc_error* error) { @@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build( auth_md_context->service_url = service_url; auth_md_context->method_name = method_name; auth_md_context->channel_auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context"); + auth_context == nullptr + ? nullptr + : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context") + .release(); gpr_free(service); gpr_free(host_and_port); } @@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); if (error != GRPC_ERROR_NONE) { - grpc_call_credentials_cancel_get_request_metadata( - calld->creds, &calld->md_array, GRPC_ERROR_REF(error)); + calld->creds->cancel_get_request_metadata(&calld->md_array, + GRPC_ERROR_REF(error)); } } @@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem, static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); grpc_call_credentials* channel_call_creds = - chand->security_connector->request_metadata_creds; + chand->security_connector->mutable_request_metadata_creds(); int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); if (channel_call_creds == nullptr && !call_creds_has_md) { @@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem, } if (channel_call_creds != nullptr && call_creds_has_md) { - calld->creds = grpc_composite_call_credentials_create(channel_call_creds, - ctx->creds, nullptr); + calld->creds = grpc_core::RefCountedPtr<grpc_call_credentials>( + grpc_composite_call_credentials_create(channel_call_creds, + ctx->creds.get(), nullptr)); if (calld->creds == nullptr) { grpc_transport_stream_op_batch_finish_with_failure( batch, @@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem, return; } } else { - calld->creds = grpc_call_credentials_ref( - call_creds_has_md ? ctx->creds : channel_call_creds); + calld->creds = + call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref(); } grpc_auth_metadata_context_build( - chand->security_connector->base.url_scheme, calld->host, calld->method, - chand->auth_context, &calld->auth_md_context); + chand->security_connector->url_scheme(), calld->host, calld->method, + chand->auth_context.get(), &calld->auth_md_context); GPR_ASSERT(calld->pollent != nullptr); GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata"); GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata, batch, grpc_schedule_on_exec_ctx); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - calld->creds, calld->pollent, calld->auth_md_context, - &calld->md_array, &calld->async_result_closure, &error)) { + if (calld->creds->get_request_metadata( + calld->pollent, calld->auth_md_context, &calld->md_array, + &calld->async_result_closure, &error)) { // Synchronous return; invoke on_credentials_metadata() directly. on_credentials_metadata(batch, error); GRPC_ERROR_UNREF(error); @@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) { call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); if (error != GRPC_ERROR_NONE) { - grpc_channel_security_connector_cancel_check_call_host( - chand->security_connector, &calld->async_result_closure, - GRPC_ERROR_REF(error)); + chand->security_connector->cancel_check_call_host( + &calld->async_result_closure, GRPC_ERROR_REF(error)); } } @@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch( GPR_ASSERT(batch->payload->context != nullptr); if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) { batch->payload->context[GRPC_CONTEXT_SECURITY].value = - grpc_client_security_context_create(calld->arena); + grpc_client_security_context_create(calld->arena, /*creds=*/nullptr); batch->payload->context[GRPC_CONTEXT_SECURITY].destroy = grpc_client_security_context_destroy; } grpc_client_security_context* sec_ctx = static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); - GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter"); + sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); sec_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter"); } if (batch->send_initial_metadata) { @@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch( grpc_schedule_on_exec_ctx); char* call_host = grpc_slice_to_c_string(calld->host); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_channel_security_connector_check_call_host( - chand->security_connector, call_host, chand->auth_context, + if (chand->security_connector->check_call_host( + call_host, chand->auth_context.get(), &calld->async_result_closure, &error)) { // Synchronous return; invoke on_host_checked() directly. on_host_checked(batch, error); @@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem, /* Constructor for channel_data */ static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { + /* The first and the last filters tend to be implemented differently to + handle the case that there's no 'next' filter to call on the up or down + path */ + GPR_ASSERT(!args->is_last); grpc_security_connector* sc = grpc_security_connector_find_in_args(args->channel_args); if (sc == nullptr) { @@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Auth context missing from client auth filter args"); } - - /* grab pointers to our data from the channel element */ - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - - /* The first and the last filters tend to be implemented differently to - handle the case that there's no 'next' filter to call on the up or down - path */ - GPR_ASSERT(!args->is_last); - - /* initialize members */ - chand->security_connector = - reinterpret_cast<grpc_channel_security_connector*>( - GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter")); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter"); + new (elem->channel_data) channel_data( + static_cast<grpc_channel_security_connector*>(sc), auth_context); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { - /* grab pointers to our data from the channel element */ channel_data* chand = static_cast<channel_data*>(elem->channel_data); - grpc_channel_security_connector* sc = chand->security_connector; - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter"); - } - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter"); + chand->~channel_data(); } const grpc_channel_filter grpc_client_auth_filter = { |