diff options
Diffstat (limited to 'src')
177 files changed, 6133 insertions, 4200 deletions
diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index 70aac47231..dd741f1e2d 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -35,10 +35,10 @@ #include "src/core/ext/filters/client_channel/http_connect_handshaker.h" #include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/request_routing.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" #include "src/core/ext/filters/client_channel/resolver_result_parsing.h" #include "src/core/ext/filters/client_channel/retry_throttle.h" -#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel.h" #include "src/core/ext/filters/deadline/deadline_filter.h" #include "src/core/lib/backoff/backoff.h" @@ -63,7 +63,6 @@ #include "src/core/lib/transport/static_metadata.h" #include "src/core/lib/transport/status_metadata.h" -using grpc_core::ServerAddressList; using grpc_core::internal::ClientChannelMethodParams; using grpc_core::internal::ClientChannelMethodParamsTable; using grpc_core::internal::ProcessedResolverResult; @@ -88,31 +87,18 @@ grpc_core::TraceFlag grpc_client_channel_trace(false, "client_channel"); struct external_connectivity_watcher; typedef struct client_channel_channel_data { - grpc_core::OrphanablePtr<grpc_core::Resolver> resolver; - bool started_resolving; + grpc_core::ManualConstructor<grpc_core::RequestRouter> request_router; + bool deadline_checking_enabled; - grpc_client_channel_factory* client_channel_factory; bool enable_retries; size_t per_rpc_retry_buffer_size; /** combiner protecting all variables below in this data structure */ grpc_combiner* combiner; - /** currently active load balancer */ - grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> lb_policy; /** retry throttle data */ grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data; /** maps method names to method_parameters structs */ grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> method_params_table; - /** incoming resolver result - set by resolver.next() */ - grpc_channel_args* resolver_result; - /** a list of closures that are all waiting for resolver result to come in */ - grpc_closure_list waiting_for_resolver_result_closures; - /** resolver callback */ - grpc_closure on_resolver_result_changed; - /** connectivity state being tracked */ - grpc_connectivity_state_tracker state_tracker; - /** when an lb_policy arrives, should we try to exit idle */ - bool exit_idle_when_lb_policy_arrives; /** owning stack */ grpc_channel_stack* owning_stack; /** interested parties (owned) */ @@ -129,418 +115,40 @@ typedef struct client_channel_channel_data { grpc_core::UniquePtr<char> info_lb_policy_name; /** service config in JSON form */ grpc_core::UniquePtr<char> info_service_config_json; - /* backpointer to grpc_channel's channelz node */ - grpc_core::channelz::ClientChannelNode* channelz_channel; - /* caches if the last resolution event contained addresses */ - bool previous_resolution_contained_addresses; } channel_data; -typedef struct { - channel_data* chand; - /** used as an identifier, don't dereference it because the LB policy may be - * non-existing when the callback is run */ - grpc_core::LoadBalancingPolicy* lb_policy; - grpc_closure closure; -} reresolution_request_args; - -/** We create one watcher for each new lb_policy that is returned from a - resolver, to watch for state changes from the lb_policy. When a state - change is seen, we update the channel, and create a new watcher. */ -typedef struct { - channel_data* chand; - grpc_closure on_changed; - grpc_connectivity_state state; - grpc_core::LoadBalancingPolicy* lb_policy; -} lb_policy_connectivity_watcher; - -static void watch_lb_policy_locked(channel_data* chand, - grpc_core::LoadBalancingPolicy* lb_policy, - grpc_connectivity_state current_state); - -static const char* channel_connectivity_state_change_string( - grpc_connectivity_state state) { - switch (state) { - case GRPC_CHANNEL_IDLE: - return "Channel state change to IDLE"; - case GRPC_CHANNEL_CONNECTING: - return "Channel state change to CONNECTING"; - case GRPC_CHANNEL_READY: - return "Channel state change to READY"; - case GRPC_CHANNEL_TRANSIENT_FAILURE: - return "Channel state change to TRANSIENT_FAILURE"; - case GRPC_CHANNEL_SHUTDOWN: - return "Channel state change to SHUTDOWN"; - } - GPR_UNREACHABLE_CODE(return "UNKNOWN"); -} - -static void set_channel_connectivity_state_locked(channel_data* chand, - grpc_connectivity_state state, - grpc_error* error, - const char* reason) { - /* TODO: Improve failure handling: - * - Make it possible for policies to return GRPC_CHANNEL_TRANSIENT_FAILURE. - * - Hand over pending picks from old policies during the switch that happens - * when resolver provides an update. */ - if (chand->lb_policy != nullptr) { - if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { - /* cancel picks with wait_for_ready=false */ - chand->lb_policy->CancelMatchingPicksLocked( - /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY, - /* check= */ 0, GRPC_ERROR_REF(error)); - } else if (state == GRPC_CHANNEL_SHUTDOWN) { - /* cancel all picks */ - chand->lb_policy->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0, - GRPC_ERROR_REF(error)); - } - } - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: setting connectivity state to %s", chand, - grpc_connectivity_state_name(state)); - } - if (chand->channelz_channel != nullptr) { - chand->channelz_channel->AddTraceEvent( - grpc_core::channelz::ChannelTrace::Severity::Info, - grpc_slice_from_static_string( - channel_connectivity_state_change_string(state))); - } - grpc_connectivity_state_set(&chand->state_tracker, state, error, reason); -} - -static void on_lb_policy_state_changed_locked(void* arg, grpc_error* error) { - lb_policy_connectivity_watcher* w = - static_cast<lb_policy_connectivity_watcher*>(arg); - /* check if the notification is for the latest policy */ - if (w->lb_policy == w->chand->lb_policy.get()) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: lb_policy=%p state changed to %s", w->chand, - w->lb_policy, grpc_connectivity_state_name(w->state)); - } - set_channel_connectivity_state_locked(w->chand, w->state, - GRPC_ERROR_REF(error), "lb_changed"); - if (w->state != GRPC_CHANNEL_SHUTDOWN) { - watch_lb_policy_locked(w->chand, w->lb_policy, w->state); - } - } - GRPC_CHANNEL_STACK_UNREF(w->chand->owning_stack, "watch_lb_policy"); - gpr_free(w); -} - -static void watch_lb_policy_locked(channel_data* chand, - grpc_core::LoadBalancingPolicy* lb_policy, - grpc_connectivity_state current_state) { - lb_policy_connectivity_watcher* w = - static_cast<lb_policy_connectivity_watcher*>(gpr_malloc(sizeof(*w))); - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "watch_lb_policy"); - w->chand = chand; - GRPC_CLOSURE_INIT(&w->on_changed, on_lb_policy_state_changed_locked, w, - grpc_combiner_scheduler(chand->combiner)); - w->state = current_state; - w->lb_policy = lb_policy; - lb_policy->NotifyOnStateChangeLocked(&w->state, &w->on_changed); -} - -static void start_resolving_locked(channel_data* chand) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: starting name resolution", chand); - } - GPR_ASSERT(!chand->started_resolving); - chand->started_resolving = true; - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "resolver"); - chand->resolver->NextLocked(&chand->resolver_result, - &chand->on_resolver_result_changed); -} - -// Invoked from the resolver NextLocked() callback when the resolver -// is shutting down. -static void on_resolver_shutdown_locked(channel_data* chand, - grpc_error* error) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down", chand); - } - if (chand->lb_policy != nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand, - chand->lb_policy.get()); - } - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } - if (chand->resolver != nullptr) { - // This should never happen; it can only be triggered by a resolver - // implementation spotaneously deciding to report shutdown without - // being orphaned. This code is included just to be defensive. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: spontaneous shutdown from resolver %p", - chand, chand->resolver.get()); - } - chand->resolver.reset(); - set_channel_connectivity_state_locked( - chand, GRPC_CHANNEL_SHUTDOWN, - GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Resolver spontaneous shutdown", &error, 1), - "resolver_spontaneous_shutdown"); - } - grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures, - GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Channel disconnected", &error, 1)); - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "resolver"); - grpc_channel_args_destroy(chand->resolver_result); - chand->resolver_result = nullptr; - GRPC_ERROR_UNREF(error); -} - -static void request_reresolution_locked(void* arg, grpc_error* error) { - reresolution_request_args* args = - static_cast<reresolution_request_args*>(arg); - channel_data* chand = args->chand; - // If this invocation is for a stale LB policy, treat it as an LB shutdown - // signal. - if (args->lb_policy != chand->lb_policy.get() || error != GRPC_ERROR_NONE || - chand->resolver == nullptr) { - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "re-resolution"); - gpr_free(args); - return; - } - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: started name re-resolving", chand); - } - chand->resolver->RequestReresolutionLocked(); - // Give back the closure to the LB policy. - chand->lb_policy->SetReresolutionClosureLocked(&args->closure); -} - -using TraceStringVector = grpc_core::InlinedVector<char*, 3>; - -// Creates a new LB policy, replacing any previous one. -// If the new policy is created successfully, sets *connectivity_state and -// *connectivity_error to its initial connectivity state; otherwise, -// leaves them unchanged. -static void create_new_lb_policy_locked( - channel_data* chand, char* lb_policy_name, grpc_json* lb_config, - grpc_connectivity_state* connectivity_state, - grpc_error** connectivity_error, TraceStringVector* trace_strings) { - grpc_core::LoadBalancingPolicy::Args lb_policy_args; - lb_policy_args.combiner = chand->combiner; - lb_policy_args.client_channel_factory = chand->client_channel_factory; - lb_policy_args.args = chand->resolver_result; - lb_policy_args.lb_config = lb_config; - grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> new_lb_policy = - grpc_core::LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( - lb_policy_name, lb_policy_args); - if (GPR_UNLIKELY(new_lb_policy == nullptr)) { - gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name); - if (chand->channelz_channel != nullptr) { - char* str; - gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name); - trace_strings->push_back(str); - } - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: created new LB policy \"%s\" (%p)", chand, - lb_policy_name, new_lb_policy.get()); - } - if (chand->channelz_channel != nullptr) { - char* str; - gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name); - trace_strings->push_back(str); - } - // Swap out the LB policy and update the fds in - // chand->interested_parties. - if (chand->lb_policy != nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand, - chand->lb_policy.get()); - } - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy->HandOffPendingPicksLocked(new_lb_policy.get()); - } - chand->lb_policy = std::move(new_lb_policy); - grpc_pollset_set_add_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - // Set up re-resolution callback. - reresolution_request_args* args = - static_cast<reresolution_request_args*>(gpr_zalloc(sizeof(*args))); - args->chand = chand; - args->lb_policy = chand->lb_policy.get(); - GRPC_CLOSURE_INIT(&args->closure, request_reresolution_locked, args, - grpc_combiner_scheduler(chand->combiner)); - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "re-resolution"); - chand->lb_policy->SetReresolutionClosureLocked(&args->closure); - // Get the new LB policy's initial connectivity state and start a - // connectivity watch. - GRPC_ERROR_UNREF(*connectivity_error); - *connectivity_state = - chand->lb_policy->CheckConnectivityLocked(connectivity_error); - if (chand->exit_idle_when_lb_policy_arrives) { - chand->lb_policy->ExitIdleLocked(); - chand->exit_idle_when_lb_policy_arrives = false; - } - watch_lb_policy_locked(chand, chand->lb_policy.get(), *connectivity_state); - } -} - -static void maybe_add_trace_message_for_address_changes_locked( - channel_data* chand, TraceStringVector* trace_strings) { - const ServerAddressList* addresses = - grpc_core::FindServerAddressListChannelArg(chand->resolver_result); - const bool resolution_contains_addresses = - addresses != nullptr && addresses->size() > 0; - if (!resolution_contains_addresses && - chand->previous_resolution_contained_addresses) { - trace_strings->push_back(gpr_strdup("Address list became empty")); - } else if (resolution_contains_addresses && - !chand->previous_resolution_contained_addresses) { - trace_strings->push_back(gpr_strdup("Address list became non-empty")); - } - chand->previous_resolution_contained_addresses = - resolution_contains_addresses; -} - -static void concatenate_and_add_channel_trace_locked( - channel_data* chand, TraceStringVector* trace_strings) { - if (!trace_strings->empty()) { - gpr_strvec v; - gpr_strvec_init(&v); - gpr_strvec_add(&v, gpr_strdup("Resolution event: ")); - bool is_first = 1; - for (size_t i = 0; i < trace_strings->size(); ++i) { - if (!is_first) gpr_strvec_add(&v, gpr_strdup(", ")); - is_first = false; - gpr_strvec_add(&v, (*trace_strings)[i]); - } - char* flat; - size_t flat_len = 0; - flat = gpr_strvec_flatten(&v, &flat_len); - chand->channelz_channel->AddTraceEvent( - grpc_core::channelz::ChannelTrace::Severity::Info, - grpc_slice_new(flat, flat_len, gpr_free)); - gpr_strvec_destroy(&v); - } -} - -// Callback invoked when a resolver result is available. -static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { +// Synchronous callback from chand->request_router to process a resolver +// result update. +static bool process_resolver_result_locked(void* arg, + const grpc_channel_args& args, + const char** lb_policy_name, + grpc_json** lb_policy_config) { channel_data* chand = static_cast<channel_data*>(arg); + ProcessedResolverResult resolver_result(args, chand->enable_retries); + grpc_core::UniquePtr<char> service_config_json = + resolver_result.service_config_json(); if (grpc_client_channel_trace.enabled()) { - const char* disposition = - chand->resolver_result != nullptr - ? "" - : (error == GRPC_ERROR_NONE ? " (transient error)" - : " (resolver shutdown)"); - gpr_log(GPR_INFO, - "chand=%p: got resolver result: resolver_result=%p error=%s%s", - chand, chand->resolver_result, grpc_error_string(error), - disposition); + gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"", + chand, service_config_json.get()); } - // Handle shutdown. - if (error != GRPC_ERROR_NONE || chand->resolver == nullptr) { - on_resolver_shutdown_locked(chand, GRPC_ERROR_REF(error)); - return; - } - // Data used to set the channel's connectivity state. - bool set_connectivity_state = true; - // We only want to trace the address resolution in the follow cases: - // (a) Address resolution resulted in service config change. - // (b) Address resolution that causes number of backends to go from - // zero to non-zero. - // (c) Address resolution that causes number of backends to go from - // non-zero to zero. - // (d) Address resolution that causes a new LB policy to be created. - // - // we track a list of strings to eventually be concatenated and traced. - TraceStringVector trace_strings; - grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; - grpc_error* connectivity_error = - GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy"); - // chand->resolver_result will be null in the case of a transient - // resolution error. In that case, we don't have any new result to - // process, which means that we keep using the previous result (if any). - if (chand->resolver_result == nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: resolver transient failure", chand); - } - // Don't override connectivity state if we already have an LB policy. - if (chand->lb_policy != nullptr) set_connectivity_state = false; - } else { - // Parse the resolver result. - ProcessedResolverResult resolver_result(chand->resolver_result, - chand->enable_retries); - chand->retry_throttle_data = resolver_result.retry_throttle_data(); - chand->method_params_table = resolver_result.method_params_table(); - grpc_core::UniquePtr<char> service_config_json = - resolver_result.service_config_json(); - if (service_config_json != nullptr && grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"", - chand, service_config_json.get()); - } - grpc_core::UniquePtr<char> lb_policy_name = - resolver_result.lb_policy_name(); - grpc_json* lb_policy_config = resolver_result.lb_policy_config(); - // Check to see if we're already using the right LB policy. - // Note: It's safe to use chand->info_lb_policy_name here without - // taking a lock on chand->info_mu, because this function is the - // only thing that modifies its value, and it can only be invoked - // once at any given time. - bool lb_policy_name_changed = - chand->info_lb_policy_name == nullptr || - strcmp(chand->info_lb_policy_name.get(), lb_policy_name.get()) != 0; - if (chand->lb_policy != nullptr && !lb_policy_name_changed) { - // Continue using the same LB policy. Update with new addresses. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: updating existing LB policy \"%s\" (%p)", - chand, lb_policy_name.get(), chand->lb_policy.get()); - } - chand->lb_policy->UpdateLocked(*chand->resolver_result, lb_policy_config); - // No need to set the channel's connectivity state; the existing - // watch on the LB policy will take care of that. - set_connectivity_state = false; - } else { - // Instantiate new LB policy. - create_new_lb_policy_locked(chand, lb_policy_name.get(), lb_policy_config, - &connectivity_state, &connectivity_error, - &trace_strings); - } - // Note: It's safe to use chand->info_service_config_json here without - // taking a lock on chand->info_mu, because this function is the - // only thing that modifies its value, and it can only be invoked - // once at any given time. - if (chand->channelz_channel != nullptr) { - if (((service_config_json == nullptr) != - (chand->info_service_config_json == nullptr)) || - (service_config_json != nullptr && - strcmp(service_config_json.get(), - chand->info_service_config_json.get()) != 0)) { - // TODO(ncteisen): might be worth somehow including a snippet of the - // config in the trace, at the risk of bloating the trace logs. - trace_strings.push_back(gpr_strdup("Service config changed")); - } - maybe_add_trace_message_for_address_changes_locked(chand, &trace_strings); - concatenate_and_add_channel_trace_locked(chand, &trace_strings); - } - // Swap out the data used by cc_get_channel_info(). - gpr_mu_lock(&chand->info_mu); - chand->info_lb_policy_name = std::move(lb_policy_name); - chand->info_service_config_json = std::move(service_config_json); - gpr_mu_unlock(&chand->info_mu); - // Clean up. - grpc_channel_args_destroy(chand->resolver_result); - chand->resolver_result = nullptr; - } - // Set the channel's connectivity state if needed. - if (set_connectivity_state) { - set_channel_connectivity_state_locked( - chand, connectivity_state, connectivity_error, "resolver_result"); - } else { - GRPC_ERROR_UNREF(connectivity_error); - } - // Invoke closures that were waiting for results and renew the watch. - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - chand->resolver->NextLocked(&chand->resolver_result, - &chand->on_resolver_result_changed); + // Update channel state. + chand->retry_throttle_data = resolver_result.retry_throttle_data(); + chand->method_params_table = resolver_result.method_params_table(); + // Swap out the data used by cc_get_channel_info(). + gpr_mu_lock(&chand->info_mu); + chand->info_lb_policy_name = resolver_result.lb_policy_name(); + const bool service_config_changed = + ((service_config_json == nullptr) != + (chand->info_service_config_json == nullptr)) || + (service_config_json != nullptr && + strcmp(service_config_json.get(), + chand->info_service_config_json.get()) != 0); + chand->info_service_config_json = std::move(service_config_json); + gpr_mu_unlock(&chand->info_mu); + // Return results. + *lb_policy_name = chand->info_lb_policy_name.get(); + *lb_policy_config = resolver_result.lb_policy_config(); + return service_config_changed; } static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { @@ -550,15 +158,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); if (op->on_connectivity_state_change != nullptr) { - grpc_connectivity_state_notify_on_state_change( - &chand->state_tracker, op->connectivity_state, - op->on_connectivity_state_change); + chand->request_router->NotifyOnConnectivityStateChange( + op->connectivity_state, op->on_connectivity_state_change); op->on_connectivity_state_change = nullptr; op->connectivity_state = nullptr; } if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) { - if (chand->lb_policy == nullptr) { + if (chand->request_router->lb_policy() == nullptr) { grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Ping with no load balancing"); GRPC_CLOSURE_SCHED(op->send_ping.on_initiate, GRPC_ERROR_REF(error)); @@ -567,7 +174,8 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { grpc_error* error = GRPC_ERROR_NONE; grpc_core::LoadBalancingPolicy::PickState pick_state; // Pick must return synchronously, because pick_state.on_complete is null. - GPR_ASSERT(chand->lb_policy->PickLocked(&pick_state, &error)); + GPR_ASSERT( + chand->request_router->lb_policy()->PickLocked(&pick_state, &error)); if (pick_state.connected_subchannel != nullptr) { pick_state.connected_subchannel->Ping(op->send_ping.on_initiate, op->send_ping.on_ack); @@ -586,37 +194,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { } if (op->disconnect_with_error != GRPC_ERROR_NONE) { - if (chand->resolver != nullptr) { - set_channel_connectivity_state_locked( - chand, GRPC_CHANNEL_SHUTDOWN, - GRPC_ERROR_REF(op->disconnect_with_error), "disconnect"); - chand->resolver.reset(); - if (!chand->started_resolving) { - grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures, - GRPC_ERROR_REF(op->disconnect_with_error)); - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - } - if (chand->lb_policy != nullptr) { - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } - } - GRPC_ERROR_UNREF(op->disconnect_with_error); + chand->request_router->ShutdownLocked(op->disconnect_with_error); } if (op->reset_connect_backoff) { - if (chand->resolver != nullptr) { - chand->resolver->ResetBackoffLocked(); - chand->resolver->RequestReresolutionLocked(); - } - if (chand->lb_policy != nullptr) { - chand->lb_policy->ResetBackoffLocked(); - } + chand->request_router->ResetConnectionBackoffLocked(); } GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "start_transport_op"); - GRPC_CLOSURE_SCHED(op->on_consumed, GRPC_ERROR_NONE); } @@ -667,12 +252,9 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); chand->owning_stack = args->channel_stack; - GRPC_CLOSURE_INIT(&chand->on_resolver_result_changed, - on_resolver_result_changed_locked, chand, - grpc_combiner_scheduler(chand->combiner)); + chand->deadline_checking_enabled = + grpc_deadline_checking_enabled(args->channel_args); chand->interested_parties = grpc_pollset_set_create(); - grpc_connectivity_state_init(&chand->state_tracker, GRPC_CHANNEL_IDLE, - "client_channel"); grpc_client_channel_start_backup_polling(chand->interested_parties); // Record max per-RPC retry buffer size. const grpc_arg* arg = grpc_channel_args_find( @@ -682,8 +264,6 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, // Record enable_retries. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_ENABLE_RETRIES); chand->enable_retries = grpc_channel_arg_get_bool(arg, true); - chand->channelz_channel = nullptr; - chand->previous_resolution_contained_addresses = false; // Record client channel factory. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_CLIENT_CHANNEL_FACTORY); @@ -695,9 +275,7 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "client channel factory arg must be a pointer"); } - grpc_client_channel_factory_ref( - static_cast<grpc_client_channel_factory*>(arg->value.pointer.p)); - chand->client_channel_factory = + grpc_client_channel_factory* client_channel_factory = static_cast<grpc_client_channel_factory*>(arg->value.pointer.p); // Get server name to resolve, using proxy mapper if needed. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_SERVER_URI); @@ -713,39 +291,24 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, grpc_channel_args* new_args = nullptr; grpc_proxy_mappers_map_name(arg->value.string, args->channel_args, &proxy_name, &new_args); - // Instantiate resolver. - chand->resolver = grpc_core::ResolverRegistry::CreateResolver( - proxy_name != nullptr ? proxy_name : arg->value.string, - new_args != nullptr ? new_args : args->channel_args, - chand->interested_parties, chand->combiner); - if (proxy_name != nullptr) gpr_free(proxy_name); - if (new_args != nullptr) grpc_channel_args_destroy(new_args); - if (chand->resolver == nullptr) { - return GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed"); - } - chand->deadline_checking_enabled = - grpc_deadline_checking_enabled(args->channel_args); - return GRPC_ERROR_NONE; + // Instantiate request router. + grpc_client_channel_factory_ref(client_channel_factory); + grpc_error* error = GRPC_ERROR_NONE; + chand->request_router.Init( + chand->owning_stack, chand->combiner, client_channel_factory, + chand->interested_parties, &grpc_client_channel_trace, + process_resolver_result_locked, chand, + proxy_name != nullptr ? proxy_name : arg->value.string /* target_uri */, + new_args != nullptr ? new_args : args->channel_args, &error); + gpr_free(proxy_name); + grpc_channel_args_destroy(new_args); + return error; } /* Destructor for channel_data */ static void cc_destroy_channel_elem(grpc_channel_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - if (chand->resolver != nullptr) { - // The only way we can get here is if we never started resolving, - // because we take a ref to the channel stack when we start - // resolving and do not release it until the resolver callback is - // invoked after the resolver shuts down. - chand->resolver.reset(); - } - if (chand->client_channel_factory != nullptr) { - grpc_client_channel_factory_unref(chand->client_channel_factory); - } - if (chand->lb_policy != nullptr) { - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } + chand->request_router.Destroy(); // TODO(roth): Once we convert the filter API to C++, there will no // longer be any need to explicitly reset these smart pointer data members. chand->info_lb_policy_name.reset(); @@ -753,7 +316,6 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) { chand->retry_throttle_data.reset(); chand->method_params_table.reset(); grpc_client_channel_stop_backup_polling(chand->interested_parties); - grpc_connectivity_state_destroy(&chand->state_tracker); grpc_pollset_set_destroy(chand->interested_parties); GRPC_COMBINER_UNREF(chand->combiner, "client_channel"); gpr_mu_destroy(&chand->info_mu); @@ -810,6 +372,7 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) { // - add census stats for retries namespace { + struct call_data; // State used for starting a retryable batch on a subchannel call. @@ -894,12 +457,12 @@ struct subchannel_call_retry_state { bool completed_recv_initial_metadata : 1; bool started_recv_trailing_metadata : 1; bool completed_recv_trailing_metadata : 1; + // State for callback processing. subchannel_batch_data* recv_initial_metadata_ready_deferred_batch = nullptr; grpc_error* recv_initial_metadata_error = GRPC_ERROR_NONE; subchannel_batch_data* recv_message_ready_deferred_batch = nullptr; grpc_error* recv_message_error = GRPC_ERROR_NONE; subchannel_batch_data* recv_trailing_metadata_internal_batch = nullptr; - // State for callback processing. // NOTE: Do not move this next to the metadata bitfields above. That would // save space but will also result in a data race because compiler will // generate a 2 byte store which overwrites the meta-data fields upon @@ -908,12 +471,12 @@ struct subchannel_call_retry_state { }; // Pending batches stored in call data. -typedef struct { +struct pending_batch { // The pending batch. If nullptr, this slot is empty. grpc_transport_stream_op_batch* batch; // Indicates whether payload for send ops has been cached in call data. bool send_ops_cached; -} pending_batch; +}; /** Call data. Holds a pointer to grpc_subchannel_call and the associated machinery to create such a pointer. @@ -950,11 +513,8 @@ struct call_data { for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches); ++i) { GPR_ASSERT(pending_batches[i].batch == nullptr); } - for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { - if (pick.subchannel_call_context[i].value != nullptr) { - pick.subchannel_call_context[i].destroy( - pick.subchannel_call_context[i].value); - } + if (have_request) { + request.Destroy(); } } @@ -981,12 +541,11 @@ struct call_data { // Set when we get a cancel_stream op. grpc_error* cancel_error = GRPC_ERROR_NONE; - grpc_core::LoadBalancingPolicy::PickState pick; + grpc_core::ManualConstructor<grpc_core::RequestRouter::Request> request; + bool have_request = false; grpc_closure pick_closure; - grpc_closure pick_cancel_closure; grpc_polling_entity* pollent = nullptr; - bool pollent_added_to_interested_parties = false; // Batches are added to this list when received from above. // They are removed when we are done handling the batch (i.e., when @@ -1036,6 +595,7 @@ struct call_data { grpc_linked_mdelem* send_trailing_metadata_storage = nullptr; grpc_metadata_batch send_trailing_metadata; }; + } // namespace // Forward declarations. @@ -1438,8 +998,9 @@ static void do_retry(grpc_call_element* elem, "client_channel_call_retry"); calld->subchannel_call = nullptr; } - if (calld->pick.connected_subchannel != nullptr) { - calld->pick.connected_subchannel.reset(); + if (calld->have_request) { + calld->have_request = false; + calld->request.Destroy(); } // Compute backoff delay. grpc_millis next_attempt_time; @@ -1588,6 +1149,7 @@ static bool maybe_retry(grpc_call_element* elem, // namespace { + subchannel_batch_data::subchannel_batch_data(grpc_call_element* elem, call_data* calld, int refcount, bool set_on_complete) @@ -1628,6 +1190,7 @@ void subchannel_batch_data::destroy() { call_data* calld = static_cast<call_data*>(elem->call_data); GRPC_CALL_STACK_UNREF(calld->owning_call, "batch_data"); } + } // namespace // Creates a subchannel_batch_data object on the call's arena with the @@ -2644,17 +2207,18 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { const size_t parent_data_size = calld->enable_retries ? sizeof(subchannel_call_retry_state) : 0; const grpc_core::ConnectedSubchannel::CallArgs call_args = { - calld->pollent, // pollent - calld->path, // path - calld->call_start_time, // start_time - calld->deadline, // deadline - calld->arena, // arena - calld->pick.subchannel_call_context, // context - calld->call_combiner, // call_combiner - parent_data_size // parent_data_size + calld->pollent, // pollent + calld->path, // path + calld->call_start_time, // start_time + calld->deadline, // deadline + calld->arena, // arena + calld->request->pick()->subchannel_call_context, // context + calld->call_combiner, // call_combiner + parent_data_size // parent_data_size }; - grpc_error* new_error = calld->pick.connected_subchannel->CreateCall( - call_args, &calld->subchannel_call); + grpc_error* new_error = + calld->request->pick()->connected_subchannel->CreateCall( + call_args, &calld->subchannel_call); if (grpc_client_channel_trace.enabled()) { gpr_log(GPR_INFO, "chand=%p calld=%p: create subchannel_call=%p: error=%s", chand, calld, calld->subchannel_call, grpc_error_string(new_error)); @@ -2666,7 +2230,8 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { if (parent_data_size > 0) { new (grpc_connected_subchannel_call_get_parent_data( calld->subchannel_call)) - subchannel_call_retry_state(calld->pick.subchannel_call_context); + subchannel_call_retry_state( + calld->request->pick()->subchannel_call_context); } pending_batches_resume(elem); } @@ -2678,7 +2243,7 @@ static void pick_done(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); channel_data* chand = static_cast<channel_data*>(elem->channel_data); call_data* calld = static_cast<call_data*>(elem->call_data); - if (GPR_UNLIKELY(calld->pick.connected_subchannel == nullptr)) { + if (GPR_UNLIKELY(calld->request->pick()->connected_subchannel == nullptr)) { // Failed to create subchannel. // If there was no error, this is an LB policy drop, in which case // we return an error; otherwise, we may retry. @@ -2707,135 +2272,27 @@ static void pick_done(void* arg, grpc_error* error) { } } -static void maybe_add_call_to_channel_interested_parties_locked( - grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (!calld->pollent_added_to_interested_parties) { - calld->pollent_added_to_interested_parties = true; - grpc_polling_entity_add_to_pollset_set(calld->pollent, - chand->interested_parties); - } -} - -static void maybe_del_call_from_channel_interested_parties_locked( - grpc_call_element* elem) { +// If the channel is in TRANSIENT_FAILURE and the call is not +// wait_for_ready=true, fails the call and returns true. +static bool fail_call_if_in_transient_failure(grpc_call_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); call_data* calld = static_cast<call_data*>(elem->call_data); - if (calld->pollent_added_to_interested_parties) { - calld->pollent_added_to_interested_parties = false; - grpc_polling_entity_del_from_pollset_set(calld->pollent, - chand->interested_parties); + grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch; + if (chand->request_router->GetConnectivityState() == + GRPC_CHANNEL_TRANSIENT_FAILURE && + (batch->payload->send_initial_metadata.send_initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) { + pending_batches_fail( + elem, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "channel is in state TRANSIENT_FAILURE"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), + true /* yield_call_combiner */); + return true; } + return false; } -// Invoked when a pick is completed to leave the client_channel combiner -// and continue processing in the call combiner. -// If needed, removes the call's polling entity from chand->interested_parties. -static void pick_done_locked(grpc_call_element* elem, grpc_error* error) { - call_data* calld = static_cast<call_data*>(elem->call_data); - maybe_del_call_from_channel_interested_parties_locked(elem); - GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_SCHED(&calld->pick_closure, error); -} - -namespace grpc_core { - -// Performs subchannel pick via LB policy. -class LbPicker { - public: - // Starts a pick on chand->lb_policy. - static void StartLocked(grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: starting pick on lb_policy=%p", - chand, calld, chand->lb_policy.get()); - } - // If this is a retry, use the send_initial_metadata payload that - // we've cached; otherwise, use the pending batch. The - // send_initial_metadata batch will be the first pending batch in the - // list, as set by get_batch_index() above. - calld->pick.initial_metadata = - calld->seen_send_initial_metadata - ? &calld->send_initial_metadata - : calld->pending_batches[0] - .batch->payload->send_initial_metadata.send_initial_metadata; - calld->pick.initial_metadata_flags = - calld->seen_send_initial_metadata - ? calld->send_initial_metadata_flags - : calld->pending_batches[0] - .batch->payload->send_initial_metadata - .send_initial_metadata_flags; - GRPC_CLOSURE_INIT(&calld->pick_closure, &LbPicker::DoneLocked, elem, - grpc_combiner_scheduler(chand->combiner)); - calld->pick.on_complete = &calld->pick_closure; - GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback"); - grpc_error* error = GRPC_ERROR_NONE; - const bool pick_done = chand->lb_policy->PickLocked(&calld->pick, &error); - if (GPR_LIKELY(pick_done)) { - // Pick completed synchronously. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed synchronously", - chand, calld); - } - pick_done_locked(elem, error); - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback"); - } else { - // Pick will be returned asynchronously. - // Add the polling entity from call_data to the channel_data's - // interested_parties, so that the I/O of the LB policy can be done - // under it. It will be removed in pick_done_locked(). - maybe_add_call_to_channel_interested_parties_locked(elem); - // Request notification on call cancellation. - GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback_cancel"); - grpc_call_combiner_set_notify_on_cancel( - calld->call_combiner, - GRPC_CLOSURE_INIT(&calld->pick_cancel_closure, - &LbPicker::CancelLocked, elem, - grpc_combiner_scheduler(chand->combiner))); - } - } - - private: - // Callback invoked by LoadBalancingPolicy::PickLocked() for async picks. - // Unrefs the LB policy and invokes pick_done_locked(). - static void DoneLocked(void* arg, grpc_error* error) { - grpc_call_element* elem = static_cast<grpc_call_element*>(arg); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed asynchronously", - chand, calld); - } - pick_done_locked(elem, GRPC_ERROR_REF(error)); - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback"); - } - - // Note: This runs under the client_channel combiner, but will NOT be - // holding the call combiner. - static void CancelLocked(void* arg, grpc_error* error) { - grpc_call_element* elem = static_cast<grpc_call_element*>(arg); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - // Note: chand->lb_policy may have changed since we started our pick, - // in which case we will be cancelling the pick on a policy other than - // the one we started it on. However, this will just be a no-op. - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE && chand->lb_policy != nullptr)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: cancelling pick from LB policy %p", chand, - calld, chand->lb_policy.get()); - } - chand->lb_policy->CancelPickLocked(&calld->pick, GRPC_ERROR_REF(error)); - } - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback_cancel"); - } -}; - -} // namespace grpc_core - // Applies service config to the call. Must be invoked once we know // that the resolver has returned results to the channel. static void apply_service_config_to_call_locked(grpc_call_element* elem) { @@ -2892,224 +2349,66 @@ static void apply_service_config_to_call_locked(grpc_call_element* elem) { } } -// If the channel is in TRANSIENT_FAILURE and the call is not -// wait_for_ready=true, fails the call and returns true. -static bool fail_call_if_in_transient_failure(grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch; - if (grpc_connectivity_state_check(&chand->state_tracker) == - GRPC_CHANNEL_TRANSIENT_FAILURE && - (batch->payload->send_initial_metadata.send_initial_metadata_flags & - GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) { - pending_batches_fail( - elem, - grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "channel is in state TRANSIENT_FAILURE"), - GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), - true /* yield_call_combiner */); - return true; - } - return false; -} - // Invoked once resolver results are available. -static void process_service_config_and_start_lb_pick_locked( - grpc_call_element* elem) { +static bool maybe_apply_service_config_to_call_locked(void* arg) { + grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); // Only get service config data on the first attempt. if (GPR_LIKELY(calld->num_attempts_completed == 0)) { apply_service_config_to_call_locked(elem); // Check this after applying service config, since it may have // affected the call's wait_for_ready value. - if (fail_call_if_in_transient_failure(elem)) return; + if (fail_call_if_in_transient_failure(elem)) return false; } - // Start LB pick. - grpc_core::LbPicker::StartLocked(elem); + return true; } -namespace grpc_core { - -// Handles waiting for a resolver result. -// Used only for the first call on an idle channel. -class ResolverResultWaiter { - public: - explicit ResolverResultWaiter(grpc_call_element* elem) : elem_(elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: deferring pick pending resolver result", - chand, calld); - } - // Add closure to be run when a resolver result is available. - GRPC_CLOSURE_INIT(&done_closure_, &ResolverResultWaiter::DoneLocked, this, - grpc_combiner_scheduler(chand->combiner)); - AddToWaitingList(); - // Set cancellation closure, so that we abort if the call is cancelled. - GRPC_CLOSURE_INIT(&cancel_closure_, &ResolverResultWaiter::CancelLocked, - this, grpc_combiner_scheduler(chand->combiner)); - grpc_call_combiner_set_notify_on_cancel(calld->call_combiner, - &cancel_closure_); - } - - private: - // Adds closure_ to chand->waiting_for_resolver_result_closures. - void AddToWaitingList() { - channel_data* chand = static_cast<channel_data*>(elem_->channel_data); - grpc_closure_list_append(&chand->waiting_for_resolver_result_closures, - &done_closure_, GRPC_ERROR_NONE); - } - - // Invoked when a resolver result is available. - static void DoneLocked(void* arg, grpc_error* error) { - ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); - // If CancelLocked() has already run, delete ourselves without doing - // anything. Note that the call stack may have already been destroyed, - // so it's not safe to access anything in elem_. - if (GPR_UNLIKELY(self->finished_)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "call cancelled before resolver result"); - } - Delete(self); - return; - } - // Otherwise, process the resolver result. - grpc_call_element* elem = self->elem_; - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver failed to return data", - chand, calld); - } - pick_done_locked(elem, GRPC_ERROR_REF(error)); - } else if (GPR_UNLIKELY(chand->resolver == nullptr)) { - // Shutting down. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver disconnected", chand, - calld); - } - pick_done_locked(elem, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); - } else if (GPR_UNLIKELY(chand->lb_policy == nullptr)) { - // Transient resolver failure. - // If call has wait_for_ready=true, try again; otherwise, fail. - uint32_t send_initial_metadata_flags = - calld->seen_send_initial_metadata - ? calld->send_initial_metadata_flags - : calld->pending_batches[0] - .batch->payload->send_initial_metadata - .send_initial_metadata_flags; - if (send_initial_metadata_flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: resolver returned but no LB policy; " - "wait_for_ready=true; trying again", - chand, calld); - } - // Re-add ourselves to the waiting list. - self->AddToWaitingList(); - // Return early so that we don't set finished_ to true below. - return; - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: resolver returned but no LB policy; " - "wait_for_ready=false; failing", - chand, calld); - } - pick_done_locked( - elem, - grpc_error_set_int( - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"), - GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); - } - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver returned, doing LB pick", - chand, calld); - } - process_service_config_and_start_lb_pick_locked(elem); - } - self->finished_ = true; - } - - // Invoked when the call is cancelled. - // Note: This runs under the client_channel combiner, but will NOT be - // holding the call combiner. - static void CancelLocked(void* arg, grpc_error* error) { - ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); - // If DoneLocked() has already run, delete ourselves without doing anything. - if (GPR_LIKELY(self->finished_)) { - Delete(self); - return; - } - // If we are being cancelled, immediately invoke pick_done_locked() - // to propagate the error back to the caller. - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { - grpc_call_element* elem = self->elem_; - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: cancelling call waiting for name " - "resolution", - chand, calld); - } - // Note: Although we are not in the call combiner here, we are - // basically stealing the call combiner from the pending pick, so - // it's safe to call pick_done_locked() here -- we are essentially - // calling it here instead of calling it in DoneLocked(). - pick_done_locked(elem, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Pick cancelled", &error, 1)); - } - self->finished_ = true; - } - - grpc_call_element* elem_; - grpc_closure done_closure_; - grpc_closure cancel_closure_; - bool finished_ = false; -}; - -} // namespace grpc_core - static void start_pick_locked(void* arg, grpc_error* ignored) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); - GPR_ASSERT(calld->pick.connected_subchannel == nullptr); + GPR_ASSERT(!calld->have_request); GPR_ASSERT(calld->subchannel_call == nullptr); - if (GPR_LIKELY(chand->lb_policy != nullptr)) { - // We already have resolver results, so process the service config - // and start an LB pick. - process_service_config_and_start_lb_pick_locked(elem); - } else if (GPR_UNLIKELY(chand->resolver == nullptr)) { - pick_done_locked(elem, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); - } else { - // We do not yet have an LB policy, so wait for a resolver result. - if (GPR_UNLIKELY(!chand->started_resolving)) { - start_resolving_locked(chand); - } else { - // Normally, we want to do this check in - // process_service_config_and_start_lb_pick_locked(), so that we - // can honor the wait_for_ready setting in the service config. - // However, if the channel is in TRANSIENT_FAILURE at this point, that - // means that the resolver has returned a failure, so we're not going - // to get a service config right away. In that case, we fail the - // call now based on the wait_for_ready value passed in from the - // application. - if (fail_call_if_in_transient_failure(elem)) return; - } - // Create a new waiter, which will delete itself when done. - grpc_core::New<grpc_core::ResolverResultWaiter>(elem); - // Add the polling entity from call_data to the channel_data's - // interested_parties, so that the I/O of the resolver can be done - // under it. It will be removed in pick_done_locked(). - maybe_add_call_to_channel_interested_parties_locked(elem); + // Normally, we want to do this check until after we've processed the + // service config, so that we can honor the wait_for_ready setting in + // the service config. However, if the channel is in TRANSIENT_FAILURE + // and we don't have an LB policy at this point, that means that the + // resolver has returned a failure, so we're not going to get a service + // config right away. In that case, we fail the call now based on the + // wait_for_ready value passed in from the application. + if (chand->request_router->lb_policy() == nullptr && + fail_call_if_in_transient_failure(elem)) { + return; } + // If this is a retry, use the send_initial_metadata payload that + // we've cached; otherwise, use the pending batch. The + // send_initial_metadata batch will be the first pending batch in the + // list, as set by get_batch_index() above. + // TODO(roth): What if the LB policy needs to add something to the + // call's initial metadata, and then there's a retry? We don't want + // the new metadata to be added twice. We might need to somehow + // allocate the subchannel batch earlier so that we can give the + // subchannel's copy of the metadata batch (which is copied for each + // attempt) to the LB policy instead the one from the parent channel. + grpc_metadata_batch* initial_metadata = + calld->seen_send_initial_metadata + ? &calld->send_initial_metadata + : calld->pending_batches[0] + .batch->payload->send_initial_metadata.send_initial_metadata; + uint32_t* initial_metadata_flags = + calld->seen_send_initial_metadata + ? &calld->send_initial_metadata_flags + : &calld->pending_batches[0] + .batch->payload->send_initial_metadata + .send_initial_metadata_flags; + GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem, + grpc_schedule_on_exec_ctx); + calld->request.Init(calld->owning_call, calld->call_combiner, calld->pollent, + initial_metadata, initial_metadata_flags, + maybe_apply_service_config_to_call_locked, elem, + &calld->pick_closure); + calld->have_request = true; + chand->request_router->RouteCallLocked(calld->request.get()); } // @@ -3249,23 +2548,10 @@ const grpc_channel_filter grpc_client_channel_filter = { "client-channel", }; -static void try_to_connect_locked(void* arg, grpc_error* error_ignored) { - channel_data* chand = static_cast<channel_data*>(arg); - if (chand->lb_policy != nullptr) { - chand->lb_policy->ExitIdleLocked(); - } else { - chand->exit_idle_when_lb_policy_arrives = true; - if (!chand->started_resolving && chand->resolver != nullptr) { - start_resolving_locked(chand); - } - } - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect"); -} - void grpc_client_channel_set_channelz_node( grpc_channel_element* elem, grpc_core::channelz::ClientChannelNode* node) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - chand->channelz_channel = node; + chand->request_router->set_channelz_node(node); } void grpc_client_channel_populate_child_refs( @@ -3273,17 +2559,22 @@ void grpc_client_channel_populate_child_refs( grpc_core::channelz::ChildRefsList* child_subchannels, grpc_core::channelz::ChildRefsList* child_channels) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - if (chand->lb_policy != nullptr) { - chand->lb_policy->FillChildRefsForChannelz(child_subchannels, - child_channels); + if (chand->request_router->lb_policy() != nullptr) { + chand->request_router->lb_policy()->FillChildRefsForChannelz( + child_subchannels, child_channels); } } +static void try_to_connect_locked(void* arg, grpc_error* error_ignored) { + channel_data* chand = static_cast<channel_data*>(arg); + chand->request_router->ExitIdleLocked(); + GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect"); +} + grpc_connectivity_state grpc_client_channel_check_connectivity_state( grpc_channel_element* elem, int try_to_connect) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - grpc_connectivity_state out = - grpc_connectivity_state_check(&chand->state_tracker); + grpc_connectivity_state out = chand->request_router->GetConnectivityState(); if (out == GRPC_CHANNEL_IDLE && try_to_connect) { GRPC_CHANNEL_STACK_REF(chand->owning_stack, "try_to_connect"); GRPC_CLOSURE_SCHED( @@ -3328,19 +2619,19 @@ static void external_connectivity_watcher_list_append( } static void external_connectivity_watcher_list_remove( - channel_data* chand, external_connectivity_watcher* too_remove) { + channel_data* chand, external_connectivity_watcher* to_remove) { GPR_ASSERT( - lookup_external_connectivity_watcher(chand, too_remove->on_complete)); + lookup_external_connectivity_watcher(chand, to_remove->on_complete)); gpr_mu_lock(&chand->external_connectivity_watcher_list_mu); - if (too_remove == chand->external_connectivity_watcher_list_head) { - chand->external_connectivity_watcher_list_head = too_remove->next; + if (to_remove == chand->external_connectivity_watcher_list_head) { + chand->external_connectivity_watcher_list_head = to_remove->next; gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); return; } external_connectivity_watcher* w = chand->external_connectivity_watcher_list_head; while (w != nullptr) { - if (w->next == too_remove) { + if (w->next == to_remove) { w->next = w->next->next; gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); return; @@ -3392,15 +2683,15 @@ static void watch_connectivity_state_locked(void* arg, GRPC_CLOSURE_RUN(w->watcher_timer_init, GRPC_ERROR_NONE); GRPC_CLOSURE_INIT(&w->my_closure, on_external_watch_complete_locked, w, grpc_combiner_scheduler(w->chand->combiner)); - grpc_connectivity_state_notify_on_state_change(&w->chand->state_tracker, - w->state, &w->my_closure); + w->chand->request_router->NotifyOnConnectivityStateChange(w->state, + &w->my_closure); } else { GPR_ASSERT(w->watcher_timer_init == nullptr); found = lookup_external_connectivity_watcher(w->chand, w->on_complete); if (found) { GPR_ASSERT(found->on_complete == w->on_complete); - grpc_connectivity_state_notify_on_state_change( - &found->chand->state_tracker, nullptr, &found->my_closure); + found->chand->request_router->NotifyOnConnectivityStateChange( + nullptr, &found->my_closure); } grpc_polling_entity_del_from_pollset_set(&w->pollent, w->chand->interested_parties); diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h index 6b76fe5d5d..293d8e960c 100644 --- a/src/core/ext/filters/client_channel/lb_policy.h +++ b/src/core/ext/filters/client_channel/lb_policy.h @@ -65,10 +65,10 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> { struct PickState { /// Initial metadata associated with the picking call. grpc_metadata_batch* initial_metadata = nullptr; - /// Bitmask used for selective cancelling. See + /// Pointer to bitmask used for selective cancelling. See /// \a CancelMatchingPicksLocked() and \a GRPC_INITIAL_METADATA_* in /// grpc_types.h. - uint32_t initial_metadata_flags = 0; + uint32_t* initial_metadata_flags = nullptr; /// Storage for LB token in \a initial_metadata, or nullptr if not used. grpc_linked_mdelem lb_token_mdelem_storage; /// Closure to run when pick is complete, if not completed synchronously. @@ -88,6 +88,9 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> { LoadBalancingPolicy(const LoadBalancingPolicy&) = delete; LoadBalancingPolicy& operator=(const LoadBalancingPolicy&) = delete; + /// Returns the name of the LB policy. + virtual const char* name() const GRPC_ABSTRACT; + /// Updates the policy with a new set of \a args and a new \a lb_config from /// the resolver. Note that the LB policy gets the set of addresses from the /// GRPC_ARG_SERVER_ADDRESS_LIST channel arg. @@ -205,12 +208,6 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> { grpc_pollset_set* interested_parties_; /// Callback to force a re-resolution. grpc_closure* request_reresolution_; - - // Dummy classes needed for alignment issues. - // See https://github.com/grpc/grpc/issues/16032 for context. - // TODO(ncteisen): remove this as soon as the issue is resolved. - channelz::ChildRefsList dummy_list_foo; - channelz::ChildRefsList dummy_list_bar; }; } // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc index a9a5965ed1..ba40febd53 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc @@ -122,10 +122,14 @@ TraceFlag grpc_lb_glb_trace(false, "glb"); namespace { +constexpr char kGrpclb[] = "grpclb"; + class GrpcLb : public LoadBalancingPolicy { public: explicit GrpcLb(const Args& args); + const char* name() const override { return kGrpclb; } + void UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; @@ -361,7 +365,9 @@ void lb_token_destroy(void* token) { } } int lb_token_cmp(void* token1, void* token2) { - return GPR_ICMP(token1, token2); + // Always indicate a match, since we don't want this channel arg to + // affect the subchannel's key in the index. + return 0; } const grpc_arg_pointer_vtable lb_token_arg_vtable = { lb_token_copy, lb_token_destroy, lb_token_cmp}; @@ -422,7 +428,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { grpc_resolved_address addr; ParseServer(server, &addr); // LB token processing. - void* lb_token; + grpc_mdelem lb_token; if (server->has_load_balance_token) { const size_t lb_token_max_length = GPR_ARRAY_SIZE(server->load_balance_token); @@ -430,9 +436,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { strnlen(server->load_balance_token, lb_token_max_length); grpc_slice lb_token_mdstr = grpc_slice_from_copied_buffer( server->load_balance_token, lb_token_length); - lb_token = - (void*)grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr) - .payload; + lb_token = grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr); } else { char* uri = grpc_sockaddr_to_uri(&addr); gpr_log(GPR_INFO, @@ -440,14 +444,16 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { "be used instead", uri); gpr_free(uri); - lb_token = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload; + lb_token = GRPC_MDELEM_LB_TOKEN_EMPTY; } // Add address. grpc_arg arg = grpc_channel_arg_pointer_create( - const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), lb_token, - &lb_token_arg_vtable); + const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), + (void*)lb_token.payload, &lb_token_arg_vtable); grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); addresses.emplace_back(addr, args); + // Clean up. + GRPC_MDELEM_UNREF(lb_token); } return addresses; } @@ -525,8 +531,7 @@ void GrpcLb::BalancerCallState::Orphan() { void GrpcLb::BalancerCallState::StartQuery() { GPR_ASSERT(lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, - "[grpclb %p] Starting LB call (lb_calld: %p, lb_call: %p)", + gpr_log(GPR_INFO, "[grpclb %p] lb_calld=%p: Starting LB call %p", grpclb_policy_.get(), this, lb_call_); } // Create the ops. @@ -670,8 +675,9 @@ void GrpcLb::BalancerCallState::SendClientLoadReportLocked() { grpc_call_error call_error = grpc_call_start_batch_and_execute( lb_call_, &op, 1, &client_load_report_closure_); if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { - gpr_log(GPR_ERROR, "[grpclb %p] call_error=%d", grpclb_policy_.get(), - call_error); + gpr_log(GPR_ERROR, + "[grpclb %p] lb_calld=%p call_error=%d sending client load report", + grpclb_policy_.get(), this, call_error); GPR_ASSERT(GRPC_CALL_OK == call_error); } } @@ -732,15 +738,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( &initial_response->client_stats_report_interval)); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; " - "client load reporting interval = %" PRId64 " milliseconds", - grpclb_policy, lb_calld->client_stats_report_interval_); + "[grpclb %p] lb_calld=%p: Received initial LB response " + "message; client load reporting interval = %" PRId64 + " milliseconds", + grpclb_policy, lb_calld, + lb_calld->client_stats_report_interval_); } } else if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; client load " - "reporting NOT enabled", - grpclb_policy); + "[grpclb %p] lb_calld=%p: Received initial LB response message; " + "client load reporting NOT enabled", + grpclb_policy, lb_calld); } grpc_grpclb_initial_response_destroy(initial_response); lb_calld->seen_initial_response_ = true; @@ -750,15 +758,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( GPR_ASSERT(lb_calld->lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Serverlist with %" PRIuPTR " servers received", - grpclb_policy, serverlist->num_servers); + "[grpclb %p] lb_calld=%p: Serverlist with %" PRIuPTR + " servers received", + grpclb_policy, lb_calld, serverlist->num_servers); for (size_t i = 0; i < serverlist->num_servers; ++i) { grpc_resolved_address addr; ParseServer(serverlist->servers[i], &addr); char* ipport; grpc_sockaddr_to_string(&ipport, &addr, false); - gpr_log(GPR_INFO, "[grpclb %p] Serverlist[%" PRIuPTR "]: %s", - grpclb_policy, i, ipport); + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Serverlist[%" PRIuPTR "]: %s", + grpclb_policy, lb_calld, i, ipport); gpr_free(ipport); } } @@ -778,9 +788,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( if (grpc_grpclb_serverlist_equals(grpclb_policy->serverlist_, serverlist)) { if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Incoming server list identical to current, " - "ignoring.", - grpclb_policy); + "[grpclb %p] lb_calld=%p: Incoming server list identical to " + "current, ignoring.", + grpclb_policy, lb_calld); } grpc_grpclb_destroy_serverlist(serverlist); } else { // New serverlist. @@ -806,8 +816,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( char* response_slice_str = grpc_dump_slice(response_slice, GPR_DUMP_ASCII | GPR_DUMP_HEX); gpr_log(GPR_ERROR, - "[grpclb %p] Invalid LB response received: '%s'. Ignoring.", - grpclb_policy, response_slice_str); + "[grpclb %p] lb_calld=%p: Invalid LB response received: '%s'. " + "Ignoring.", + grpclb_policy, lb_calld, response_slice_str); gpr_free(response_slice_str); } grpc_slice_unref_internal(response_slice); @@ -838,9 +849,9 @@ void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked( char* status_details = grpc_slice_to_c_string(lb_calld->lb_call_status_details_); gpr_log(GPR_INFO, - "[grpclb %p] Status from LB server received. Status = %d, details " - "= '%s', (lb_calld: %p, lb_call: %p), error '%s'", - grpclb_policy, lb_calld->lb_call_status_, status_details, lb_calld, + "[grpclb %p] lb_calld=%p: Status from LB server received. " + "Status = %d, details = '%s', (lb_call: %p), error '%s'", + grpclb_policy, lb_calld, lb_calld->lb_call_status_, status_details, lb_calld->lb_call_, grpc_error_string(error)); gpr_free(status_details); } @@ -1129,7 +1140,7 @@ void GrpcLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pp != nullptr) { PendingPick* next = pp->next; - if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { // Note: pp is deleted in this callback. GRPC_CLOSURE_SCHED(&pp->on_complete, @@ -1592,6 +1603,10 @@ void GrpcLb::CreateRoundRobinPolicyLocked(const Args& args) { this); return; } + if (grpc_lb_glb_trace.enabled()) { + gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, + rr_policy_.get()); + } // TODO(roth): We currently track this ref manually. Once the new // ClosureRef API is done, pass the RefCountedPtr<> along with the closure. auto self = Ref(DEBUG_LOCATION, "on_rr_reresolution_requested"); @@ -1685,10 +1700,6 @@ void GrpcLb::CreateOrUpdateRoundRobinPolicyLocked() { lb_policy_args.client_channel_factory = client_channel_factory(); lb_policy_args.args = args; CreateRoundRobinPolicyLocked(lb_policy_args); - if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, - rr_policy_.get()); - } } grpc_channel_args_destroy(args); } @@ -1812,7 +1823,7 @@ class GrpcLbFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<GrpcLb>(args)); } - const char* name() const override { return "grpclb"; } + const char* name() const override { return kGrpclb; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc index 6e8fbdcab7..657ff69312 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc @@ -88,22 +88,18 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc index 74c17612a2..d6ff74ec7f 100644 --- a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc +++ b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc @@ -43,10 +43,14 @@ namespace { // pick_first LB policy // +constexpr char kPickFirst[] = "pick_first"; + class PickFirst : public LoadBalancingPolicy { public: explicit PickFirst(const Args& args); + const char* name() const override { return kPickFirst; } + void UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; @@ -234,7 +238,7 @@ void PickFirst::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pick != nullptr) { PickState* next = pick->next; - if ((pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { GRPC_CLOSURE_SCHED(pick->on_complete, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( @@ -622,7 +626,7 @@ class PickFirstFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<PickFirst>(args)); } - const char* name() const override { return "pick_first"; } + const char* name() const override { return kPickFirst; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc index 63089afbd7..3bcb33ef11 100644 --- a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc +++ b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc @@ -53,10 +53,14 @@ namespace { // round_robin LB policy // +constexpr char kRoundRobin[] = "round_robin"; + class RoundRobin : public LoadBalancingPolicy { public: explicit RoundRobin(const Args& args); + const char* name() const override { return kRoundRobin; } + void UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; @@ -291,7 +295,7 @@ void RoundRobin::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pick != nullptr) { PickState* next = pick->next; - if ((pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { pick->connected_subchannel.reset(); GRPC_CLOSURE_SCHED(pick->on_complete, @@ -700,7 +704,7 @@ class RoundRobinFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<RoundRobin>(args)); } - const char* name() const override { return "round_robin"; } + const char* name() const override { return kRoundRobin; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc index 3c25de2386..8787f5bcc2 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc @@ -115,10 +115,14 @@ TraceFlag grpc_lb_xds_trace(false, "xds"); namespace { +constexpr char kXds[] = "xds_experimental"; + class XdsLb : public LoadBalancingPolicy { public: explicit XdsLb(const Args& args); + const char* name() const override { return kXds; } + void UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; @@ -1053,7 +1057,7 @@ void XdsLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pp != nullptr) { PendingPick* next = pp->next; - if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { // Note: pp is deleted in this callback. GRPC_CLOSURE_SCHED(&pp->on_complete, @@ -1651,7 +1655,7 @@ class XdsFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<XdsLb>(args)); } - const char* name() const override { return "xds_experimental"; } + const char* name() const override { return kXds; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc index 9a11f8e39f..55c646e6ee 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc @@ -87,22 +87,18 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/filters/client_channel/request_routing.cc b/src/core/ext/filters/client_channel/request_routing.cc new file mode 100644 index 0000000000..f9a7e164e7 --- /dev/null +++ b/src/core/ext/filters/client_channel/request_routing.cc @@ -0,0 +1,936 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/request_routing.h" + +#include <inttypes.h> +#include <limits.h> +#include <stdbool.h> +#include <stdio.h> +#include <string.h> + +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> +#include <grpc/support/sync.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/retry_throttle.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/ext/filters/deadline/deadline_filter.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/service_config.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" + +namespace grpc_core { + +// +// RequestRouter::Request::ResolverResultWaiter +// + +// Handles waiting for a resolver result. +// Used only for the first call on an idle channel. +class RequestRouter::Request::ResolverResultWaiter { + public: + explicit ResolverResultWaiter(Request* request) + : request_router_(request->request_router_), + request_(request), + tracer_enabled_(request_router_->tracer_->enabled()) { + if (tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: deferring pick pending resolver " + "result", + request_router_, request); + } + // Add closure to be run when a resolver result is available. + GRPC_CLOSURE_INIT(&done_closure_, &DoneLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + AddToWaitingList(); + // Set cancellation closure, so that we abort if the call is cancelled. + GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + grpc_call_combiner_set_notify_on_cancel(request->call_combiner_, + &cancel_closure_); + } + + private: + // Adds done_closure_ to + // request_router_->waiting_for_resolver_result_closures_. + void AddToWaitingList() { + grpc_closure_list_append( + &request_router_->waiting_for_resolver_result_closures_, &done_closure_, + GRPC_ERROR_NONE); + } + + // Invoked when a resolver result is available. + static void DoneLocked(void* arg, grpc_error* error) { + ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); + RequestRouter* request_router = self->request_router_; + // If CancelLocked() has already run, delete ourselves without doing + // anything. Note that the call stack may have already been destroyed, + // so it's not safe to access anything in state_. + if (GPR_UNLIKELY(self->finished_)) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p: call cancelled before resolver result", + request_router); + } + Delete(self); + return; + } + // Otherwise, process the resolver result. + Request* request = self->request_; + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver failed to return data", + request_router, request); + } + GRPC_CLOSURE_RUN(request->on_route_done_, GRPC_ERROR_REF(error)); + } else if (GPR_UNLIKELY(request_router->resolver_ == nullptr)) { + // Shutting down. + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, "request_router=%p request=%p: resolver disconnected", + request_router, request); + } + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); + } else if (GPR_UNLIKELY(request_router->lb_policy_ == nullptr)) { + // Transient resolver failure. + // If call has wait_for_ready=true, try again; otherwise, fail. + if (*request->pick_.initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned but no LB " + "policy; wait_for_ready=true; trying again", + request_router, request); + } + // Re-add ourselves to the waiting list. + self->AddToWaitingList(); + // Return early so that we don't set finished_ to true below. + return; + } else { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned but no LB " + "policy; wait_for_ready=false; failing", + request_router, request); + } + GRPC_CLOSURE_RUN( + request->on_route_done_, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } else { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned, doing LB " + "pick", + request_router, request); + } + request->ProcessServiceConfigAndStartLbPickLocked(); + } + self->finished_ = true; + } + + // Invoked when the call is cancelled. + // Note: This runs under the client_channel combiner, but will NOT be + // holding the call combiner. + static void CancelLocked(void* arg, grpc_error* error) { + ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); + RequestRouter* request_router = self->request_router_; + // If DoneLocked() has already run, delete ourselves without doing anything. + if (self->finished_) { + Delete(self); + return; + } + Request* request = self->request_; + // If we are being cancelled, immediately invoke on_route_done_ + // to propagate the error back to the caller. + if (error != GRPC_ERROR_NONE) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: cancelling call waiting for " + "name resolution", + request_router, request); + } + // Note: Although we are not in the call combiner here, we are + // basically stealing the call combiner from the pending pick, so + // it's safe to run on_route_done_ here -- we are essentially + // calling it here instead of calling it in DoneLocked(). + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Pick cancelled", &error, 1)); + } + self->finished_ = true; + } + + RequestRouter* request_router_; + Request* request_; + const bool tracer_enabled_; + grpc_closure done_closure_; + grpc_closure cancel_closure_; + bool finished_ = false; +}; + +// +// RequestRouter::Request::AsyncPickCanceller +// + +// Handles the call combiner cancellation callback for an async LB pick. +class RequestRouter::Request::AsyncPickCanceller { + public: + explicit AsyncPickCanceller(Request* request) + : request_router_(request->request_router_), + request_(request), + tracer_enabled_(request_router_->tracer_->enabled()) { + GRPC_CALL_STACK_REF(request->owning_call_, "pick_callback_cancel"); + // Set cancellation closure, so that we abort if the call is cancelled. + GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + grpc_call_combiner_set_notify_on_cancel(request->call_combiner_, + &cancel_closure_); + } + + void MarkFinishedLocked() { + finished_ = true; + GRPC_CALL_STACK_UNREF(request_->owning_call_, "pick_callback_cancel"); + } + + private: + // Invoked when the call is cancelled. + // Note: This runs under the client_channel combiner, but will NOT be + // holding the call combiner. + static void CancelLocked(void* arg, grpc_error* error) { + AsyncPickCanceller* self = static_cast<AsyncPickCanceller*>(arg); + Request* request = self->request_; + RequestRouter* request_router = self->request_router_; + if (!self->finished_) { + // Note: request_router->lb_policy_ may have changed since we started our + // pick, in which case we will be cancelling the pick on a policy other + // than the one we started it on. However, this will just be a no-op. + if (error != GRPC_ERROR_NONE && request_router->lb_policy_ != nullptr) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: cancelling pick from LB " + "policy %p", + request_router, request, request_router->lb_policy_.get()); + } + request_router->lb_policy_->CancelPickLocked(&request->pick_, + GRPC_ERROR_REF(error)); + } + request->pick_canceller_ = nullptr; + GRPC_CALL_STACK_UNREF(request->owning_call_, "pick_callback_cancel"); + } + Delete(self); + } + + RequestRouter* request_router_; + Request* request_; + const bool tracer_enabled_; + grpc_closure cancel_closure_; + bool finished_ = false; +}; + +// +// RequestRouter::Request +// + +RequestRouter::Request::Request(grpc_call_stack* owning_call, + grpc_call_combiner* call_combiner, + grpc_polling_entity* pollent, + grpc_metadata_batch* send_initial_metadata, + uint32_t* send_initial_metadata_flags, + ApplyServiceConfigCallback apply_service_config, + void* apply_service_config_user_data, + grpc_closure* on_route_done) + : owning_call_(owning_call), + call_combiner_(call_combiner), + pollent_(pollent), + apply_service_config_(apply_service_config), + apply_service_config_user_data_(apply_service_config_user_data), + on_route_done_(on_route_done) { + pick_.initial_metadata = send_initial_metadata; + pick_.initial_metadata_flags = send_initial_metadata_flags; +} + +RequestRouter::Request::~Request() { + if (pick_.connected_subchannel != nullptr) { + pick_.connected_subchannel.reset(); + } + for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { + if (pick_.subchannel_call_context[i].destroy != nullptr) { + pick_.subchannel_call_context[i].destroy( + pick_.subchannel_call_context[i].value); + } + } +} + +// Invoked once resolver results are available. +void RequestRouter::Request::ProcessServiceConfigAndStartLbPickLocked() { + // Get service config data if needed. + if (!apply_service_config_(apply_service_config_user_data_)) return; + // Start LB pick. + StartLbPickLocked(); +} + +void RequestRouter::Request::MaybeAddCallToInterestedPartiesLocked() { + if (!pollent_added_to_interested_parties_) { + pollent_added_to_interested_parties_ = true; + grpc_polling_entity_add_to_pollset_set( + pollent_, request_router_->interested_parties_); + } +} + +void RequestRouter::Request::MaybeRemoveCallFromInterestedPartiesLocked() { + if (pollent_added_to_interested_parties_) { + pollent_added_to_interested_parties_ = false; + grpc_polling_entity_del_from_pollset_set( + pollent_, request_router_->interested_parties_); + } +} + +// Starts a pick on the LB policy. +void RequestRouter::Request::StartLbPickLocked() { + if (request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: starting pick on lb_policy=%p", + request_router_, this, request_router_->lb_policy_.get()); + } + GRPC_CLOSURE_INIT(&on_pick_done_, &LbPickDoneLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + pick_.on_complete = &on_pick_done_; + GRPC_CALL_STACK_REF(owning_call_, "pick_callback"); + grpc_error* error = GRPC_ERROR_NONE; + const bool pick_done = + request_router_->lb_policy_->PickLocked(&pick_, &error); + if (pick_done) { + // Pick completed synchronously. + if (request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: pick completed synchronously", + request_router_, this); + } + GRPC_CLOSURE_RUN(on_route_done_, error); + GRPC_CALL_STACK_UNREF(owning_call_, "pick_callback"); + } else { + // Pick will be returned asynchronously. + // Add the request's polling entity to the request_router's + // interested_parties, so that the I/O of the LB policy can be done + // under it. It will be removed in LbPickDoneLocked(). + MaybeAddCallToInterestedPartiesLocked(); + // Request notification on call cancellation. + // We allocate a separate object to track cancellation, since the + // cancellation closure might still be pending when we need to reuse + // the memory in which this Request object is stored for a subsequent + // retry attempt. + pick_canceller_ = New<AsyncPickCanceller>(this); + } +} + +// Callback invoked by LoadBalancingPolicy::PickLocked() for async picks. +// Unrefs the LB policy and invokes on_route_done_. +void RequestRouter::Request::LbPickDoneLocked(void* arg, grpc_error* error) { + Request* self = static_cast<Request*>(arg); + RequestRouter* request_router = self->request_router_; + if (request_router->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: pick completed asynchronously", + request_router, self); + } + self->MaybeRemoveCallFromInterestedPartiesLocked(); + if (self->pick_canceller_ != nullptr) { + self->pick_canceller_->MarkFinishedLocked(); + } + GRPC_CLOSURE_RUN(self->on_route_done_, GRPC_ERROR_REF(error)); + GRPC_CALL_STACK_UNREF(self->owning_call_, "pick_callback"); +} + +// +// RequestRouter::LbConnectivityWatcher +// + +class RequestRouter::LbConnectivityWatcher { + public: + LbConnectivityWatcher(RequestRouter* request_router, + grpc_connectivity_state state, + LoadBalancingPolicy* lb_policy, + grpc_channel_stack* owning_stack, + grpc_combiner* combiner) + : request_router_(request_router), + state_(state), + lb_policy_(lb_policy), + owning_stack_(owning_stack) { + GRPC_CHANNEL_STACK_REF(owning_stack_, "LbConnectivityWatcher"); + GRPC_CLOSURE_INIT(&on_changed_, &OnLbPolicyStateChangedLocked, this, + grpc_combiner_scheduler(combiner)); + lb_policy_->NotifyOnStateChangeLocked(&state_, &on_changed_); + } + + ~LbConnectivityWatcher() { + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "LbConnectivityWatcher"); + } + + private: + static void OnLbPolicyStateChangedLocked(void* arg, grpc_error* error) { + LbConnectivityWatcher* self = static_cast<LbConnectivityWatcher*>(arg); + // If the notification is not for the current policy, we're stale, + // so delete ourselves. + if (self->lb_policy_ != self->request_router_->lb_policy_.get()) { + Delete(self); + return; + } + // Otherwise, process notification. + if (self->request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: lb_policy=%p state changed to %s", + self->request_router_, self->lb_policy_, + grpc_connectivity_state_name(self->state_)); + } + self->request_router_->SetConnectivityStateLocked( + self->state_, GRPC_ERROR_REF(error), "lb_changed"); + // If shutting down, terminate watch. + if (self->state_ == GRPC_CHANNEL_SHUTDOWN) { + Delete(self); + return; + } + // Renew watch. + self->lb_policy_->NotifyOnStateChangeLocked(&self->state_, + &self->on_changed_); + } + + RequestRouter* request_router_; + grpc_connectivity_state state_; + // LB policy address. No ref held, so not safe to dereference unless + // it happens to match request_router->lb_policy_. + LoadBalancingPolicy* lb_policy_; + grpc_channel_stack* owning_stack_; + grpc_closure on_changed_; +}; + +// +// RequestRounter::ReresolutionRequestHandler +// + +class RequestRouter::ReresolutionRequestHandler { + public: + ReresolutionRequestHandler(RequestRouter* request_router, + LoadBalancingPolicy* lb_policy, + grpc_channel_stack* owning_stack, + grpc_combiner* combiner) + : request_router_(request_router), + lb_policy_(lb_policy), + owning_stack_(owning_stack) { + GRPC_CHANNEL_STACK_REF(owning_stack_, "ReresolutionRequestHandler"); + GRPC_CLOSURE_INIT(&closure_, &OnRequestReresolutionLocked, this, + grpc_combiner_scheduler(combiner)); + lb_policy_->SetReresolutionClosureLocked(&closure_); + } + + private: + static void OnRequestReresolutionLocked(void* arg, grpc_error* error) { + ReresolutionRequestHandler* self = + static_cast<ReresolutionRequestHandler*>(arg); + RequestRouter* request_router = self->request_router_; + // If this invocation is for a stale LB policy, treat it as an LB shutdown + // signal. + if (self->lb_policy_ != request_router->lb_policy_.get() || + error != GRPC_ERROR_NONE || request_router->resolver_ == nullptr) { + GRPC_CHANNEL_STACK_UNREF(request_router->owning_stack_, + "ReresolutionRequestHandler"); + Delete(self); + return; + } + if (request_router->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: started name re-resolving", + request_router); + } + request_router->resolver_->RequestReresolutionLocked(); + // Give back the closure to the LB policy. + self->lb_policy_->SetReresolutionClosureLocked(&self->closure_); + } + + RequestRouter* request_router_; + // LB policy address. No ref held, so not safe to dereference unless + // it happens to match request_router->lb_policy_. + LoadBalancingPolicy* lb_policy_; + grpc_channel_stack* owning_stack_; + grpc_closure closure_; +}; + +// +// RequestRouter +// + +RequestRouter::RequestRouter( + grpc_channel_stack* owning_stack, grpc_combiner* combiner, + grpc_client_channel_factory* client_channel_factory, + grpc_pollset_set* interested_parties, TraceFlag* tracer, + ProcessResolverResultCallback process_resolver_result, + void* process_resolver_result_user_data, const char* target_uri, + const grpc_channel_args* args, grpc_error** error) + : owning_stack_(owning_stack), + combiner_(combiner), + client_channel_factory_(client_channel_factory), + interested_parties_(interested_parties), + tracer_(tracer), + process_resolver_result_(process_resolver_result), + process_resolver_result_user_data_(process_resolver_result_user_data) { + GRPC_CLOSURE_INIT(&on_resolver_result_changed_, + &RequestRouter::OnResolverResultChangedLocked, this, + grpc_combiner_scheduler(combiner)); + grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE, + "request_router"); + grpc_channel_args* new_args = nullptr; + if (process_resolver_result == nullptr) { + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast<char*>(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION), 0); + new_args = grpc_channel_args_copy_and_add(args, &arg, 1); + } + resolver_ = ResolverRegistry::CreateResolver( + target_uri, (new_args == nullptr ? args : new_args), interested_parties_, + combiner_); + grpc_channel_args_destroy(new_args); + if (resolver_ == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed"); + } +} + +RequestRouter::~RequestRouter() { + if (resolver_ != nullptr) { + // The only way we can get here is if we never started resolving, + // because we take a ref to the channel stack when we start + // resolving and do not release it until the resolver callback is + // invoked after the resolver shuts down. + resolver_.reset(); + } + if (lb_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + if (client_channel_factory_ != nullptr) { + grpc_client_channel_factory_unref(client_channel_factory_); + } + grpc_connectivity_state_destroy(&state_tracker_); +} + +namespace { + +const char* GetChannelConnectivityStateChangeString( + grpc_connectivity_state state) { + switch (state) { + case GRPC_CHANNEL_IDLE: + return "Channel state change to IDLE"; + case GRPC_CHANNEL_CONNECTING: + return "Channel state change to CONNECTING"; + case GRPC_CHANNEL_READY: + return "Channel state change to READY"; + case GRPC_CHANNEL_TRANSIENT_FAILURE: + return "Channel state change to TRANSIENT_FAILURE"; + case GRPC_CHANNEL_SHUTDOWN: + return "Channel state change to SHUTDOWN"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +} // namespace + +void RequestRouter::SetConnectivityStateLocked(grpc_connectivity_state state, + grpc_error* error, + const char* reason) { + if (lb_policy_ != nullptr) { + if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + // Cancel picks with wait_for_ready=false. + lb_policy_->CancelMatchingPicksLocked( + /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY, + /* check= */ 0, GRPC_ERROR_REF(error)); + } else if (state == GRPC_CHANNEL_SHUTDOWN) { + // Cancel all picks. + lb_policy_->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0, + GRPC_ERROR_REF(error)); + } + } + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: setting connectivity state to %s", + this, grpc_connectivity_state_name(state)); + } + if (channelz_node_ != nullptr) { + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string( + GetChannelConnectivityStateChangeString(state))); + } + grpc_connectivity_state_set(&state_tracker_, state, error, reason); +} + +void RequestRouter::StartResolvingLocked() { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: starting name resolution", this); + } + GPR_ASSERT(!started_resolving_); + started_resolving_ = true; + GRPC_CHANNEL_STACK_REF(owning_stack_, "resolver"); + resolver_->NextLocked(&resolver_result_, &on_resolver_result_changed_); +} + +// Invoked from the resolver NextLocked() callback when the resolver +// is shutting down. +void RequestRouter::OnResolverShutdownLocked(grpc_error* error) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down", this); + } + if (lb_policy_ != nullptr) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this, + lb_policy_.get()); + } + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + if (resolver_ != nullptr) { + // This should never happen; it can only be triggered by a resolver + // implementation spotaneously deciding to report shutdown without + // being orphaned. This code is included just to be defensive. + if (tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p: spontaneous shutdown from resolver %p", this, + resolver_.get()); + } + resolver_.reset(); + SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Resolver spontaneous shutdown", &error, 1), + "resolver_spontaneous_shutdown"); + } + grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Channel disconnected", &error, 1)); + GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_); + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "resolver"); + grpc_channel_args_destroy(resolver_result_); + resolver_result_ = nullptr; + GRPC_ERROR_UNREF(error); +} + +// Creates a new LB policy, replacing any previous one. +// If the new policy is created successfully, sets *connectivity_state and +// *connectivity_error to its initial connectivity state; otherwise, +// leaves them unchanged. +void RequestRouter::CreateNewLbPolicyLocked( + const char* lb_policy_name, grpc_json* lb_config, + grpc_connectivity_state* connectivity_state, + grpc_error** connectivity_error, TraceStringVector* trace_strings) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.combiner = combiner_; + lb_policy_args.client_channel_factory = client_channel_factory_; + lb_policy_args.args = resolver_result_; + lb_policy_args.lb_config = lb_config; + OrphanablePtr<LoadBalancingPolicy> new_lb_policy = + LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(lb_policy_name, + lb_policy_args); + if (GPR_UNLIKELY(new_lb_policy == nullptr)) { + gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name); + if (channelz_node_ != nullptr) { + char* str; + gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name); + trace_strings->push_back(str); + } + } else { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: created new LB policy \"%s\" (%p)", + this, lb_policy_name, new_lb_policy.get()); + } + if (channelz_node_ != nullptr) { + char* str; + gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name); + trace_strings->push_back(str); + } + // Swap out the LB policy and update the fds in interested_parties_. + if (lb_policy_ != nullptr) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this, + lb_policy_.get()); + } + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_->HandOffPendingPicksLocked(new_lb_policy.get()); + } + lb_policy_ = std::move(new_lb_policy); + grpc_pollset_set_add_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + // Create re-resolution request handler for the new LB policy. It + // will delete itself when no longer needed. + New<ReresolutionRequestHandler>(this, lb_policy_.get(), owning_stack_, + combiner_); + // Get the new LB policy's initial connectivity state and start a + // connectivity watch. + GRPC_ERROR_UNREF(*connectivity_error); + *connectivity_state = + lb_policy_->CheckConnectivityLocked(connectivity_error); + if (exit_idle_when_lb_policy_arrives_) { + lb_policy_->ExitIdleLocked(); + exit_idle_when_lb_policy_arrives_ = false; + } + // Create new watcher. It will delete itself when done. + New<LbConnectivityWatcher>(this, *connectivity_state, lb_policy_.get(), + owning_stack_, combiner_); + } +} + +void RequestRouter::MaybeAddTraceMessagesForAddressChangesLocked( + TraceStringVector* trace_strings) { + const ServerAddressList* addresses = + FindServerAddressListChannelArg(resolver_result_); + const bool resolution_contains_addresses = + addresses != nullptr && addresses->size() > 0; + if (!resolution_contains_addresses && + previous_resolution_contained_addresses_) { + trace_strings->push_back(gpr_strdup("Address list became empty")); + } else if (resolution_contains_addresses && + !previous_resolution_contained_addresses_) { + trace_strings->push_back(gpr_strdup("Address list became non-empty")); + } + previous_resolution_contained_addresses_ = resolution_contains_addresses; +} + +void RequestRouter::ConcatenateAndAddChannelTraceLocked( + TraceStringVector* trace_strings) const { + if (!trace_strings->empty()) { + gpr_strvec v; + gpr_strvec_init(&v); + gpr_strvec_add(&v, gpr_strdup("Resolution event: ")); + bool is_first = 1; + for (size_t i = 0; i < trace_strings->size(); ++i) { + if (!is_first) gpr_strvec_add(&v, gpr_strdup(", ")); + is_first = false; + gpr_strvec_add(&v, (*trace_strings)[i]); + } + char* flat; + size_t flat_len = 0; + flat = gpr_strvec_flatten(&v, &flat_len); + channelz_node_->AddTraceEvent( + grpc_core::channelz::ChannelTrace::Severity::Info, + grpc_slice_new(flat, flat_len, gpr_free)); + gpr_strvec_destroy(&v); + } +} + +// Callback invoked when a resolver result is available. +void RequestRouter::OnResolverResultChangedLocked(void* arg, + grpc_error* error) { + RequestRouter* self = static_cast<RequestRouter*>(arg); + if (self->tracer_->enabled()) { + const char* disposition = + self->resolver_result_ != nullptr + ? "" + : (error == GRPC_ERROR_NONE ? " (transient error)" + : " (resolver shutdown)"); + gpr_log(GPR_INFO, + "request_router=%p: got resolver result: resolver_result=%p " + "error=%s%s", + self, self->resolver_result_, grpc_error_string(error), + disposition); + } + // Handle shutdown. + if (error != GRPC_ERROR_NONE || self->resolver_ == nullptr) { + self->OnResolverShutdownLocked(GRPC_ERROR_REF(error)); + return; + } + // Data used to set the channel's connectivity state. + bool set_connectivity_state = true; + // We only want to trace the address resolution in the follow cases: + // (a) Address resolution resulted in service config change. + // (b) Address resolution that causes number of backends to go from + // zero to non-zero. + // (c) Address resolution that causes number of backends to go from + // non-zero to zero. + // (d) Address resolution that causes a new LB policy to be created. + // + // we track a list of strings to eventually be concatenated and traced. + TraceStringVector trace_strings; + grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; + grpc_error* connectivity_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy"); + // resolver_result_ will be null in the case of a transient + // resolution error. In that case, we don't have any new result to + // process, which means that we keep using the previous result (if any). + if (self->resolver_result_ == nullptr) { + if (self->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: resolver transient failure", self); + } + // Don't override connectivity state if we already have an LB policy. + if (self->lb_policy_ != nullptr) set_connectivity_state = false; + } else { + // Parse the resolver result. + const char* lb_policy_name = nullptr; + grpc_json* lb_policy_config = nullptr; + const bool service_config_changed = self->process_resolver_result_( + self->process_resolver_result_user_data_, *self->resolver_result_, + &lb_policy_name, &lb_policy_config); + GPR_ASSERT(lb_policy_name != nullptr); + // Check to see if we're already using the right LB policy. + const bool lb_policy_name_changed = + self->lb_policy_ == nullptr || + strcmp(self->lb_policy_->name(), lb_policy_name) != 0; + if (self->lb_policy_ != nullptr && !lb_policy_name_changed) { + // Continue using the same LB policy. Update with new addresses. + if (self->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p: updating existing LB policy \"%s\" (%p)", + self, lb_policy_name, self->lb_policy_.get()); + } + self->lb_policy_->UpdateLocked(*self->resolver_result_, lb_policy_config); + // No need to set the channel's connectivity state; the existing + // watch on the LB policy will take care of that. + set_connectivity_state = false; + } else { + // Instantiate new LB policy. + self->CreateNewLbPolicyLocked(lb_policy_name, lb_policy_config, + &connectivity_state, &connectivity_error, + &trace_strings); + } + // Add channel trace event. + if (self->channelz_node_ != nullptr) { + if (service_config_changed) { + // TODO(ncteisen): might be worth somehow including a snippet of the + // config in the trace, at the risk of bloating the trace logs. + trace_strings.push_back(gpr_strdup("Service config changed")); + } + self->MaybeAddTraceMessagesForAddressChangesLocked(&trace_strings); + self->ConcatenateAndAddChannelTraceLocked(&trace_strings); + } + // Clean up. + grpc_channel_args_destroy(self->resolver_result_); + self->resolver_result_ = nullptr; + } + // Set the channel's connectivity state if needed. + if (set_connectivity_state) { + self->SetConnectivityStateLocked(connectivity_state, connectivity_error, + "resolver_result"); + } else { + GRPC_ERROR_UNREF(connectivity_error); + } + // Invoke closures that were waiting for results and renew the watch. + GRPC_CLOSURE_LIST_SCHED(&self->waiting_for_resolver_result_closures_); + self->resolver_->NextLocked(&self->resolver_result_, + &self->on_resolver_result_changed_); +} + +void RequestRouter::RouteCallLocked(Request* request) { + GPR_ASSERT(request->pick_.connected_subchannel == nullptr); + request->request_router_ = this; + if (lb_policy_ != nullptr) { + // We already have resolver results, so process the service config + // and start an LB pick. + request->ProcessServiceConfigAndStartLbPickLocked(); + } else if (resolver_ == nullptr) { + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); + } else { + // We do not yet have an LB policy, so wait for a resolver result. + if (!started_resolving_) { + StartResolvingLocked(); + } + // Create a new waiter, which will delete itself when done. + New<Request::ResolverResultWaiter>(request); + // Add the request's polling entity to the request_router's + // interested_parties, so that the I/O of the resolver can be done + // under it. It will be removed in LbPickDoneLocked(). + request->MaybeAddCallToInterestedPartiesLocked(); + } +} + +void RequestRouter::ShutdownLocked(grpc_error* error) { + if (resolver_ != nullptr) { + SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN, GRPC_ERROR_REF(error), + "disconnect"); + resolver_.reset(); + if (!started_resolving_) { + grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_, + GRPC_ERROR_REF(error)); + GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_); + } + if (lb_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + } + GRPC_ERROR_UNREF(error); +} + +grpc_connectivity_state RequestRouter::GetConnectivityState() { + return grpc_connectivity_state_check(&state_tracker_); +} + +void RequestRouter::NotifyOnConnectivityStateChange( + grpc_connectivity_state* state, grpc_closure* closure) { + grpc_connectivity_state_notify_on_state_change(&state_tracker_, state, + closure); +} + +void RequestRouter::ExitIdleLocked() { + if (lb_policy_ != nullptr) { + lb_policy_->ExitIdleLocked(); + } else { + exit_idle_when_lb_policy_arrives_ = true; + if (!started_resolving_ && resolver_ != nullptr) { + StartResolvingLocked(); + } + } +} + +void RequestRouter::ResetConnectionBackoffLocked() { + if (resolver_ != nullptr) { + resolver_->ResetBackoffLocked(); + resolver_->RequestReresolutionLocked(); + } + if (lb_policy_ != nullptr) { + lb_policy_->ResetBackoffLocked(); + } +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/request_routing.h b/src/core/ext/filters/client_channel/request_routing.h new file mode 100644 index 0000000000..0c671229c8 --- /dev/null +++ b/src/core/ext/filters/client_channel/request_routing.h @@ -0,0 +1,177 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H +#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/client_channel_channelz.h" +#include "src/core/ext/filters/client_channel/client_channel_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/iomgr/call_combiner.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/metadata_batch.h" + +namespace grpc_core { + +class RequestRouter { + public: + class Request { + public: + // Synchronous callback that applies the service config to a call. + // Returns false if the call should be failed. + typedef bool (*ApplyServiceConfigCallback)(void* user_data); + + Request(grpc_call_stack* owning_call, grpc_call_combiner* call_combiner, + grpc_polling_entity* pollent, + grpc_metadata_batch* send_initial_metadata, + uint32_t* send_initial_metadata_flags, + ApplyServiceConfigCallback apply_service_config, + void* apply_service_config_user_data, grpc_closure* on_route_done); + + ~Request(); + + // TODO(roth): It seems a bit ugly to expose this member in a + // non-const way. Find a better API to avoid this. + LoadBalancingPolicy::PickState* pick() { return &pick_; } + + private: + friend class RequestRouter; + + class ResolverResultWaiter; + class AsyncPickCanceller; + + void ProcessServiceConfigAndStartLbPickLocked(); + void StartLbPickLocked(); + static void LbPickDoneLocked(void* arg, grpc_error* error); + + void MaybeAddCallToInterestedPartiesLocked(); + void MaybeRemoveCallFromInterestedPartiesLocked(); + + // Populated by caller. + grpc_call_stack* owning_call_; + grpc_call_combiner* call_combiner_; + grpc_polling_entity* pollent_; + ApplyServiceConfigCallback apply_service_config_; + void* apply_service_config_user_data_; + grpc_closure* on_route_done_; + LoadBalancingPolicy::PickState pick_; + + // Internal state. + RequestRouter* request_router_ = nullptr; + bool pollent_added_to_interested_parties_ = false; + grpc_closure on_pick_done_; + AsyncPickCanceller* pick_canceller_ = nullptr; + }; + + // Synchronous callback that takes the service config JSON string and + // LB policy name. + // Returns true if the service config has changed since the last result. + typedef bool (*ProcessResolverResultCallback)(void* user_data, + const grpc_channel_args& args, + const char** lb_policy_name, + grpc_json** lb_policy_config); + + RequestRouter(grpc_channel_stack* owning_stack, grpc_combiner* combiner, + grpc_client_channel_factory* client_channel_factory, + grpc_pollset_set* interested_parties, TraceFlag* tracer, + ProcessResolverResultCallback process_resolver_result, + void* process_resolver_result_user_data, const char* target_uri, + const grpc_channel_args* args, grpc_error** error); + + ~RequestRouter(); + + void set_channelz_node(channelz::ClientChannelNode* channelz_node) { + channelz_node_ = channelz_node; + } + + void RouteCallLocked(Request* request); + + // TODO(roth): Add methods to cancel picks. + + void ShutdownLocked(grpc_error* error); + + void ExitIdleLocked(); + void ResetConnectionBackoffLocked(); + + grpc_connectivity_state GetConnectivityState(); + void NotifyOnConnectivityStateChange(grpc_connectivity_state* state, + grpc_closure* closure); + + LoadBalancingPolicy* lb_policy() const { return lb_policy_.get(); } + + private: + using TraceStringVector = grpc_core::InlinedVector<char*, 3>; + + class ReresolutionRequestHandler; + class LbConnectivityWatcher; + + void StartResolvingLocked(); + void OnResolverShutdownLocked(grpc_error* error); + void CreateNewLbPolicyLocked(const char* lb_policy_name, grpc_json* lb_config, + grpc_connectivity_state* connectivity_state, + grpc_error** connectivity_error, + TraceStringVector* trace_strings); + void MaybeAddTraceMessagesForAddressChangesLocked( + TraceStringVector* trace_strings); + void ConcatenateAndAddChannelTraceLocked( + TraceStringVector* trace_strings) const; + static void OnResolverResultChangedLocked(void* arg, grpc_error* error); + + void SetConnectivityStateLocked(grpc_connectivity_state state, + grpc_error* error, const char* reason); + + // Passed in from caller at construction time. + grpc_channel_stack* owning_stack_; + grpc_combiner* combiner_; + grpc_client_channel_factory* client_channel_factory_; + grpc_pollset_set* interested_parties_; + TraceFlag* tracer_; + + channelz::ClientChannelNode* channelz_node_ = nullptr; + + // Resolver and associated state. + OrphanablePtr<Resolver> resolver_; + ProcessResolverResultCallback process_resolver_result_; + void* process_resolver_result_user_data_; + bool started_resolving_ = false; + grpc_channel_args* resolver_result_ = nullptr; + bool previous_resolution_contained_addresses_ = false; + grpc_closure_list waiting_for_resolver_result_closures_; + grpc_closure on_resolver_result_changed_; + + // LB policy and associated state. + OrphanablePtr<LoadBalancingPolicy> lb_policy_; + bool exit_idle_when_lb_policy_arrives_ = false; + + grpc_connectivity_state_tracker state_tracker_; +}; + +} // namespace grpc_core + +#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc index c8425ae336..abacd0c960 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc @@ -170,7 +170,7 @@ AresDnsResolver::AresDnsResolver(const ResolverArgs& args) } AresDnsResolver::~AresDnsResolver() { - gpr_log(GPR_DEBUG, "destroying AresDnsResolver"); + GRPC_CARES_TRACE_LOG("resolver:%p destroying AresDnsResolver", this); if (resolved_result_ != nullptr) { grpc_channel_args_destroy(resolved_result_); } @@ -182,7 +182,8 @@ AresDnsResolver::~AresDnsResolver() { void AresDnsResolver::NextLocked(grpc_channel_args** target_result, grpc_closure* on_complete) { - gpr_log(GPR_DEBUG, "AresDnsResolver::NextLocked() is called."); + GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::NextLocked() is called.", + this); GPR_ASSERT(next_completion_ == nullptr); next_completion_ = on_complete; target_result_ = target_result; @@ -225,12 +226,14 @@ void AresDnsResolver::ShutdownLocked() { void AresDnsResolver::OnNextResolutionLocked(void* arg, grpc_error* error) { AresDnsResolver* r = static_cast<AresDnsResolver*>(arg); GRPC_CARES_TRACE_LOG( - "%p re-resolution timer fired. error: %s. shutdown_initiated_: %d", r, - grpc_error_string(error), r->shutdown_initiated_); + "resolver:%p re-resolution timer fired. error: %s. shutdown_initiated_: " + "%d", + r, grpc_error_string(error), r->shutdown_initiated_); r->have_next_resolution_timer_ = false; if (error == GRPC_ERROR_NONE && !r->shutdown_initiated_) { if (!r->resolving_) { - GRPC_CARES_TRACE_LOG("%p start resolving due to re-resolution timer", r); + GRPC_CARES_TRACE_LOG( + "resolver:%p start resolving due to re-resolution timer", r); r->StartResolvingLocked(); } } @@ -327,8 +330,8 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { service_config_string = ChooseServiceConfig(r->service_config_json_); gpr_free(r->service_config_json_); if (service_config_string != nullptr) { - gpr_log(GPR_INFO, "selected service config choice: %s", - service_config_string); + GRPC_CARES_TRACE_LOG("resolver:%p selected service config choice: %s", + r, service_config_string); args_to_remove[num_args_to_remove++] = GRPC_ARG_SERVICE_CONFIG; args_to_add[num_args_to_add++] = grpc_channel_arg_string_create( (char*)GRPC_ARG_SERVICE_CONFIG, service_config_string); @@ -344,11 +347,11 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { r->backoff_.Reset(); } else if (!r->shutdown_initiated_) { const char* msg = grpc_error_string(error); - gpr_log(GPR_DEBUG, "dns resolution failed: %s", msg); + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed: %s", r, msg); grpc_millis next_try = r->backoff_.NextAttemptTime(); grpc_millis timeout = next_try - ExecCtx::Get()->Now(); - gpr_log(GPR_INFO, "dns resolution failed (will retry): %s", - grpc_error_string(error)); + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed (will retry): %s", + r, grpc_error_string(error)); GPR_ASSERT(!r->have_next_resolution_timer_); r->have_next_resolution_timer_ = true; // TODO(roth): We currently deal with this ref manually. Once the @@ -357,9 +360,10 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { RefCountedPtr<Resolver> self = r->Ref(DEBUG_LOCATION, "retry-timer"); self.release(); if (timeout > 0) { - gpr_log(GPR_DEBUG, "retrying in %" PRId64 " milliseconds", timeout); + GRPC_CARES_TRACE_LOG("resolver:%p retrying in %" PRId64 " milliseconds", + r, timeout); } else { - gpr_log(GPR_DEBUG, "retrying immediately"); + GRPC_CARES_TRACE_LOG("resolver:%p retrying immediately", r); } grpc_timer_init(&r->next_resolution_timer_, next_try, &r->on_next_resolution_); @@ -385,10 +389,10 @@ void AresDnsResolver::MaybeStartResolvingLocked() { if (ms_until_next_resolution > 0) { const grpc_millis last_resolution_ago = grpc_core::ExecCtx::Get()->Now() - last_resolution_timestamp_; - gpr_log(GPR_DEBUG, - "In cooldown from last resolution (from %" PRId64 - " ms ago). Will resolve again in %" PRId64 " ms", - last_resolution_ago, ms_until_next_resolution); + GRPC_CARES_TRACE_LOG( + "resolver:%p In cooldown from last resolution (from %" PRId64 + " ms ago). Will resolve again in %" PRId64 " ms", + this, last_resolution_ago, ms_until_next_resolution); have_next_resolution_timer_ = true; // TODO(roth): We currently deal with this ref manually. Once the // new closure API is done, find a way to track this ref with the timer @@ -405,7 +409,6 @@ void AresDnsResolver::MaybeStartResolvingLocked() { } void AresDnsResolver::StartResolvingLocked() { - gpr_log(GPR_DEBUG, "Start resolving."); // TODO(roth): We currently deal with this ref manually. Once the // new closure API is done, find a way to track this ref with the timer // callback as part of the type system. @@ -420,6 +423,8 @@ void AresDnsResolver::StartResolvingLocked() { request_service_config_ ? &service_config_json_ : nullptr, query_timeout_ms_, combiner()); last_resolution_timestamp_ = grpc_core::ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG("resolver:%p Started resolving. pending_request_:%p", + this, pending_request_); } void AresDnsResolver::MaybeFinishNextLocked() { @@ -427,7 +432,8 @@ void AresDnsResolver::MaybeFinishNextLocked() { *target_result_ = resolved_result_ == nullptr ? nullptr : grpc_channel_args_copy(resolved_result_); - gpr_log(GPR_DEBUG, "AresDnsResolver::MaybeFinishNextLocked()"); + GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::MaybeFinishNextLocked()", + this); GRPC_CLOSURE_SCHED(next_completion_, GRPC_ERROR_NONE); next_completion_ = nullptr; published_version_ = resolved_version_; @@ -465,11 +471,16 @@ static grpc_error* blocking_resolve_address_ares( static grpc_address_resolver_vtable ares_resolver = { grpc_resolve_address_ares, blocking_resolve_address_ares}; +static bool should_use_ares(const char* resolver_env) { + return resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0; +} + void grpc_resolver_dns_ares_init() { char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER"); /* TODO(zyc): Turn on c-ares based resolver by default after the address sorter and the CNAME support are added. */ - if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + if (should_use_ares(resolver_env)) { + gpr_log(GPR_DEBUG, "Using ares dns resolver"); address_sorting_init(); grpc_error* error = grpc_ares_init(); if (error != GRPC_ERROR_NONE) { @@ -489,7 +500,7 @@ void grpc_resolver_dns_ares_init() { void grpc_resolver_dns_ares_shutdown() { char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER"); - if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + if (should_use_ares(resolver_env)) { address_sorting_shutdown(); grpc_ares_cleanup(); } diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc index 8abc34c6ed..d99c2e3004 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc @@ -90,15 +90,18 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver); static grpc_ares_ev_driver* grpc_ares_ev_driver_ref( grpc_ares_ev_driver* ev_driver) { - gpr_log(GPR_DEBUG, "Ref ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p Ref ev_driver %p", ev_driver->request, + ev_driver); gpr_ref(&ev_driver->refs); return ev_driver; } static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) { - gpr_log(GPR_DEBUG, "Unref ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p Unref ev_driver %p", ev_driver->request, + ev_driver); if (gpr_unref(&ev_driver->refs)) { - gpr_log(GPR_DEBUG, "destroy ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p destroy ev_driver %p", ev_driver->request, + ev_driver); GPR_ASSERT(ev_driver->fds == nullptr); GRPC_COMBINER_UNREF(ev_driver->combiner, "free ares event driver"); ares_destroy(ev_driver->channel); @@ -108,7 +111,8 @@ static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) { } static void fd_node_destroy_locked(fd_node* fdn) { - gpr_log(GPR_DEBUG, "delete fd: %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p delete fd: %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); GPR_ASSERT(!fdn->readable_registered); GPR_ASSERT(!fdn->writable_registered); GPR_ASSERT(fdn->already_shutdown); @@ -136,7 +140,7 @@ grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver, memset(&opts, 0, sizeof(opts)); opts.flags |= ARES_FLAG_STAYOPEN; int status = ares_init_options(&(*ev_driver)->channel, &opts, ARES_OPT_FLAGS); - gpr_log(GPR_DEBUG, "grpc_ares_ev_driver_create_locked"); + GRPC_CARES_TRACE_LOG("request:%p grpc_ares_ev_driver_create_locked", request); if (status != ARES_SUCCESS) { char* err_msg; gpr_asprintf(&err_msg, "Failed to init ares channel. C-ares error: %s", @@ -203,8 +207,9 @@ static fd_node* pop_fd_node_locked(fd_node** head, ares_socket_t as) { static void on_timeout_locked(void* arg, grpc_error* error) { grpc_ares_ev_driver* driver = static_cast<grpc_ares_ev_driver*>(arg); GRPC_CARES_TRACE_LOG( - "ev_driver=%p on_timeout_locked. driver->shutting_down=%d. err=%s", - driver, driver->shutting_down, grpc_error_string(error)); + "request:%p ev_driver=%p on_timeout_locked. driver->shutting_down=%d. " + "err=%s", + driver->request, driver, driver->shutting_down, grpc_error_string(error)); if (!driver->shutting_down && error == GRPC_ERROR_NONE) { grpc_ares_ev_driver_shutdown_locked(driver); } @@ -216,7 +221,8 @@ static void on_readable_locked(void* arg, grpc_error* error) { grpc_ares_ev_driver* ev_driver = fdn->ev_driver; const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); fdn->readable_registered = false; - gpr_log(GPR_DEBUG, "readable on %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p readable on %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); if (error == GRPC_ERROR_NONE) { do { ares_process_fd(ev_driver->channel, as, ARES_SOCKET_BAD); @@ -239,7 +245,8 @@ static void on_writable_locked(void* arg, grpc_error* error) { grpc_ares_ev_driver* ev_driver = fdn->ev_driver; const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); fdn->writable_registered = false; - gpr_log(GPR_DEBUG, "writable on %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p writable on %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); if (error == GRPC_ERROR_NONE) { ares_process_fd(ev_driver->channel, ARES_SOCKET_BAD, as); } else { @@ -278,7 +285,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { fdn->grpc_polled_fd = ev_driver->polled_fd_factory->NewGrpcPolledFdLocked( socks[i], ev_driver->pollset_set, ev_driver->combiner); - gpr_log(GPR_DEBUG, "new fd: %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p new fd: %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); fdn->ev_driver = ev_driver; fdn->readable_registered = false; fdn->writable_registered = false; @@ -295,8 +303,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { if (ARES_GETSOCK_READABLE(socks_bitmask, i) && !fdn->readable_registered) { grpc_ares_ev_driver_ref(ev_driver); - gpr_log(GPR_DEBUG, "notify read on: %s", - fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p notify read on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); fdn->grpc_polled_fd->RegisterForOnReadableLocked(&fdn->read_closure); fdn->readable_registered = true; } @@ -304,8 +313,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { // has not been registered with this socket. if (ARES_GETSOCK_WRITABLE(socks_bitmask, i) && !fdn->writable_registered) { - gpr_log(GPR_DEBUG, "notify write on: %s", - fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p notify write on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); grpc_ares_ev_driver_ref(ev_driver); fdn->grpc_polled_fd->RegisterForOnWriteableLocked( &fdn->write_closure); @@ -332,7 +342,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { // If the ev driver has no working fd, all the tasks are done. if (new_list == nullptr) { ev_driver->working = false; - gpr_log(GPR_DEBUG, "ev driver stop working"); + GRPC_CARES_TRACE_LOG("request:%p ev driver stop working", + ev_driver->request); } } @@ -345,9 +356,9 @@ void grpc_ares_ev_driver_start_locked(grpc_ares_ev_driver* ev_driver) { ? GRPC_MILLIS_INF_FUTURE : ev_driver->query_timeout_ms + grpc_core::ExecCtx::Get()->Now(); GRPC_CARES_TRACE_LOG( - "ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in %" PRId64 - " ms", - ev_driver, timeout); + "request:%p ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in " + "%" PRId64 " ms", + ev_driver->request, ev_driver, timeout); grpc_ares_ev_driver_ref(ev_driver); grpc_timer_init(&ev_driver->query_timeout, timeout, &ev_driver->on_timeout_locked); diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc index 1b1c2303da..1a7e5d0626 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc @@ -96,11 +96,11 @@ static void log_address_sorting_list(const ServerAddressList& addresses, for (size_t i = 0; i < addresses.size(); i++) { char* addr_str; if (grpc_sockaddr_to_string(&addr_str, &addresses[i].address(), true)) { - gpr_log(GPR_DEBUG, "c-ares address sorting: %s[%" PRIuPTR "]=%s", + gpr_log(GPR_INFO, "c-ares address sorting: %s[%" PRIuPTR "]=%s", input_output_str, i, addr_str); gpr_free(addr_str); } else { - gpr_log(GPR_DEBUG, + gpr_log(GPR_INFO, "c-ares address sorting: %s[%" PRIuPTR "]=<unprintable>", input_output_str, i); } @@ -209,10 +209,10 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts, addresses.emplace_back(&addr, addr_len, args); char output[INET6_ADDRSTRLEN]; ares_inet_ntop(AF_INET6, &addr.sin6_addr, output, INET6_ADDRSTRLEN); - gpr_log(GPR_DEBUG, - "c-ares resolver gets a AF_INET6 result: \n" - " addr: %s\n port: %d\n sin6_scope_id: %d\n", - output, ntohs(hr->port), addr.sin6_scope_id); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET6 result: \n" + " addr: %s\n port: %d\n sin6_scope_id: %d\n", + r, output, ntohs(hr->port), addr.sin6_scope_id); break; } case AF_INET: { @@ -226,10 +226,10 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts, addresses.emplace_back(&addr, addr_len, args); char output[INET_ADDRSTRLEN]; ares_inet_ntop(AF_INET, &addr.sin_addr, output, INET_ADDRSTRLEN); - gpr_log(GPR_DEBUG, - "c-ares resolver gets a AF_INET result: \n" - " addr: %s\n port: %d\n", - output, ntohs(hr->port)); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET result: \n" + " addr: %s\n port: %d\n", + r, output, ntohs(hr->port)); break; } } @@ -252,9 +252,9 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts, static void on_srv_query_done_locked(void* arg, int status, int timeouts, unsigned char* abuf, int alen) { grpc_ares_request* r = static_cast<grpc_ares_request*>(arg); - gpr_log(GPR_DEBUG, "on_query_srv_done_locked"); + GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked", r); if (status == ARES_SUCCESS) { - gpr_log(GPR_DEBUG, "on_query_srv_done_locked ARES_SUCCESS"); + GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked ARES_SUCCESS", r); struct ares_srv_reply* reply; const int parse_status = ares_parse_srv_reply(abuf, alen, &reply); if (parse_status == ARES_SUCCESS) { @@ -297,9 +297,9 @@ static const char g_service_config_attribute_prefix[] = "grpc_config="; static void on_txt_done_locked(void* arg, int status, int timeouts, unsigned char* buf, int len) { - gpr_log(GPR_DEBUG, "on_txt_done_locked"); char* error_msg; grpc_ares_request* r = static_cast<grpc_ares_request*>(arg); + GRPC_CARES_TRACE_LOG("request:%p on_txt_done_locked", r); const size_t prefix_len = sizeof(g_service_config_attribute_prefix) - 1; struct ares_txt_ext* result = nullptr; struct ares_txt_ext* reply = nullptr; @@ -332,7 +332,8 @@ static void on_txt_done_locked(void* arg, int status, int timeouts, service_config_len += result->length; } (*r->service_config_json_out)[service_config_len] = '\0'; - gpr_log(GPR_INFO, "found service config: %s", *r->service_config_json_out); + GRPC_CARES_TRACE_LOG("request:%p found service config: %s", r, + *r->service_config_json_out); } // Clean up. ares_free_data(reply); @@ -358,12 +359,6 @@ void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( grpc_error* error = GRPC_ERROR_NONE; grpc_ares_hostbyname_request* hr = nullptr; ares_channel* channel = nullptr; - /* TODO(zyc): Enable tracing after #9603 is checked in */ - /* if (grpc_dns_trace) { - gpr_log(GPR_DEBUG, "resolve_address (blocking): name=%s, default_port=%s", - name, default_port); - } */ - /* parse name, splitting it into host and port parts */ char* host; char* port; @@ -388,7 +383,7 @@ void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( channel = grpc_ares_ev_driver_get_channel_locked(r->ev_driver); // If dns_server is specified, use it. if (dns_server != nullptr) { - gpr_log(GPR_INFO, "Using DNS server %s", dns_server); + GRPC_CARES_TRACE_LOG("request:%p Using DNS server %s", r, dns_server); grpc_resolved_address addr; if (grpc_parse_ipv4_hostport(dns_server, &addr, false /* log_errors */)) { r->dns_server_addr.family = AF_INET; @@ -510,6 +505,28 @@ static bool resolve_as_ip_literal_locked( return out; } +static bool target_matches_localhost_inner(const char* name, char** host, + char** port) { + if (!gpr_split_host_port(name, host, port)) { + gpr_log(GPR_ERROR, "Unable to split host and port for name: %s", name); + return false; + } + if (gpr_stricmp(*host, "localhost") == 0) { + return true; + } else { + return false; + } +} + +static bool target_matches_localhost(const char* name) { + char* host = nullptr; + char* port = nullptr; + bool out = target_matches_localhost_inner(name, &host, &port); + gpr_free(host); + gpr_free(port); + return out; +} + static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, @@ -525,6 +542,10 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( r->success = false; r->error = GRPC_ERROR_NONE; r->pending_queries = 0; + GRPC_CARES_TRACE_LOG( + "request:%p c-ares grpc_dns_lookup_ares_locked_impl name=%s, " + "default_port=%s", + r, name, default_port); // Early out if the target is an ipv4 or ipv6 literal. if (resolve_as_ip_literal_locked(name, default_port, addrs)) { GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE); @@ -536,6 +557,13 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE); return r; } + // Don't query for SRV and TXT records if the target is "localhost", so + // as to cut down on lookups over the network, especially in tests: + // https://github.com/grpc/proposal/pull/79 + if (target_matches_localhost(name)) { + check_grpclb = false; + r->service_config_json_out = nullptr; + } // Look up name using c-ares lib. grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( r, dns_server, name, default_port, interested_parties, check_grpclb, diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.cc b/src/core/ext/filters/client_channel/resolver_result_parsing.cc index 22b06db45c..9a0122e8ec 100644 --- a/src/core/ext/filters/client_channel/resolver_result_parsing.cc +++ b/src/core/ext/filters/client_channel/resolver_result_parsing.cc @@ -43,16 +43,16 @@ namespace grpc_core { namespace internal { ProcessedResolverResult::ProcessedResolverResult( - const grpc_channel_args* resolver_result, bool parse_retry) { + const grpc_channel_args& resolver_result, bool parse_retry) { ProcessServiceConfig(resolver_result, parse_retry); // If no LB config was found above, just find the LB policy name then. if (lb_policy_name_ == nullptr) ProcessLbPolicyName(resolver_result); } void ProcessedResolverResult::ProcessServiceConfig( - const grpc_channel_args* resolver_result, bool parse_retry) { + const grpc_channel_args& resolver_result, bool parse_retry) { const grpc_arg* channel_arg = - grpc_channel_args_find(resolver_result, GRPC_ARG_SERVICE_CONFIG); + grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVICE_CONFIG); const char* service_config_json = grpc_channel_arg_get_string(channel_arg); if (service_config_json != nullptr) { service_config_json_.reset(gpr_strdup(service_config_json)); @@ -60,7 +60,7 @@ void ProcessedResolverResult::ProcessServiceConfig( if (service_config_ != nullptr) { if (parse_retry) { channel_arg = - grpc_channel_args_find(resolver_result, GRPC_ARG_SERVER_URI); + grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVER_URI); const char* server_uri = grpc_channel_arg_get_string(channel_arg); GPR_ASSERT(server_uri != nullptr); grpc_uri* uri = grpc_uri_parse(server_uri, true); @@ -78,7 +78,7 @@ void ProcessedResolverResult::ProcessServiceConfig( } void ProcessedResolverResult::ProcessLbPolicyName( - const grpc_channel_args* resolver_result) { + const grpc_channel_args& resolver_result) { // Prefer the LB policy name found in the service config. Note that this is // checking the deprecated loadBalancingPolicy field, rather than the new // loadBalancingConfig field. @@ -96,13 +96,13 @@ void ProcessedResolverResult::ProcessLbPolicyName( // Otherwise, find the LB policy name set by the client API. if (lb_policy_name_ == nullptr) { const grpc_arg* channel_arg = - grpc_channel_args_find(resolver_result, GRPC_ARG_LB_POLICY_NAME); + grpc_channel_args_find(&resolver_result, GRPC_ARG_LB_POLICY_NAME); lb_policy_name_.reset(gpr_strdup(grpc_channel_arg_get_string(channel_arg))); } // Special case: If at least one balancer address is present, we use // the grpclb policy, regardless of what the resolver has returned. const ServerAddressList* addresses = - FindServerAddressListChannelArg(resolver_result); + FindServerAddressListChannelArg(&resolver_result); if (addresses != nullptr) { bool found_balancer_address = false; for (size_t i = 0; i < addresses->size(); ++i) { diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.h b/src/core/ext/filters/client_channel/resolver_result_parsing.h index f1fb7406bc..98a9d26c46 100644 --- a/src/core/ext/filters/client_channel/resolver_result_parsing.h +++ b/src/core/ext/filters/client_channel/resolver_result_parsing.h @@ -36,8 +36,7 @@ namespace internal { class ClientChannelMethodParams; // A table mapping from a method name to its method parameters. -typedef grpc_core::SliceHashTable< - grpc_core::RefCountedPtr<ClientChannelMethodParams>> +typedef SliceHashTable<RefCountedPtr<ClientChannelMethodParams>> ClientChannelMethodParamsTable; // A container of processed fields from the resolver result. Simplifies the @@ -47,33 +46,30 @@ class ProcessedResolverResult { // Processes the resolver result and populates the relative members // for later consumption. Tries to parse retry parameters only if parse_retry // is true. - ProcessedResolverResult(const grpc_channel_args* resolver_result, + ProcessedResolverResult(const grpc_channel_args& resolver_result, bool parse_retry); // Getters. Any managed object's ownership is transferred. - grpc_core::UniquePtr<char> service_config_json() { + UniquePtr<char> service_config_json() { return std::move(service_config_json_); } - grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data() { + RefCountedPtr<ServerRetryThrottleData> retry_throttle_data() { return std::move(retry_throttle_data_); } - grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> - method_params_table() { + RefCountedPtr<ClientChannelMethodParamsTable> method_params_table() { return std::move(method_params_table_); } - grpc_core::UniquePtr<char> lb_policy_name() { - return std::move(lb_policy_name_); - } + UniquePtr<char> lb_policy_name() { return std::move(lb_policy_name_); } grpc_json* lb_policy_config() { return lb_policy_config_; } private: // Finds the service config; extracts LB config and (maybe) retry throttle // params from it. - void ProcessServiceConfig(const grpc_channel_args* resolver_result, + void ProcessServiceConfig(const grpc_channel_args& resolver_result, bool parse_retry); // Finds the LB policy name (when no LB config was found). - void ProcessLbPolicyName(const grpc_channel_args* resolver_result); + void ProcessLbPolicyName(const grpc_channel_args& resolver_result); // Parses the service config. Intended to be used by // ServiceConfig::ParseGlobalParams. @@ -85,16 +81,16 @@ class ProcessedResolverResult { void ParseRetryThrottleParamsFromServiceConfig(const grpc_json* field); // Service config. - grpc_core::UniquePtr<char> service_config_json_; - grpc_core::UniquePtr<grpc_core::ServiceConfig> service_config_; + UniquePtr<char> service_config_json_; + UniquePtr<grpc_core::ServiceConfig> service_config_; // LB policy. grpc_json* lb_policy_config_ = nullptr; - grpc_core::UniquePtr<char> lb_policy_name_; + UniquePtr<char> lb_policy_name_; // Retry throttle data. char* server_name_ = nullptr; - grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data_; + RefCountedPtr<ServerRetryThrottleData> retry_throttle_data_; // Method params table. - grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> method_params_table_; + RefCountedPtr<ClientChannelMethodParamsTable> method_params_table_; }; // The parameters of a method. diff --git a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc index e73eee4353..9612698e96 100644 --- a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc +++ b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc @@ -110,14 +110,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( grpc_channel_args* args_with_authority = grpc_channel_args_copy_and_add(args->args, args_to_add, num_args_to_add); grpc_uri_destroy(server_uri); - grpc_channel_security_connector* subchannel_security_connector = nullptr; // Create the security connector using the credentials and target name. grpc_channel_args* new_args_from_connector = nullptr; - const grpc_security_status security_status = - grpc_channel_credentials_create_security_connector( - channel_credentials, authority.get(), args_with_authority, - &subchannel_security_connector, &new_args_from_connector); - if (security_status != GRPC_SECURITY_OK) { + grpc_core::RefCountedPtr<grpc_channel_security_connector> + subchannel_security_connector = + channel_credentials->create_security_connector( + /*call_creds=*/nullptr, authority.get(), args_with_authority, + &new_args_from_connector); + if (subchannel_security_connector == nullptr) { gpr_log(GPR_ERROR, "Failed to create secure subchannel for secure name '%s'", authority.get()); @@ -125,15 +125,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( return nullptr; } grpc_arg new_security_connector_arg = - grpc_security_connector_to_arg(&subchannel_security_connector->base); + grpc_security_connector_to_arg(subchannel_security_connector.get()); grpc_channel_args* new_args = grpc_channel_args_copy_and_add( new_args_from_connector != nullptr ? new_args_from_connector : args_with_authority, &new_security_connector_arg, 1); - GRPC_SECURITY_CONNECTOR_UNREF(&subchannel_security_connector->base, - "lb_channel_create"); + subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create"); if (new_args_from_connector != nullptr) { grpc_channel_args_destroy(new_args_from_connector); } diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc index 6689a17da6..98fdb62070 100644 --- a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc +++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc @@ -31,6 +31,7 @@ #include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/surface/api_trace.h" @@ -40,9 +41,8 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, grpc_server_credentials* creds) { grpc_core::ExecCtx exec_ctx; grpc_error* err = GRPC_ERROR_NONE; - grpc_server_security_connector* sc = nullptr; + grpc_core::RefCountedPtr<grpc_server_security_connector> sc; int port_num = 0; - grpc_security_status status; grpc_channel_args* args = nullptr; GRPC_API_TRACE( "grpc_server_add_secure_http2_port(" @@ -54,30 +54,27 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, "No credentials specified for secure server port (creds==NULL)"); goto done; } - status = grpc_server_credentials_create_security_connector(creds, &sc); - if (status != GRPC_SECURITY_OK) { + sc = creds->create_security_connector(); + if (sc == nullptr) { char* msg; gpr_asprintf(&msg, "Unable to create secure server with credentials of type %s.", - creds->type); - err = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg), - GRPC_ERROR_INT_SECURITY_STATUS, status); + creds->type()); + err = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); gpr_free(msg); goto done; } // Create channel args. grpc_arg args_to_add[2]; args_to_add[0] = grpc_server_credentials_to_arg(creds); - args_to_add[1] = grpc_security_connector_to_arg(&sc->base); + args_to_add[1] = grpc_security_connector_to_arg(sc.get()); args = grpc_channel_args_copy_and_add(grpc_server_get_channel_args(server), args_to_add, GPR_ARRAY_SIZE(args_to_add)); // Add server port. err = grpc_chttp2_server_add_port(server, addr, args, &port_num); done: - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server"); - } + sc.reset(DEBUG_LOCATION, "server"); if (err != GRPC_ERROR_NONE) { const char* msg = grpc_error_string(err); diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index 9b6574b612..7f4627fa77 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -170,7 +170,12 @@ grpc_chttp2_transport::~grpc_chttp2_transport() { grpc_slice_buffer_destroy_internal(&outbuf); grpc_chttp2_hpack_compressor_destroy(&hpack_compressor); - grpc_core::ContextList::Execute(cl, nullptr, GRPC_ERROR_NONE); + grpc_error* error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed"); + // ContextList::Execute follows semantics of a callback function and does not + // take a ref on error + grpc_core::ContextList::Execute(cl, nullptr, error); + GRPC_ERROR_UNREF(error); cl = nullptr; grpc_slice_buffer_destroy_internal(&read_buffer); diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc index f30d41c332..df09809067 100644 --- a/src/core/ext/transport/chttp2/transport/context_list.cc +++ b/src/core/ext/transport/chttp2/transport/context_list.cc @@ -21,31 +21,47 @@ #include "src/core/ext/transport/chttp2/transport/context_list.h" namespace { -void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*) = nullptr; -} +void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*, + grpc_error* error) = nullptr; +void* (*get_copied_context_fn_g)(void*) = nullptr; +} // namespace namespace grpc_core { +void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) { + if (get_copied_context_fn_g == nullptr || + write_timestamps_callback_g == nullptr) { + return; + } + /* Create a new element in the list and add it at the front */ + ContextList* elem = grpc_core::New<ContextList>(); + elem->trace_context_ = get_copied_context_fn_g(s->context); + elem->byte_offset_ = s->byte_counter; + elem->next_ = *head; + *head = elem; +} + void ContextList::Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error) { ContextList* head = static_cast<ContextList*>(arg); ContextList* to_be_freed; while (head != nullptr) { - if (error == GRPC_ERROR_NONE && ts != nullptr) { - if (write_timestamps_callback_g) { - ts->byte_offset = static_cast<uint32_t>(head->byte_offset_); - write_timestamps_callback_g(head->s_->context, ts); - } + if (write_timestamps_callback_g) { + ts->byte_offset = static_cast<uint32_t>(head->byte_offset_); + write_timestamps_callback_g(head->trace_context_, ts, error); } - GRPC_CHTTP2_STREAM_UNREF(static_cast<grpc_chttp2_stream*>(head->s_), - "timestamp"); to_be_freed = head; head = head->next_; grpc_core::Delete(to_be_freed); } } -void grpc_http2_set_write_timestamps_callback( - void (*fn)(void*, grpc_core::Timestamps*)) { +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)) { write_timestamps_callback_g = fn; } + +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) { + get_copied_context_fn_g = fn; +} } /* namespace grpc_core */ diff --git a/src/core/ext/transport/chttp2/transport/context_list.h b/src/core/ext/transport/chttp2/transport/context_list.h index d870107749..5b9d2ab378 100644 --- a/src/core/ext/transport/chttp2/transport/context_list.h +++ b/src/core/ext/transport/chttp2/transport/context_list.h @@ -31,42 +31,23 @@ class ContextList { public: /* Creates a new element with \a context as the value and appends it to the * list. */ - static void Append(ContextList** head, grpc_chttp2_stream* s) { - /* Make sure context is not already present */ - GRPC_CHTTP2_STREAM_REF(s, "timestamp"); - -#ifndef NDEBUG - ContextList* ptr = *head; - while (ptr != nullptr) { - if (ptr->s_ == s) { - GPR_ASSERT( - false && - "Trying to append a stream that is already present in the list"); - } - ptr = ptr->next_; - } -#endif - - /* Create a new element in the list and add it at the front */ - ContextList* elem = grpc_core::New<ContextList>(); - elem->s_ = s; - elem->byte_offset_ = s->byte_counter; - elem->next_ = *head; - *head = elem; - } + static void Append(ContextList** head, grpc_chttp2_stream* s); /* Executes a function \a fn with each context in the list and \a ts. It also - * frees up the entire list after this operation. */ + * frees up the entire list after this operation. It is intended as a callback + * and hence does not take a ref on \a error */ static void Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error); private: - grpc_chttp2_stream* s_ = nullptr; + void* trace_context_ = nullptr; ContextList* next_ = nullptr; size_t byte_offset_ = 0; }; -void grpc_http2_set_write_timestamps_callback( - void (*fn)(void*, grpc_core::Timestamps*)); +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)); +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)); } /* namespace grpc_core */ #endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H */ diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h index 1ed5d584c7..19f38d7f01 100644 --- a/src/core/lib/gprpp/ref_counted_ptr.h +++ b/src/core/lib/gprpp/ref_counted_ptr.h @@ -50,7 +50,7 @@ class RefCountedPtr { } template <typename Y> RefCountedPtr(RefCountedPtr<Y>&& other) { - value_ = other.value_; + value_ = static_cast<T*>(other.value_); other.value_ = nullptr; } @@ -77,7 +77,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor<T>::value, "T does not have a virtual dtor"); if (other.value_ != nullptr) other.value_->IncrementRefCount(); - value_ = other.value_; + value_ = static_cast<T*>(other.value_); } // Copy assignment. @@ -118,7 +118,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor<T>::value, "T does not have a virtual dtor"); if (value_ != nullptr) value_->Unref(); - value_ = value; + value_ = static_cast<T*>(value); } template <typename Y> void reset(const DebugLocation& location, const char* reason, @@ -126,7 +126,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor<T>::value, "T does not have a virtual dtor"); if (value_ != nullptr) value_->Unref(location, reason); - value_ = value; + value_ = static_cast<T*>(value); } // TODO(roth): This method exists solely as a transition mechanism to allow diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 1c798d368b..fdea7511cc 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -29,119 +29,125 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/ssl_utils.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/tsi/ssl_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* handshaker_factory; - char* secure_peer_name; -} grpc_httpcli_ssl_channel_security_connector; - -static void httpcli_ssl_destroy(grpc_security_connector* sc) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - if (c->handshaker_factory != nullptr) { - tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory); - c->handshaker_factory = nullptr; +class grpc_httpcli_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name) + : grpc_channel_security_connector( + /*url_scheme=*/nullptr, + /*channel_creds=*/nullptr, + /*request_metadata_creds=*/nullptr), + secure_peer_name_(secure_peer_name) {} + + ~grpc_httpcli_ssl_channel_security_connector() override { + if (handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(handshaker_factory_); + } + if (secure_peer_name_ != nullptr) { + gpr_free(secure_peer_name_); + } + } + + tsi_result InitHandshakerFactory(const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store) { + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + return tsi_create_ssl_client_handshaker_factory_with_options( + &options, &handshaker_factory_); } - if (c->secure_peer_name != nullptr) gpr_free(c->secure_peer_name); - gpr_free(sc); -} -static void httpcli_ssl_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - tsi_handshaker* handshaker = nullptr; - if (c->handshaker_factory != nullptr) { - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->handshaker_factory, c->secure_peer_name, &handshaker); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + tsi_handshaker* handshaker = nullptr; + if (handshaker_factory_ != nullptr) { + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + handshaker_factory_, secure_peer_name_, &handshaker); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } } + grpc_handshake_manager_add( + handshake_mgr, grpc_security_handshaker_create(handshaker, this)); } - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(handshaker, &sc->base)); -} -static void httpcli_ssl_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - grpc_error* error = GRPC_ERROR_NONE; - - /* Check the peer name. */ - if (c->secure_peer_name != nullptr && - !tsi_ssl_peer_matches_name(&peer, c->secure_peer_name)) { - char* msg; - gpr_asprintf(&msg, "Peer name %s is not in peer certificate", - c->secure_peer_name); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); + tsi_ssl_client_handshaker_factory* handshaker_factory() const { + return handshaker_factory_; } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static int httpcli_ssl_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_httpcli_ssl_channel_security_connector* c1 = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc1); - grpc_httpcli_ssl_channel_security_connector* c2 = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc2); - return strcmp(c1->secure_peer_name, c2->secure_peer_name); -} + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* /*auth_context*/, + grpc_closure* on_peer_checked) override { + grpc_error* error = GRPC_ERROR_NONE; + + /* Check the peer name. */ + if (secure_peer_name_ != nullptr && + !tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) { + char* msg; + gpr_asprintf(&msg, "Peer name %s is not in peer certificate", + secure_peer_name_); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); + } -static grpc_security_connector_vtable httpcli_ssl_vtable = { - httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp}; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_httpcli_ssl_channel_security_connector*>( + other_sc); + return strcmp(secure_peer_name_, other->secure_peer_name_); + } -static grpc_security_status httpcli_ssl_channel_security_connector_create( - const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, - const char* secure_peer_name, grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_httpcli_ssl_channel_security_connector* c; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + *error = GRPC_ERROR_NONE; + return true; + } - if (secure_peer_name != nullptr && pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, - "Cannot assert a secure peer name without a trust root."); - return GRPC_SECURITY_ERROR; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); } - c = static_cast<grpc_httpcli_ssl_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_httpcli_ssl_channel_security_connector))); + const char* secure_peer_name() const { return secure_peer_name_; } - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &httpcli_ssl_vtable; - if (secure_peer_name != nullptr) { - c->secure_peer_name = gpr_strdup(secure_peer_name); + private: + tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr; + char* secure_peer_name_; +}; + +static grpc_core::RefCountedPtr<grpc_channel_security_connector> +httpcli_ssl_channel_security_connector_create( + const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, + const char* secure_peer_name) { + if (secure_peer_name != nullptr && pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, + "Cannot assert a secure peer name without a trust root."); + return nullptr; } - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.pem_root_certs = pem_root_certs; - options.root_store = root_store; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->handshaker_factory); + grpc_core::RefCountedPtr<grpc_httpcli_ssl_channel_security_connector> c = + grpc_core::MakeRefCounted<grpc_httpcli_ssl_channel_security_connector>( + secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name)); + tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - httpcli_ssl_destroy(&c->base.base); - *sc = nullptr; - return GRPC_SECURITY_ERROR; + return nullptr; } - // We don't actually need a channel credentials object in this case, - // but we set it to a non-nullptr address so that we don't trigger - // assertions in grpc_channel_security_connector_cmp(). - c->base.channel_creds = (grpc_channel_credentials*)1; - c->base.add_handshakers = httpcli_ssl_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return c; } /* handshaker */ @@ -186,10 +192,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, } c->func = on_done; c->arg = arg; - grpc_channel_security_connector* sc = nullptr; - GPR_ASSERT(httpcli_ssl_channel_security_connector_create( - pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK); - grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base); + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store, + host); + GPR_ASSERT(sc != nullptr); + grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_channel_args args = {1, &channel_arg}; c->handshake_mgr = grpc_handshake_manager_create(); grpc_handshakers_add(HANDSHAKER_CLIENT, &args, @@ -197,7 +204,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, grpc_handshake_manager_do_handshake( c->handshake_mgr, tcp, nullptr /* channel_args */, deadline, nullptr /* acceptor */, on_handshake_done, c /* user_data */); - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "httpcli"); + sc.reset(DEBUG_LOCATION, "httpcli"); } const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake}; diff --git a/src/core/lib/http/parser.h b/src/core/lib/http/parser.h index 1d2e13e831..a8f47c96c8 100644 --- a/src/core/lib/http/parser.h +++ b/src/core/lib/http/parser.h @@ -70,13 +70,13 @@ typedef struct grpc_http_request { /* A response */ typedef struct grpc_http_response { /* HTTP status code */ - int status; + int status = 0; /* Headers: count and key/values */ - size_t hdr_count; - grpc_http_header* hdrs; + size_t hdr_count = 0; + grpc_http_header* hdrs = nullptr; /* Body: length and contents; contents are NOT null-terminated */ - size_t body_length; - char* body; + size_t body_length = 0; + char* body = nullptr; } grpc_http_response; typedef struct { diff --git a/src/core/lib/iomgr/resource_quota.cc b/src/core/lib/iomgr/resource_quota.cc index 7e4b3c9b2f..61c366098e 100644 --- a/src/core/lib/iomgr/resource_quota.cc +++ b/src/core/lib/iomgr/resource_quota.cc @@ -665,6 +665,7 @@ void grpc_resource_quota_unref_internal(grpc_resource_quota* resource_quota) { GPR_ASSERT(resource_quota->num_threads_allocated == 0); GRPC_COMBINER_UNREF(resource_quota->combiner, "resource_quota"); gpr_free(resource_quota->name); + gpr_mu_destroy(&resource_quota->thread_count_mu); gpr_free(resource_quota); } } diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc index cfcb190d60..c268c18664 100644 --- a/src/core/lib/iomgr/tcp_posix.cc +++ b/src/core/lib/iomgr/tcp_posix.cc @@ -126,6 +126,7 @@ struct grpc_tcp { int bytes_counter; bool socket_ts_enabled; /* True if timestamping options are set on the socket */ + bool ts_capable; /* Cache whether we can set timestamping options */ gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified on errors anymore */ @@ -589,7 +590,7 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) { */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, grpc_error** error); + ssize_t* sent_length); /** The callback function to be invoked when we get an error on the socket. */ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); @@ -597,13 +598,11 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); #ifdef GRPC_LINUX_ERRQUEUE static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { if (!tcp->socket_ts_enabled) { uint32_t opt = grpc_core::kTimestampingSocketOptions; if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, static_cast<void*>(&opt), sizeof(opt)) != 0) { - *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "setsockopt"), tcp); grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); if (grpc_tcp_trace.enabled()) { gpr_log(GPR_ERROR, "Failed to set timestamping options on the socket."); @@ -784,8 +783,7 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { #else /* GRPC_LINUX_ERRQUEUE */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); GPR_ASSERT(0); return false; @@ -804,7 +802,7 @@ void tcp_shutdown_buffer_list(grpc_tcp* tcp) { gpr_mu_lock(&tcp->tb_mu); grpc_core::TracedBuffer::Shutdown( &tcp->tb_head, tcp->outgoing_buffer_arg, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed")); + GRPC_ERROR_CREATE_FROM_STATIC_STRING("TracedBuffer list shutdown")); gpr_mu_unlock(&tcp->tb_mu); tcp->outgoing_buffer_arg = nullptr; } @@ -820,7 +818,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { struct msghdr msg; struct iovec iov[MAX_WRITE_IOVEC]; msg_iovlen_type iov_size; - ssize_t sent_length; + ssize_t sent_length = 0; size_t sending_length; size_t trailing; size_t unwind_slice_idx; @@ -855,13 +853,19 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { msg.msg_iov = iov; msg.msg_iovlen = iov_size; msg.msg_flags = 0; + bool tried_sending_message = false; if (tcp->outgoing_buffer_arg != nullptr) { - if (!tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length, - error)) { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; tcp_shutdown_buffer_list(tcp); - return true; /* something went wrong with timestamps */ + } else { + tried_sending_message = true; } - } else { + } + if (!tried_sending_message) { msg.msg_control = nullptr; msg.msg_controllen = 0; @@ -1117,6 +1121,7 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, tcp->is_first_read = true; tcp->bytes_counter = -1; tcp->socket_ts_enabled = false; + tcp->ts_capable = true; tcp->outgoing_buffer_arg = nullptr; /* paired with unref in grpc_tcp_destroy */ gpr_ref_init(&tcp->refcount, 1); diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc index 16f40b4f55..8443ee0695 100644 --- a/src/core/lib/security/context/security_context.cc +++ b/src/core/lib/security/context/security_context.cc @@ -23,6 +23,8 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/arena.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/surface/api_trace.h" #include "src/core/lib/surface/call.h" @@ -50,13 +52,11 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call, ctx = static_cast<grpc_client_security_context*>( grpc_call_context_get(call, GRPC_CONTEXT_SECURITY)); if (ctx == nullptr) { - ctx = grpc_client_security_context_create(grpc_call_get_arena(call)); - ctx->creds = grpc_call_credentials_ref(creds); + ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds); grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx, grpc_client_security_context_destroy); } else { - grpc_call_credentials_unref(ctx->creds); - ctx->creds = grpc_call_credentials_ref(creds); + ctx->creds = creds != nullptr ? creds->Ref() : nullptr; } return GRPC_CALL_OK; @@ -66,33 +66,45 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) { void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY); GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call)); if (sec_ctx == nullptr) return nullptr; - return grpc_call_is_client(call) - ? GRPC_AUTH_CONTEXT_REF( - ((grpc_client_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context client") - : GRPC_AUTH_CONTEXT_REF( - ((grpc_server_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context server"); + if (grpc_call_is_client(call)) { + auto* sc = static_cast<grpc_client_security_context*>(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context client") + .release(); + } + } else { + auto* sc = static_cast<grpc_server_security_context*>(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context server") + .release(); + } + } } void grpc_auth_context_release(grpc_auth_context* context) { GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context)); - GRPC_AUTH_CONTEXT_UNREF(context, "grpc_auth_context_unref"); + if (context == nullptr) return; + context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref"); } /* --- grpc_client_security_context --- */ grpc_client_security_context::~grpc_client_security_context() { - grpc_call_credentials_unref(creds); - GRPC_AUTH_CONTEXT_UNREF(auth_context, "client_security_context"); + auth_context.reset(DEBUG_LOCATION, "client_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } } grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena) { + gpr_arena* arena, grpc_call_credentials* creds) { return new (gpr_arena_alloc(arena, sizeof(grpc_client_security_context))) - grpc_client_security_context(); + grpc_client_security_context(creds != nullptr ? creds->Ref() : nullptr); } void grpc_client_security_context_destroy(void* ctx) { @@ -104,7 +116,7 @@ void grpc_client_security_context_destroy(void* ctx) { /* --- grpc_server_security_context --- */ grpc_server_security_context::~grpc_server_security_context() { - GRPC_AUTH_CONTEXT_UNREF(auth_context, "server_security_context"); + auth_context.reset(DEBUG_LOCATION, "server_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } @@ -126,69 +138,11 @@ void grpc_server_security_context_destroy(void* ctx) { static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr}; -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained) { - grpc_auth_context* ctx = - static_cast<grpc_auth_context*>(gpr_zalloc(sizeof(grpc_auth_context))); - gpr_ref_init(&ctx->refcount, 1); - if (chained != nullptr) { - ctx->chained = GRPC_AUTH_CONTEXT_REF(chained, "chained"); - ctx->peer_identity_property_name = - ctx->chained->peer_identity_property_name; - } - return ctx; -} - -#ifndef NDEBUG -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx, - const char* file, int line, - const char* reason) { - if (ctx == nullptr) return nullptr; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val + 1, reason); - } -#else -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx) { - if (ctx == nullptr) return nullptr; -#endif - gpr_ref(&ctx->refcount); - return ctx; -} - -#ifndef NDEBUG -void grpc_auth_context_unref(grpc_auth_context* ctx, const char* file, int line, - const char* reason) { - if (ctx == nullptr) return; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val - 1, reason); - } -#else -void grpc_auth_context_unref(grpc_auth_context* ctx) { - if (ctx == nullptr) return; -#endif - if (gpr_unref(&ctx->refcount)) { - size_t i; - GRPC_AUTH_CONTEXT_UNREF(ctx->chained, "chained"); - if (ctx->properties.array != nullptr) { - for (i = 0; i < ctx->properties.count; i++) { - grpc_auth_property_reset(&ctx->properties.array[i]); - } - gpr_free(ctx->properties.array); - } - gpr_free(ctx); - } -} - const char* grpc_auth_context_peer_identity_property_name( const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name; + return ctx->peer_identity_property_name(); } int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, @@ -204,13 +158,13 @@ int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, name != nullptr ? name : "NULL"); return 0; } - ctx->peer_identity_property_name = prop->name; + ctx->set_peer_identity_property_name(prop->name); return 1; } int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name == nullptr ? 0 : 1; + return ctx->is_authenticated(); } grpc_auth_property_iterator grpc_auth_context_property_iterator( @@ -226,16 +180,17 @@ const grpc_auth_property* grpc_auth_property_iterator_next( grpc_auth_property_iterator* it) { GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it)); if (it == nullptr || it->ctx == nullptr) return nullptr; - while (it->index == it->ctx->properties.count) { - if (it->ctx->chained == nullptr) return nullptr; - it->ctx = it->ctx->chained; + while (it->index == it->ctx->properties().count) { + if (it->ctx->chained() == nullptr) return nullptr; + it->ctx = it->ctx->chained(); it->index = 0; } if (it->name == nullptr) { - return &it->ctx->properties.array[it->index++]; + return &it->ctx->properties().array[it->index++]; } else { - while (it->index < it->ctx->properties.count) { - const grpc_auth_property* prop = &it->ctx->properties.array[it->index++]; + while (it->index < it->ctx->properties().count) { + const grpc_auth_property* prop = + &it->ctx->properties().array[it->index++]; GPR_ASSERT(prop->name != nullptr); if (strcmp(it->name, prop->name) == 0) { return prop; @@ -262,49 +217,56 @@ grpc_auth_property_iterator grpc_auth_context_peer_identity( GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx)); if (ctx == nullptr) return empty_iterator; return grpc_auth_context_find_properties_by_name( - ctx, ctx->peer_identity_property_name); + ctx, ctx->peer_identity_property_name()); } -static void ensure_auth_context_capacity(grpc_auth_context* ctx) { - if (ctx->properties.count == ctx->properties.capacity) { - ctx->properties.capacity = - GPR_MAX(ctx->properties.capacity + 8, ctx->properties.capacity * 2); - ctx->properties.array = static_cast<grpc_auth_property*>( - gpr_realloc(ctx->properties.array, - ctx->properties.capacity * sizeof(grpc_auth_property))); +void grpc_auth_context::ensure_capacity() { + if (properties_.count == properties_.capacity) { + properties_.capacity = + GPR_MAX(properties_.capacity + 8, properties_.capacity * 2); + properties_.array = static_cast<grpc_auth_property*>(gpr_realloc( + properties_.array, properties_.capacity * sizeof(grpc_auth_property))); } } +void grpc_auth_context::add_property(const char* name, const char* value, + size_t value_length) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; + prop->name = gpr_strdup(name); + prop->value = static_cast<char*>(gpr_malloc(value_length + 1)); + memcpy(prop->value, value, value_length); + prop->value[value_length] = '\0'; + prop->value_length = value_length; +} + void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name, const char* value, size_t value_length) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, " "value_length=%lu)", 6, (ctx, name, (int)value_length, (int)value_length, value, (unsigned long)value_length)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; + ctx->add_property(name, value, value_length); +} + +void grpc_auth_context::add_cstring_property(const char* name, + const char* value) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; prop->name = gpr_strdup(name); - prop->value = static_cast<char*>(gpr_malloc(value_length + 1)); - memcpy(prop->value, value, value_length); - prop->value[value_length] = '\0'; - prop->value_length = value_length; + prop->value = gpr_strdup(value); + prop->value_length = strlen(value); } void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx, const char* name, const char* value) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3, (ctx, name, value)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; - prop->name = gpr_strdup(name); - prop->value = gpr_strdup(value); - prop->value_length = strlen(value); + ctx->add_cstring_property(name, value); } void grpc_auth_property_reset(grpc_auth_property* property) { @@ -314,12 +276,17 @@ void grpc_auth_property_reset(grpc_auth_property* property) { } static void auth_context_pointer_arg_destroy(void* p) { - GRPC_AUTH_CONTEXT_UNREF((grpc_auth_context*)p, "auth_context_pointer_arg"); + if (p != nullptr) { + static_cast<grpc_auth_context*>(p)->Unref(DEBUG_LOCATION, + "auth_context_pointer_arg"); + } } static void* auth_context_pointer_arg_copy(void* p) { - return GRPC_AUTH_CONTEXT_REF((grpc_auth_context*)p, - "auth_context_pointer_arg"); + auto* ctx = static_cast<grpc_auth_context*>(p); + return ctx == nullptr + ? nullptr + : ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release(); } static int auth_context_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h index e45415f63b..b43ee5e62d 100644 --- a/src/core/lib/security/context/security_context.h +++ b/src/core/lib/security/context/security_context.h @@ -21,6 +21,8 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/security/credentials/credentials.h" @@ -40,39 +42,59 @@ struct grpc_auth_property_array { size_t capacity = 0; }; -struct grpc_auth_context { - grpc_auth_context() { gpr_ref_init(&refcount, 0); } +void grpc_auth_property_reset(grpc_auth_property* property); - struct grpc_auth_context* chained = nullptr; - grpc_auth_property_array properties; - gpr_refcount refcount; - const char* peer_identity_property_name = nullptr; - grpc_pollset* pollset = nullptr; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_auth_context + : public grpc_core::RefCounted<grpc_auth_context, + grpc_core::NonPolymorphicRefCount> { + public: + explicit grpc_auth_context( + grpc_core::RefCountedPtr<grpc_auth_context> chained) + : grpc_core::RefCounted<grpc_auth_context, + grpc_core::NonPolymorphicRefCount>( + &grpc_trace_auth_context_refcount), + chained_(std::move(chained)) { + if (chained_ != nullptr) { + peer_identity_property_name_ = chained_->peer_identity_property_name_; + } + } + + ~grpc_auth_context() { + chained_.reset(DEBUG_LOCATION, "chained"); + if (properties_.array != nullptr) { + for (size_t i = 0; i < properties_.count; i++) { + grpc_auth_property_reset(&properties_.array[i]); + } + gpr_free(properties_.array); + } + } + + const grpc_auth_context* chained() const { return chained_.get(); } + const grpc_auth_property_array& properties() const { return properties_; } + + bool is_authenticated() const { + return peer_identity_property_name_ != nullptr; + } + const char* peer_identity_property_name() const { + return peer_identity_property_name_; + } + void set_peer_identity_property_name(const char* name) { + peer_identity_property_name_ = name; + } + + void ensure_capacity(); + void add_property(const char* name, const char* value, size_t value_length); + void add_cstring_property(const char* name, const char* value); + + private: + grpc_core::RefCountedPtr<grpc_auth_context> chained_; + grpc_auth_property_array properties_; + const char* peer_identity_property_name_ = nullptr; }; -/* Creation. */ -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained); - -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_AUTH_CONTEXT_REF(p, r) \ - grpc_auth_context_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) \ - grpc_auth_context_unref((p), __FILE__, __LINE__, (r)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy, - const char* file, int line, - const char* reason); -void grpc_auth_context_unref(grpc_auth_context* policy, const char* file, - int line, const char* reason); -#else -#define GRPC_AUTH_CONTEXT_REF(p, r) grpc_auth_context_ref((p)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) grpc_auth_context_unref((p)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy); -void grpc_auth_context_unref(grpc_auth_context* policy); -#endif - -void grpc_auth_property_reset(grpc_auth_property* property); - /* --- grpc_security_context_extension --- Extension to the security context that may be set in a filter and accessed @@ -88,16 +110,18 @@ struct grpc_security_context_extension { Internal client-side security context. */ struct grpc_client_security_context { - grpc_client_security_context() = default; + explicit grpc_client_security_context( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) + : creds(std::move(creds)) {} ~grpc_client_security_context(); - grpc_call_credentials* creds = nullptr; - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> creds; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; grpc_security_context_extension extension; }; grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena); + gpr_arena* arena, grpc_call_credentials* creds); void grpc_client_security_context_destroy(void* ctx); /* --- grpc_server_security_context --- @@ -108,7 +132,7 @@ struct grpc_server_security_context { grpc_server_security_context() = default; ~grpc_server_security_context(); - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; grpc_security_context_extension extension; }; diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc index 1fbef4ae0c..06546492bc 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.cc +++ b/src/core/lib/security/credentials/alts/alts_credentials.cc @@ -33,40 +33,47 @@ #define GRPC_CREDENTIALS_TYPE_ALTS "Alts" #define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal:8080" -static void alts_credentials_destruct(grpc_channel_credentials* creds) { - grpc_alts_credentials* alts_creds = - reinterpret_cast<grpc_alts_credentials*>(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); -} - -static void alts_server_credentials_destruct(grpc_server_credentials* creds) { - grpc_alts_server_credentials* alts_creds = - reinterpret_cast<grpc_alts_server_credentials*>(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); +grpc_alts_credentials::grpc_alts_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_alts_credentials::~grpc_alts_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); } -static grpc_security_status alts_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_alts_channel_security_connector_create( - creds, request_metadata_creds, target_name, sc); + this->Ref(), std::move(call_creds), target_name); } -static grpc_security_status alts_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_alts_server_security_connector_create(creds, sc); +grpc_alts_server_credentials::grpc_alts_server_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_credentials::create_security_connector() { + return grpc_alts_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable alts_credentials_vtable = { - alts_credentials_destruct, alts_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable alts_server_credentials_vtable = { - alts_server_credentials_destruct, alts_server_create_security_connector}; +grpc_alts_server_credentials::~grpc_alts_server_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); +} grpc_channel_credentials* grpc_alts_credentials_create_customized( const grpc_alts_credentials_options* options, @@ -74,17 +81,7 @@ grpc_channel_credentials* grpc_alts_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast<grpc_alts_credentials*>( - gpr_zalloc(sizeof(grpc_alts_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_alts_credentials>(options, handshaker_service_url); } grpc_server_credentials* grpc_alts_server_credentials_create_customized( @@ -93,17 +90,8 @@ grpc_server_credentials* grpc_alts_server_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast<grpc_alts_server_credentials*>( - gpr_zalloc(sizeof(grpc_alts_server_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_alts_server_credentials>(options, + handshaker_service_url); } grpc_channel_credentials* grpc_alts_credentials_create( diff --git a/src/core/lib/security/credentials/alts/alts_credentials.h b/src/core/lib/security/credentials/alts/alts_credentials.h index 810117f2be..cc6d5222b1 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.h +++ b/src/core/lib/security/credentials/alts/alts_credentials.h @@ -27,18 +27,45 @@ #include "src/core/lib/security/credentials/credentials.h" /* Main struct for grpc ALTS channel credential. */ -typedef struct grpc_alts_credentials { - grpc_channel_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_credentials; +class grpc_alts_credentials final : public grpc_channel_credentials { + public: + grpc_alts_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_credentials() override; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /* Main struct for grpc ALTS server credential. */ -typedef struct grpc_alts_server_credentials { - grpc_server_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_server_credentials; +class grpc_alts_server_credentials final : public grpc_server_credentials { + public: + grpc_alts_server_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_server_credentials() override; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /** * This method creates an ALTS channel credential object with customized diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc index b8f409260f..586bbed778 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.cc +++ b/src/core/lib/security/credentials/composite/composite_credentials.cc @@ -20,8 +20,10 @@ #include "src/core/lib/security/credentials/composite/composite_credentials.h" -#include <string.h> +#include <cstring> +#include <new> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/surface/api_trace.h" @@ -31,35 +33,44 @@ /* -- Composite call credentials. -- */ -typedef struct { +static void composite_call_metadata_cb(void* arg, grpc_error* error); + +namespace { +struct grpc_composite_call_credentials_metadata_context { + grpc_composite_call_credentials_metadata_context( + grpc_composite_call_credentials* composite_creds, + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata) + : composite_creds(composite_creds), + pollent(pollent), + auth_md_context(auth_md_context), + md_array(md_array), + on_request_metadata(on_request_metadata) { + GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb, + this, grpc_schedule_on_exec_ctx); + } + grpc_composite_call_credentials* composite_creds; - size_t creds_index; + size_t creds_index = 0; grpc_polling_entity* pollent; grpc_auth_metadata_context auth_md_context; grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_closure internal_on_request_metadata; -} grpc_composite_call_credentials_metadata_context; - -static void composite_call_destruct(grpc_call_credentials* creds) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); - for (size_t i = 0; i < c->inner.num_creds; i++) { - grpc_call_credentials_unref(c->inner.creds_array[i]); - } - gpr_free(c->inner.creds_array); -} +}; +} // namespace static void composite_call_metadata_cb(void* arg, grpc_error* error) { grpc_composite_call_credentials_metadata_context* ctx = static_cast<grpc_composite_call_credentials_metadata_context*>(arg); if (error == GRPC_ERROR_NONE) { + const grpc_composite_call_credentials::CallCredentialsList& inner = + ctx->composite_creds->inner(); /* See if we need to get some more metadata. */ - if (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, + if (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, &ctx->internal_on_request_metadata, &error)) { // Synchronous response, so call ourselves recursively. composite_call_metadata_cb(arg, error); @@ -73,29 +84,18 @@ static void composite_call_metadata_cb(void* arg, grpc_error* error) { gpr_free(ctx); } -static bool composite_call_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context auth_md_context, +bool grpc_composite_call_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, grpc_error** error) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); grpc_composite_call_credentials_metadata_context* ctx; - ctx = static_cast<grpc_composite_call_credentials_metadata_context*>( - gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context))); - ctx->composite_creds = c; - ctx->pollent = pollent; - ctx->auth_md_context = auth_md_context; - ctx->md_array = md_array; - ctx->on_request_metadata = on_request_metadata; - GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata, - composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx); + ctx = grpc_core::New<grpc_composite_call_credentials_metadata_context>( + this, pollent, auth_md_context, md_array, on_request_metadata); bool synchronous = true; - while (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, + const CallCredentialsList& inner = ctx->composite_creds->inner(); + while (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, &ctx->internal_on_request_metadata, error)) { if (*error != GRPC_ERROR_NONE) break; } else { @@ -103,46 +103,66 @@ static bool composite_call_get_request_metadata( break; } } - if (synchronous) gpr_free(ctx); + if (synchronous) grpc_core::Delete(ctx); return synchronous; } -static void composite_call_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); - for (size_t i = 0; i < c->inner.num_creds; ++i) { - grpc_call_credentials_cancel_get_request_metadata( - c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error)); +void grpc_composite_call_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + for (size_t i = 0; i < inner_.size(); ++i) { + inner_[i]->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable composite_call_credentials_vtable = { - composite_call_destruct, composite_call_get_request_metadata, - composite_call_cancel_get_request_metadata}; +static size_t get_creds_array_size(const grpc_call_credentials* creds, + bool is_composite) { + return is_composite + ? static_cast<const grpc_composite_call_credentials*>(creds) + ->inner() + .size() + : 1; +} -static grpc_call_credentials_array get_creds_array( - grpc_call_credentials** creds_addr) { - grpc_call_credentials_array result; - grpc_call_credentials* creds = *creds_addr; - result.creds_array = creds_addr; - result.num_creds = 1; - if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - result = *grpc_composite_call_credentials_get_credentials(creds); +void grpc_composite_call_credentials::push_to_inner( + grpc_core::RefCountedPtr<grpc_call_credentials> creds, bool is_composite) { + if (!is_composite) { + inner_.push_back(std::move(creds)); + return; + } + auto composite_creds = + static_cast<grpc_composite_call_credentials*>(creds.get()); + for (size_t i = 0; i < composite_creds->inner().size(); ++i) { + inner_.push_back(std::move(composite_creds->inner_[i])); } - return result; +} + +grpc_composite_call_credentials::grpc_composite_call_credentials( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) { + const bool creds1_is_composite = + strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const bool creds2_is_composite = + strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) + + get_creds_array_size(creds2.get(), creds2_is_composite); + inner_.reserve(size); + push_to_inner(std::move(creds1), creds1_is_composite); + push_to_inner(std::move(creds2), creds2_is_composite); +} + +static grpc_core::RefCountedPtr<grpc_call_credentials> +composite_call_credentials_create( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2) { + return grpc_core::MakeRefCounted<grpc_composite_call_credentials>( + std::move(creds1), std::move(creds2)); } grpc_call_credentials* grpc_composite_call_credentials_create( grpc_call_credentials* creds1, grpc_call_credentials* creds2, void* reserved) { - size_t i; - size_t creds_array_byte_size; - grpc_call_credentials_array creds1_array; - grpc_call_credentials_array creds2_array; - grpc_composite_call_credentials* c; GRPC_API_TRACE( "grpc_composite_call_credentials_create(creds1=%p, creds2=%p, " "reserved=%p)", @@ -150,120 +170,40 @@ grpc_call_credentials* grpc_composite_call_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(creds1 != nullptr); GPR_ASSERT(creds2 != nullptr); - c = static_cast<grpc_composite_call_credentials*>( - gpr_zalloc(sizeof(grpc_composite_call_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE; - c->base.vtable = &composite_call_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - creds1_array = get_creds_array(&creds1); - creds2_array = get_creds_array(&creds2); - c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds; - creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*); - c->inner.creds_array = - static_cast<grpc_call_credentials**>(gpr_zalloc(creds_array_byte_size)); - for (i = 0; i < creds1_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds1_array.creds_array[i]; - c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds); - } - for (i = 0; i < creds2_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds2_array.creds_array[i]; - c->inner.creds_array[i + creds1_array.num_creds] = - grpc_call_credentials_ref(cur_creds); - } - return &c->base; -} - -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) { - const grpc_composite_call_credentials* c = - reinterpret_cast<const grpc_composite_call_credentials*>(creds); - GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); - return &c->inner; -} -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds) { - size_t i; - if (strcmp(creds->type, type) == 0) { - if (composite_creds != nullptr) *composite_creds = nullptr; - return creds; - } else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - const grpc_call_credentials_array* inner_creds_array = - grpc_composite_call_credentials_get_credentials(creds); - for (i = 0; i < inner_creds_array->num_creds; i++) { - if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) { - if (composite_creds != nullptr) *composite_creds = creds; - return inner_creds_array->creds_array[i]; - } - } - } - return nullptr; + return composite_call_credentials_create(creds1->Ref(), creds2->Ref()) + .release(); } /* -- Composite channel credentials. -- */ -static void composite_channel_destruct(grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - grpc_channel_credentials_unref(c->inner_creds); - grpc_call_credentials_unref(c->call_creds); -} - -static grpc_security_status composite_channel_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_composite_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - grpc_security_status status = GRPC_SECURITY_ERROR; - - GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr && - c->inner_creds->vtable != nullptr && - c->inner_creds->vtable->create_security_connector != nullptr); + grpc_channel_args** new_args) { + GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr); /* If we are passed a call_creds, create a call composite to pass it downstream. */ if (call_creds != nullptr) { - grpc_call_credentials* composite_call_creds = - grpc_composite_call_credentials_create(c->call_creds, call_creds, - nullptr); - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, composite_call_creds, target, args, sc, new_args); - grpc_call_credentials_unref(composite_call_creds); + return inner_creds_->create_security_connector( + composite_call_credentials_create(call_creds_, std::move(call_creds)), + target, args, new_args); } else { - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, c->call_creds, target, args, sc, new_args); + return inner_creds_->create_security_connector(call_creds_, target, args, + new_args); } - return status; } -static grpc_channel_credentials* -composite_channel_duplicate_without_call_credentials( - grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - return grpc_channel_credentials_ref(c->inner_creds); -} - -static grpc_channel_credentials_vtable composite_channel_credentials_vtable = { - composite_channel_destruct, composite_channel_create_security_connector, - composite_channel_duplicate_without_call_credentials}; - grpc_channel_credentials* grpc_composite_channel_credentials_create( grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds, void* reserved) { - grpc_composite_channel_credentials* c = - static_cast<grpc_composite_channel_credentials*>(gpr_zalloc(sizeof(*c))); GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr && reserved == nullptr); GRPC_API_TRACE( "grpc_composite_channel_credentials_create(channel_creds=%p, " "call_creds=%p, reserved=%p)", 3, (channel_creds, call_creds, reserved)); - c->base.type = channel_creds->type; - c->base.vtable = &composite_channel_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->inner_creds = grpc_channel_credentials_ref(channel_creds); - c->call_creds = grpc_call_credentials_ref(call_creds); - return &c->base; + return grpc_core::New<grpc_composite_channel_credentials>( + channel_creds->Ref(), call_creds->Ref()); } diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h index a952ad57f1..7a1c7d5e42 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.h +++ b/src/core/lib/security/credentials/composite/composite_credentials.h @@ -21,39 +21,75 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials** creds_array; - size_t num_creds; -} grpc_call_credentials_array; +/* -- Composite channel credentials. -- */ -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials( - grpc_call_credentials* composite_creds); +class grpc_composite_channel_credentials : public grpc_channel_credentials { + public: + grpc_composite_channel_credentials( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds) + : grpc_channel_credentials(channel_creds->type()), + inner_creds_(std::move(channel_creds)), + call_creds_(std::move(call_creds)) {} -/* Returns creds if creds is of the specified type or the inner creds of the - specified type (if found), if the creds is of type COMPOSITE. - If composite_creds is not NULL, *composite_creds will point to creds if of - type COMPOSITE in case of success. */ -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds); + ~grpc_composite_channel_credentials() override = default; -/* -- Composite channel credentials. -- */ + grpc_core::RefCountedPtr<grpc_channel_credentials> + duplicate_without_call_credentials() override { + return inner_creds_; + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* inner_creds; - grpc_call_credentials* call_creds; -} grpc_composite_channel_credentials; + const grpc_channel_credentials* inner_creds() const { + return inner_creds_.get(); + } + const grpc_call_credentials* call_creds() const { return call_creds_.get(); } + grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); } + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> inner_creds_; + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds_; +}; /* -- Composite call credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_call_credentials_array inner; -} grpc_composite_call_credentials; +class grpc_composite_call_credentials : public grpc_call_credentials { + public: + using CallCredentialsList = + grpc_core::InlinedVector<grpc_core::RefCountedPtr<grpc_call_credentials>, + 2>; + + grpc_composite_call_credentials( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2); + ~grpc_composite_call_credentials() override = default; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const CallCredentialsList& inner() const { return inner_; } + + private: + void push_to_inner(grpc_core::RefCountedPtr<grpc_call_credentials> creds, + bool is_composite); + + CallCredentialsList inner_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_COMPOSITE_COMPOSITE_CREDENTIALS_H \ */ diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc index c43cb440eb..90452d68d6 100644 --- a/src/core/lib/security/credentials/credentials.cc +++ b/src/core/lib/security/credentials/credentials.cc @@ -39,120 +39,24 @@ /* -- Common. -- */ -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds) { - grpc_credentials_metadata_request* r = - static_cast<grpc_credentials_metadata_request*>( - gpr_zalloc(sizeof(grpc_credentials_metadata_request))); - r->creds = grpc_call_credentials_ref(creds); - return r; -} - -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r) { - grpc_call_credentials_unref(r->creds); - grpc_http_response_destroy(&r->response); - gpr_free(r); -} - -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_channel_credentials_unref(grpc_channel_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } -} - void grpc_channel_credentials_release(grpc_channel_credentials* creds) { GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_channel_credentials_unref(creds); -} - -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_call_credentials_unref(grpc_call_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } + if (creds) creds->Unref(); } void grpc_call_credentials_release(grpc_call_credentials* creds) { GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials_unref(creds); -} - -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - if (creds == nullptr || creds->vtable->get_request_metadata == nullptr) { - return true; - } - return creds->vtable->get_request_metadata(creds, pollent, context, md_array, - on_request_metadata, error); -} - -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - if (creds == nullptr || - creds->vtable->cancel_get_request_metadata == nullptr) { - return; - } - creds->vtable->cancel_get_request_metadata(creds, md_array, error); -} - -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* channel_creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args) { - *new_args = nullptr; - if (channel_creds == nullptr) { - return GRPC_SECURITY_ERROR; - } - GPR_ASSERT(channel_creds->vtable->create_security_connector != nullptr); - return channel_creds->vtable->create_security_connector( - channel_creds, nullptr, target, args, sc, new_args); -} - -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* channel_creds) { - if (channel_creds != nullptr && channel_creds->vtable != nullptr && - channel_creds->vtable->duplicate_without_call_credentials != nullptr) { - return channel_creds->vtable->duplicate_without_call_credentials( - channel_creds); - } else { - return grpc_channel_credentials_ref(channel_creds); - } + if (creds) creds->Unref(); } static void credentials_pointer_arg_destroy(void* p) { - grpc_channel_credentials_unref(static_cast<grpc_channel_credentials*>(p)); + static_cast<grpc_channel_credentials*>(p)->Unref(); } static void* credentials_pointer_arg_copy(void* p) { - return grpc_channel_credentials_ref( - static_cast<grpc_channel_credentials*>(p)); + return static_cast<grpc_channel_credentials*>(p)->Ref().release(); } static int credentials_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } @@ -191,63 +95,35 @@ grpc_channel_credentials* grpc_channel_credentials_find_in_args( return nullptr; } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_server_credentials_unref(grpc_server_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - gpr_free(creds); - } -} - void grpc_server_credentials_release(grpc_server_credentials* creds) { GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_server_credentials_unref(creds); + if (creds) creds->Unref(); } -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - if (creds == nullptr || creds->vtable->create_security_connector == nullptr) { - gpr_log(GPR_ERROR, "Server credentials cannot create security context."); - return GRPC_SECURITY_ERROR; - } - return creds->vtable->create_security_connector(creds, sc); -} - -void grpc_server_credentials_set_auth_metadata_processor( - grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { +void grpc_server_credentials::set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor) { GRPC_API_TRACE( "grpc_server_credentials_set_auth_metadata_processor(" "creds=%p, " "processor=grpc_auth_metadata_processor { process: %p, state: %p })", - 3, (creds, (void*)(intptr_t)processor.process, processor.state)); - if (creds == nullptr) return; - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - creds->processor = processor; + 3, (this, (void*)(intptr_t)processor.process, processor.state)); + DestroyProcessor(); + processor_ = processor; +} + +void grpc_server_credentials_set_auth_metadata_processor( + grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { + GPR_DEBUG_ASSERT(creds != nullptr); + creds->set_auth_metadata_processor(processor); } static void server_credentials_pointer_arg_destroy(void* p) { - grpc_server_credentials_unref(static_cast<grpc_server_credentials*>(p)); + static_cast<grpc_server_credentials*>(p)->Unref(); } static void* server_credentials_pointer_arg_copy(void* p) { - return grpc_server_credentials_ref(static_cast<grpc_server_credentials*>(p)); + return static_cast<grpc_server_credentials*>(p)->Ref().release(); } static int server_credentials_pointer_cmp(void* a, void* b) { diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index 3878958b38..4091ef3dfb 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -26,6 +26,7 @@ #include <grpc/support/sync.h> #include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/polling_entity.h" @@ -90,44 +91,46 @@ void grpc_override_well_known_credentials_path_getter( #define GRPC_ARG_CHANNEL_CREDENTIALS "grpc.channel_credentials" -typedef struct { - void (*destruct)(grpc_channel_credentials* c); - - grpc_security_status (*create_security_connector)( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_channel_credentials + : grpc_core::RefCounted<grpc_channel_credentials> { + public: + explicit grpc_channel_credentials(const char* type) : type_(type) {} + virtual ~grpc_channel_credentials() = default; + + // Creates a security connector for the channel. May also create new channel + // args for the channel to be used in place of the passed in const args if + // returned non NULL. In that case the caller is responsible for destroying + // new_args after channel creation. + virtual grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args); - - grpc_channel_credentials* (*duplicate_without_call_credentials)( - grpc_channel_credentials* c); -} grpc_channel_credentials_vtable; - -struct grpc_channel_credentials { - const grpc_channel_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; + grpc_channel_args** new_args) { + // Tell clang-tidy that call_creds cannot be passed as const-ref. + call_creds.reset(); + GRPC_ABSTRACT; + } + + // Creates a version of the channel credentials without any attached call + // credentials. This can be used in order to open a channel to a non-trusted + // gRPC load balancer. + virtual grpc_core::RefCountedPtr<grpc_channel_credentials> + duplicate_without_call_credentials() { + // By default we just increment the refcount. + return Ref(); + } + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds); -void grpc_channel_credentials_unref(grpc_channel_credentials* creds); - -/* Creates a security connector for the channel. May also create new channel - args for the channel to be used in place of the passed in const args if - returned non NULL. In that case the caller is responsible for destroying - new_args after channel creation. */ -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args); - -/* Creates a version of the channel credentials without any attached call - credentials. This can be used in order to open a channel to a non-trusted - gRPC load balancer. */ -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* creds); - /* Util to encapsulate the channel credentials in a channel arg. */ grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials); @@ -158,44 +161,39 @@ void grpc_credentials_mdelem_array_destroy(grpc_credentials_mdelem_array* list); /* --- grpc_call_credentials. --- */ -typedef struct { - void (*destruct)(grpc_call_credentials* c); - bool (*get_request_metadata)(grpc_call_credentials* c, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error); - void (*cancel_get_request_metadata)(grpc_call_credentials* c, - grpc_credentials_mdelem_array* md_array, - grpc_error* error); -} grpc_call_credentials_vtable; - -struct grpc_call_credentials { - const grpc_call_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_call_credentials + : public grpc_core::RefCounted<grpc_call_credentials> { + public: + explicit grpc_call_credentials(const char* type) : type_(type) {} + virtual ~grpc_call_credentials() = default; + + // Returns true if completed synchronously, in which case \a error will + // be set to indicate the result. Otherwise, \a on_request_metadata will + // be invoked asynchronously when complete. \a md_array will be populated + // with the resulting metadata once complete. + virtual bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) GRPC_ABSTRACT; + + // Cancels a pending asynchronous operation started by + // grpc_call_credentials_get_request_metadata() with the corresponding + // value of \a md_array. + virtual void cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) GRPC_ABSTRACT; + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds); -void grpc_call_credentials_unref(grpc_call_credentials* creds); - -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_request_metadata will -/// be invoked asynchronously when complete. \a md_array will be populated -/// with the resulting metadata once complete. -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error); - -/// Cancels a pending asynchronous operation started by -/// grpc_call_credentials_get_request_metadata() with the corresponding -/// value of \a md_array. -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error); - /* Metadata-only credentials with the specified key and value where asynchronicity can be simulated for testing. */ grpc_call_credentials* grpc_md_only_test_credentials_create( @@ -203,26 +201,40 @@ grpc_call_credentials* grpc_md_only_test_credentials_create( /* --- grpc_server_credentials. --- */ -typedef struct { - void (*destruct)(grpc_server_credentials* c); - grpc_security_status (*create_security_connector)( - grpc_server_credentials* c, grpc_server_security_connector** sc); -} grpc_server_credentials_vtable; - -struct grpc_server_credentials { - const grpc_server_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; - grpc_auth_metadata_processor processor; -}; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_server_credentials + : public grpc_core::RefCounted<grpc_server_credentials> { + public: + explicit grpc_server_credentials(const char* type) : type_(type) {} -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc); + virtual ~grpc_server_credentials() { DestroyProcessor(); } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds); + virtual grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() GRPC_ABSTRACT; -void grpc_server_credentials_unref(grpc_server_credentials* creds); + const char* type() const { return type_; } + + const grpc_auth_metadata_processor& auth_metadata_processor() const { + return processor_; + } + void set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor); + + GRPC_ABSTRACT_BASE_CLASS + + private: + void DestroyProcessor() { + if (processor_.destroy != nullptr && processor_.state != nullptr) { + processor_.destroy(processor_.state); + } + } + + const char* type_; + grpc_auth_metadata_processor processor_ = + grpc_auth_metadata_processor(); // Zero-initialize the C struct. +}; #define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials" @@ -233,15 +245,27 @@ grpc_server_credentials* grpc_find_server_credentials_in_args( /* -- Credentials Metadata Request. -- */ -typedef struct { - grpc_call_credentials* creds; +struct grpc_credentials_metadata_request { + explicit grpc_credentials_metadata_request( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) + : creds(std::move(creds)) {} + ~grpc_credentials_metadata_request() { + grpc_http_response_destroy(&response); + } + + grpc_core::RefCountedPtr<grpc_call_credentials> creds; grpc_http_response response; -} grpc_credentials_metadata_request; +}; -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds); +inline grpc_credentials_metadata_request* +grpc_credentials_metadata_request_create( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) { + return grpc_core::New<grpc_credentials_metadata_request>(std::move(creds)); +} -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r); +inline void grpc_credentials_metadata_request_destroy( + grpc_credentials_metadata_request* r) { + grpc_core::Delete(r); +} #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index d3e0e8c816..337dd7679f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -33,49 +33,45 @@ /* -- Fake transport security credentials. -- */ -static grpc_security_status fake_transport_security_create_security_connector( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, - const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - *sc = - grpc_fake_channel_security_connector_create(c, call_creds, target, args); - return GRPC_SECURITY_OK; -} - -static grpc_security_status -fake_transport_security_server_create_security_connector( - grpc_server_credentials* c, grpc_server_security_connector** sc) { - *sc = grpc_fake_server_security_connector_create(c); - return GRPC_SECURITY_OK; -} +namespace { +class grpc_fake_channel_credentials final : public grpc_channel_credentials { + public: + grpc_fake_channel_credentials() + : grpc_channel_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_channel_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override { + return grpc_fake_channel_security_connector_create( + this->Ref(), std::move(call_creds), target, args); + } +}; + +class grpc_fake_server_credentials final : public grpc_server_credentials { + public: + grpc_fake_server_credentials() + : grpc_server_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_server_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override { + return grpc_fake_server_security_connector_create(this->Ref()); + } +}; +} // namespace -static grpc_channel_credentials_vtable - fake_transport_security_credentials_vtable = { - nullptr, fake_transport_security_create_security_connector, nullptr}; - -static grpc_server_credentials_vtable - fake_transport_security_server_credentials_vtable = { - nullptr, fake_transport_security_server_create_security_connector}; - -grpc_channel_credentials* grpc_fake_transport_security_credentials_create( - void) { - grpc_channel_credentials* c = static_cast<grpc_channel_credentials*>( - gpr_zalloc(sizeof(grpc_channel_credentials))); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - c->vtable = &fake_transport_security_credentials_vtable; - gpr_ref_init(&c->refcount, 1); - return c; +grpc_channel_credentials* grpc_fake_transport_security_credentials_create() { + return grpc_core::New<grpc_fake_channel_credentials>(); } -grpc_server_credentials* grpc_fake_transport_security_server_credentials_create( - void) { - grpc_server_credentials* c = static_cast<grpc_server_credentials*>( - gpr_malloc(sizeof(grpc_server_credentials))); - memset(c, 0, sizeof(grpc_server_credentials)); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - gpr_ref_init(&c->refcount, 1); - c->vtable = &fake_transport_security_server_credentials_vtable; - return c; +grpc_server_credentials* +grpc_fake_transport_security_server_credentials_create() { + return grpc_core::New<grpc_fake_server_credentials>(); } grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) { @@ -92,46 +88,25 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only test credentials. -- */ -static void md_only_test_destruct(grpc_call_credentials* creds) { - grpc_md_only_test_credentials* c = - reinterpret_cast<grpc_md_only_test_credentials*>(creds); - GRPC_MDELEM_UNREF(c->md); -} - -static bool md_only_test_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_md_only_test_credentials* c = - reinterpret_cast<grpc_md_only_test_credentials*>(creds); - grpc_credentials_mdelem_array_add(md_array, c->md); - if (c->is_async) { +bool grpc_md_only_test_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, md_); + if (is_async_) { GRPC_CLOSURE_SCHED(on_request_metadata, GRPC_ERROR_NONE); return false; } return true; } -static void md_only_test_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_md_only_test_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable md_only_test_vtable = { - md_only_test_destruct, md_only_test_get_request_metadata, - md_only_test_cancel_get_request_metadata}; - grpc_call_credentials* grpc_md_only_test_credentials_create( const char* md_key, const char* md_value, bool is_async) { - grpc_md_only_test_credentials* c = - static_cast<grpc_md_only_test_credentials*>( - gpr_zalloc(sizeof(grpc_md_only_test_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &md_only_test_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->md = grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), - grpc_slice_from_copied_string(md_value)); - c->is_async = is_async; - return &c->base; + return grpc_core::New<grpc_md_only_test_credentials>(md_key, md_value, + is_async); } diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index e89e6e24cc..b7f6a1909f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -55,10 +55,28 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only Test credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_mdelem md; - bool is_async; -} grpc_md_only_test_credentials; +class grpc_md_only_test_credentials : public grpc_call_credentials { + public: + grpc_md_only_test_credentials(const char* md_key, const char* md_value, + bool is_async) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + md_(grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), + grpc_slice_from_copied_string(md_value))), + is_async_(is_async) {} + ~grpc_md_only_test_credentials() override { GRPC_MDELEM_UNREF(md_); } + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem md_; + bool is_async_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_FAKE_FAKE_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc index 0674540d01..a86a17d586 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.cc +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/load_file.h" @@ -72,20 +73,11 @@ typedef struct { grpc_http_response response; } metadata_server_detector; -static void google_default_credentials_destruct( - grpc_channel_credentials* creds) { - grpc_google_default_channel_credentials* c = - reinterpret_cast<grpc_google_default_channel_credentials*>(creds); - grpc_channel_credentials_unref(c->alts_creds); - grpc_channel_credentials_unref(c->ssl_creds); -} - -static grpc_security_status google_default_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_google_default_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_google_default_channel_credentials* c = - reinterpret_cast<grpc_google_default_channel_credentials*>(creds); + grpc_channel_args** new_args) { bool is_grpclb_load_balancer = grpc_channel_arg_get_bool( grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER), false); @@ -95,22 +87,22 @@ static grpc_security_status google_default_create_security_connector( false); bool use_alts = is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer; - grpc_security_status status = GRPC_SECURITY_ERROR; /* Return failure if ALTS is selected but not running on GCE. */ if (use_alts && !g_is_on_gce) { gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE."); - goto end; + return nullptr; } - status = use_alts ? c->alts_creds->vtable->create_security_connector( - c->alts_creds, call_creds, target, args, sc, new_args) - : c->ssl_creds->vtable->create_security_connector( - c->ssl_creds, call_creds, target, args, sc, new_args); -/* grpclb-specific channel args are removed from the channel args set - * to ensure backends and fallback adresses will have the same set of channel - * args. By doing that, it guarantees the connections to backends will not be - * torn down and re-connected when switching in and out of fallback mode. - */ -end: + + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + use_alts ? alts_creds_->create_security_connector(call_creds, target, + args, new_args) + : ssl_creds_->create_security_connector(call_creds, target, args, + new_args); + /* grpclb-specific channel args are removed from the channel args set + * to ensure backends and fallback adresses will have the same set of channel + * args. By doing that, it guarantees the connections to backends will not be + * torn down and re-connected when switching in and out of fallback mode. + */ if (use_alts) { static const char* args_to_remove[] = { GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER, @@ -119,13 +111,9 @@ end: *new_args = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0); } - return status; + return sc; } -static grpc_channel_credentials_vtable google_default_credentials_vtable = { - google_default_credentials_destruct, - google_default_create_security_connector, nullptr}; - static void on_metadata_server_detection_http_response(void* user_data, grpc_error* error) { metadata_server_detector* detector = @@ -215,11 +203,11 @@ static int is_metadata_server_reachable() { /* Takes ownership of creds_path if not NULL. */ static grpc_error* create_default_creds_from_path( - char* creds_path, grpc_call_credentials** creds) { + char* creds_path, grpc_core::RefCountedPtr<grpc_call_credentials>* creds) { grpc_json* json = nullptr; grpc_auth_json_key key; grpc_auth_refresh_token token; - grpc_call_credentials* result = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> result; grpc_slice creds_data = grpc_empty_slice(); grpc_error* error = GRPC_ERROR_NONE; if (creds_path == nullptr) { @@ -276,9 +264,9 @@ end: return error; } -grpc_channel_credentials* grpc_google_default_credentials_create(void) { +grpc_channel_credentials* grpc_google_default_credentials_create() { grpc_channel_credentials* result = nullptr; - grpc_call_credentials* call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds; grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Failed to create Google credentials"); grpc_error* err; @@ -316,7 +304,8 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { gpr_mu_unlock(&g_state_mu); if (g_metadata_server_available) { - call_creds = grpc_google_compute_engine_credentials_create(nullptr); + call_creds = grpc_core::RefCountedPtr<grpc_call_credentials>( + grpc_google_compute_engine_credentials_create(nullptr)); if (call_creds == nullptr) { error = grpc_error_add_child( error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( @@ -327,23 +316,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { end: if (call_creds != nullptr) { /* Create google default credentials. */ - auto creds = static_cast<grpc_google_default_channel_credentials*>( - gpr_zalloc(sizeof(grpc_google_default_channel_credentials))); - creds->base.vtable = &google_default_credentials_vtable; - creds->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT; - gpr_ref_init(&creds->base.refcount, 1); - creds->ssl_creds = + grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); - GPR_ASSERT(creds->ssl_creds != nullptr); + GPR_ASSERT(ssl_creds != nullptr); grpc_alts_credentials_options* options = grpc_alts_credentials_client_options_create(); - creds->alts_creds = grpc_alts_credentials_create(options); + grpc_channel_credentials* alts_creds = + grpc_alts_credentials_create(options); grpc_alts_credentials_options_destroy(options); - result = grpc_composite_channel_credentials_create(&creds->base, call_creds, - nullptr); + auto creds = + grpc_core::MakeRefCounted<grpc_google_default_channel_credentials>( + alts_creds != nullptr ? alts_creds->Ref() : nullptr, + ssl_creds != nullptr ? ssl_creds->Ref() : nullptr); + if (ssl_creds) ssl_creds->Unref(); + if (alts_creds) alts_creds->Unref(); + result = grpc_composite_channel_credentials_create( + creds.get(), call_creds.get(), nullptr); GPR_ASSERT(result != nullptr); - grpc_channel_credentials_unref(&creds->base); - grpc_call_credentials_unref(call_creds); } else { gpr_log(GPR_ERROR, "Could not create google default credentials: %s", grpc_error_string(error)); diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.h b/src/core/lib/security/credentials/google_default/google_default_credentials.h index b9e2efb04f..bf00f7285a 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.h +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.h @@ -21,6 +21,7 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" #define GRPC_GOOGLE_CLOUD_SDK_CONFIG_DIRECTORY "gcloud" @@ -39,11 +40,33 @@ "/" GRPC_GOOGLE_WELL_KNOWN_CREDENTIALS_FILE #endif -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* alts_creds; - grpc_channel_credentials* ssl_creds; -} grpc_google_default_channel_credentials; +class grpc_google_default_channel_credentials + : public grpc_channel_credentials { + public: + grpc_google_default_channel_credentials( + grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds, + grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT), + alts_creds_(std::move(alts_creds)), + ssl_creds_(std::move(ssl_creds)) {} + + ~grpc_google_default_channel_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_channel_credentials* alts_creds() const { + return alts_creds_.get(); + } + const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); } + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds_; + grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds_; +}; namespace grpc_core { namespace internal { diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc index 5d92fa88c4..5cd561f676 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.cc +++ b/src/core/lib/security/credentials/iam/iam_credentials.cc @@ -22,6 +22,7 @@ #include <string.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include <grpc/support/alloc.h> @@ -29,32 +30,37 @@ #include <grpc/support/string_util.h> #include <grpc/support/sync.h> -static void iam_destruct(grpc_call_credentials* creds) { - grpc_google_iam_credentials* c = - reinterpret_cast<grpc_google_iam_credentials*>(creds); - grpc_credentials_mdelem_array_destroy(&c->md_array); +grpc_google_iam_credentials::~grpc_google_iam_credentials() { + grpc_credentials_mdelem_array_destroy(&md_array_); } -static bool iam_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_google_iam_credentials* c = - reinterpret_cast<grpc_google_iam_credentials*>(creds); - grpc_credentials_mdelem_array_append(md_array, &c->md_array); +bool grpc_google_iam_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_append(md_array, &md_array_); return true; } -static void iam_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_google_iam_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable iam_vtable = { - iam_destruct, iam_get_request_metadata, iam_cancel_get_request_metadata}; +grpc_google_iam_credentials::grpc_google_iam_credentials( + const char* token, const char* authority_selector) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM) { + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), + grpc_slice_from_copied_string(token)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); + md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), + grpc_slice_from_copied_string(authority_selector)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); +} grpc_call_credentials* grpc_google_iam_credentials_create( const char* token, const char* authority_selector, void* reserved) { @@ -66,21 +72,7 @@ grpc_call_credentials* grpc_google_iam_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(token != nullptr); GPR_ASSERT(authority_selector != nullptr); - grpc_google_iam_credentials* c = - static_cast<grpc_google_iam_credentials*>(gpr_zalloc(sizeof(*c))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_IAM; - c->base.vtable = &iam_vtable; - gpr_ref_init(&c->base.refcount, 1); - grpc_mdelem md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), - grpc_slice_from_copied_string(token)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), - grpc_slice_from_copied_string(authority_selector)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - - return &c->base; + return grpc_core::MakeRefCounted<grpc_google_iam_credentials>( + token, authority_selector) + .release(); } diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h index a45710fe0f..36f5ee8930 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.h +++ b/src/core/lib/security/credentials/iam/iam_credentials.h @@ -23,9 +23,23 @@ #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials base; - grpc_credentials_mdelem_array md_array; -} grpc_google_iam_credentials; +class grpc_google_iam_credentials : public grpc_call_credentials { + public: + grpc_google_iam_credentials(const char* token, + const char* authority_selector); + ~grpc_google_iam_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_credentials_mdelem_array md_array_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_IAM_IAM_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc index 05c08a68b0..f2591a1ea5 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.cc +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc @@ -23,6 +23,8 @@ #include <inttypes.h> #include <string.h> +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include <grpc/support/alloc.h> @@ -30,71 +32,66 @@ #include <grpc/support/string_util.h> #include <grpc/support/sync.h> -static void jwt_reset_cache(grpc_service_account_jwt_access_credentials* c) { - GRPC_MDELEM_UNREF(c->cached.jwt_md); - c->cached.jwt_md = GRPC_MDNULL; - if (c->cached.service_url != nullptr) { - gpr_free(c->cached.service_url); - c->cached.service_url = nullptr; +void grpc_service_account_jwt_access_credentials::reset_cache() { + GRPC_MDELEM_UNREF(cached_.jwt_md); + cached_.jwt_md = GRPC_MDNULL; + if (cached_.service_url != nullptr) { + gpr_free(cached_.service_url); + cached_.service_url = nullptr; } - c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); + cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); } -static void jwt_destruct(grpc_call_credentials* creds) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds); - grpc_auth_json_key_destruct(&c->key); - jwt_reset_cache(c); - gpr_mu_destroy(&c->cache_mu); +grpc_service_account_jwt_access_credentials:: + ~grpc_service_account_jwt_access_credentials() { + grpc_auth_json_key_destruct(&key_); + reset_cache(); + gpr_mu_destroy(&cache_mu_); } -static bool jwt_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds); +bool grpc_service_account_jwt_access_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { gpr_timespec refresh_threshold = gpr_time_from_seconds( GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN); /* See if we can return a cached jwt. */ grpc_mdelem jwt_md = GRPC_MDNULL; { - gpr_mu_lock(&c->cache_mu); - if (c->cached.service_url != nullptr && - strcmp(c->cached.service_url, context.service_url) == 0 && - !GRPC_MDISNULL(c->cached.jwt_md) && - (gpr_time_cmp(gpr_time_sub(c->cached.jwt_expiration, - gpr_now(GPR_CLOCK_REALTIME)), - refresh_threshold) > 0)) { - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + gpr_mu_lock(&cache_mu_); + if (cached_.service_url != nullptr && + strcmp(cached_.service_url, context.service_url) == 0 && + !GRPC_MDISNULL(cached_.jwt_md) && + (gpr_time_cmp( + gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)), + refresh_threshold) > 0)) { + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (GRPC_MDISNULL(jwt_md)) { char* jwt = nullptr; /* Generate a new jwt. */ - gpr_mu_lock(&c->cache_mu); - jwt_reset_cache(c); - jwt = grpc_jwt_encode_and_sign(&c->key, context.service_url, - c->jwt_lifetime, nullptr); + gpr_mu_lock(&cache_mu_); + reset_cache(); + jwt = grpc_jwt_encode_and_sign(&key_, context.service_url, jwt_lifetime_, + nullptr); if (jwt != nullptr) { char* md_value; gpr_asprintf(&md_value, "Bearer %s", jwt); gpr_free(jwt); - c->cached.jwt_expiration = - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c->jwt_lifetime); - c->cached.service_url = gpr_strdup(context.service_url); - c->cached.jwt_md = grpc_mdelem_from_slices( + cached_.jwt_expiration = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_); + cached_.service_url = gpr_strdup(context.service_url); + cached_.jwt_md = grpc_mdelem_from_slices( grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), grpc_slice_from_copied_string(md_value)); gpr_free(md_value); - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (!GRPC_MDISNULL(jwt_md)) { @@ -106,29 +103,15 @@ static bool jwt_get_request_metadata(grpc_call_credentials* creds, return true; } -static void jwt_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable jwt_vtable = { - jwt_destruct, jwt_get_request_metadata, jwt_cancel_get_request_metadata}; - -grpc_call_credentials* -grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key key, gpr_timespec token_lifetime) { - grpc_service_account_jwt_access_credentials* c; - if (!grpc_auth_json_key_is_valid(&key)) { - gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); - return nullptr; - } - c = static_cast<grpc_service_account_jwt_access_credentials*>( - gpr_zalloc(sizeof(grpc_service_account_jwt_access_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_JWT; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &jwt_vtable; - c->key = key; +grpc_service_account_jwt_access_credentials:: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) { gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime(); if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) { gpr_log(GPR_INFO, @@ -136,10 +119,20 @@ grpc_service_account_jwt_access_credentials_create_from_auth_json_key( static_cast<int>(max_token_lifetime.tv_sec)); token_lifetime = grpc_max_auth_token_lifetime(); } - c->jwt_lifetime = token_lifetime; - gpr_mu_init(&c->cache_mu); - jwt_reset_cache(c); - return &c->base; + jwt_lifetime_ = token_lifetime; + gpr_mu_init(&cache_mu_); + reset_cache(); +} + +grpc_core::RefCountedPtr<grpc_call_credentials> +grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key key, gpr_timespec token_lifetime) { + if (!grpc_auth_json_key_is_valid(&key)) { + gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); + return nullptr; + } + return grpc_core::MakeRefCounted<grpc_service_account_jwt_access_credentials>( + key, token_lifetime); } static char* redact_private_key(const char* json_key) { @@ -182,9 +175,7 @@ grpc_call_credentials* grpc_service_account_jwt_access_credentials_create( } GPR_ASSERT(reserved == nullptr); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials* creds = - grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key_create_from_string(json_key), token_lifetime); - - return creds; + return grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key_create_from_string(json_key), token_lifetime) + .release(); } diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h index 5c3d34aa56..5af909f44d 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.h +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h @@ -24,25 +24,44 @@ #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/jwt/json_token.h" -typedef struct { - grpc_call_credentials base; +class grpc_service_account_jwt_access_credentials + : public grpc_call_credentials { + public: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime); + ~grpc_service_account_jwt_access_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; } + const grpc_auth_json_key& key() const { return key_; } + + private: + void reset_cache(); // Have a simple cache for now with just 1 entry. We could have a map based on // the service_url for a more sophisticated one. - gpr_mu cache_mu; + gpr_mu cache_mu_; struct { - grpc_mdelem jwt_md; - char* service_url; + grpc_mdelem jwt_md = GRPC_MDNULL; + char* service_url = nullptr; gpr_timespec jwt_expiration; - } cached; + } cached_; - grpc_auth_json_key key; - gpr_timespec jwt_lifetime; -} grpc_service_account_jwt_access_credentials; + grpc_auth_json_key key_; + gpr_timespec jwt_lifetime_; +}; // Private constructor for jwt credentials from an already parsed json key. // Takes ownership of the key. -grpc_call_credentials* +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_service_account_jwt_access_credentials_create_from_auth_json_key( grpc_auth_json_key key, gpr_timespec token_lifetime); diff --git a/src/core/lib/security/credentials/jwt/jwt_verifier.cc b/src/core/lib/security/credentials/jwt/jwt_verifier.cc index c7d1b36ff0..cdef0f322a 100644 --- a/src/core/lib/security/credentials/jwt/jwt_verifier.cc +++ b/src/core/lib/security/credentials/jwt/jwt_verifier.cc @@ -31,7 +31,9 @@ #include <grpc/support/sync.h> extern "C" { +#include <openssl/bn.h> #include <openssl/pem.h> +#include <openssl/rsa.h> } #include "src/core/lib/gpr/string.h" diff --git a/src/core/lib/security/credentials/local/local_credentials.cc b/src/core/lib/security/credentials/local/local_credentials.cc index 3ccfa2b908..6f6f95a34a 100644 --- a/src/core/lib/security/credentials/local/local_credentials.cc +++ b/src/core/lib/security/credentials/local/local_credentials.cc @@ -29,49 +29,36 @@ #define GRPC_CREDENTIALS_TYPE_LOCAL "Local" -static void local_credentials_destruct(grpc_channel_credentials* creds) {} - -static void local_server_credentials_destruct(grpc_server_credentials* creds) {} - -static grpc_security_status local_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_local_channel_security_connector_create( - creds, request_metadata_creds, args, target_name, sc); + this->Ref(), std::move(request_metadata_creds), args, target_name); } -static grpc_security_status local_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_local_server_security_connector_create(creds, sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_credentials::create_security_connector() { + return grpc_local_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable local_credentials_vtable = { - local_credentials_destruct, local_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable local_server_credentials_vtable = { - local_server_credentials_destruct, local_server_create_security_connector}; +grpc_local_credentials::grpc_local_credentials( + grpc_local_connect_type connect_type) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} grpc_channel_credentials* grpc_local_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast<grpc_local_credentials*>( - gpr_zalloc(sizeof(grpc_local_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_local_credentials>(connect_type); } +grpc_local_server_credentials::grpc_local_server_credentials( + grpc_local_connect_type connect_type) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} + grpc_server_credentials* grpc_local_server_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast<grpc_local_server_credentials*>( - gpr_zalloc(sizeof(grpc_local_server_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_local_server_credentials>(connect_type); } diff --git a/src/core/lib/security/credentials/local/local_credentials.h b/src/core/lib/security/credentials/local/local_credentials.h index 47358b04bc..60a8a4f64c 100644 --- a/src/core/lib/security/credentials/local/local_credentials.h +++ b/src/core/lib/security/credentials/local/local_credentials.h @@ -25,16 +25,37 @@ #include "src/core/lib/security/credentials/credentials.h" -/* Main struct for grpc local channel credential. */ -typedef struct grpc_local_credentials { - grpc_channel_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_credentials; - -/* Main struct for grpc local server credential. */ -typedef struct grpc_local_server_credentials { - grpc_server_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_server_credentials; +/* Main class for grpc local channel credential. */ +class grpc_local_credentials final : public grpc_channel_credentials { + public: + explicit grpc_local_credentials(grpc_local_connect_type connect_type); + ~grpc_local_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; + +/* Main class for grpc local server credential. */ +class grpc_local_server_credentials final : public grpc_server_credentials { + public: + explicit grpc_local_server_credentials(grpc_local_connect_type connect_type); + ~grpc_local_server_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_LOCAL_LOCAL_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc index 44b093557f..ad63b01e75 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc @@ -22,6 +22,7 @@ #include <string.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/util/json_util.h" #include "src/core/lib/surface/api_trace.h" @@ -105,13 +106,12 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) { // Oauth2 Token Fetcher credentials. // -static void oauth2_token_fetcher_destruct(grpc_call_credentials* creds) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); - GRPC_MDELEM_UNREF(c->access_token_md); - gpr_mu_destroy(&c->mu); - grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&c->pollent)); - grpc_httpcli_context_destroy(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials:: + ~grpc_oauth2_token_fetcher_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); + gpr_mu_destroy(&mu_); + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); + grpc_httpcli_context_destroy(&httpcli_context_); } grpc_credentials_status @@ -209,25 +209,29 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, grpc_credentials_metadata_request* r = static_cast<grpc_credentials_metadata_request*>(user_data); grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds); + reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds.get()); + c->on_http_response(r, error); +} + +void grpc_oauth2_token_fetcher_credentials::on_http_response( + grpc_credentials_metadata_request* r, grpc_error* error) { grpc_mdelem access_token_md = GRPC_MDNULL; grpc_millis token_lifetime; grpc_credentials_status status = grpc_oauth2_token_fetcher_credentials_parse_server_response( &r->response, &access_token_md, &token_lifetime); // Update cache and grab list of pending requests. - gpr_mu_lock(&c->mu); - c->token_fetch_pending = false; - c->access_token_md = GRPC_MDELEM_REF(access_token_md); - c->token_expiration = + gpr_mu_lock(&mu_); + token_fetch_pending_ = false; + access_token_md_ = GRPC_MDELEM_REF(access_token_md); + token_expiration_ = status == GRPC_CREDENTIALS_OK ? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_millis(token_lifetime, GPR_TIMESPAN)) : gpr_inf_past(GPR_CLOCK_MONOTONIC); - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; - c->pending_requests = nullptr; - gpr_mu_unlock(&c->mu); + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; + pending_requests_ = nullptr; + gpr_mu_unlock(&mu_); // Invoke callbacks for all pending requests. while (pending_request != nullptr) { if (status == GRPC_CREDENTIALS_OK) { @@ -239,42 +243,40 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, } GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, error); grpc_polling_entity_del_from_pollset_set( - pending_request->pollent, grpc_polling_entity_pollset_set(&c->pollent)); + pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_)); grpc_oauth2_pending_get_request_metadata* prev = pending_request; pending_request = pending_request->next; gpr_free(prev); } GRPC_MDELEM_UNREF(access_token_md); - grpc_call_credentials_unref(r->creds); + Unref(); grpc_credentials_metadata_request_destroy(r); } -static bool oauth2_token_fetcher_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); +bool grpc_oauth2_token_fetcher_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { // Check if we can use the cached token. grpc_millis refresh_threshold = GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC; grpc_mdelem cached_access_token_md = GRPC_MDNULL; - gpr_mu_lock(&c->mu); - if (!GRPC_MDISNULL(c->access_token_md) && + gpr_mu_lock(&mu_); + if (!GRPC_MDISNULL(access_token_md_) && gpr_time_cmp( - gpr_time_sub(c->token_expiration, gpr_now(GPR_CLOCK_MONOTONIC)), + gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)), gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN)) > 0) { - cached_access_token_md = GRPC_MDELEM_REF(c->access_token_md); + cached_access_token_md = GRPC_MDELEM_REF(access_token_md_); } if (!GRPC_MDISNULL(cached_access_token_md)) { - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); grpc_credentials_mdelem_array_add(md_array, cached_access_token_md); GRPC_MDELEM_UNREF(cached_access_token_md); return true; } // Couldn't get the token from the cache. - // Add request to c->pending_requests and start a new fetch if needed. + // Add request to pending_requests_ and start a new fetch if needed. grpc_oauth2_pending_get_request_metadata* pending_request = static_cast<grpc_oauth2_pending_get_request_metadata*>( gpr_malloc(sizeof(*pending_request))); @@ -282,41 +284,37 @@ static bool oauth2_token_fetcher_get_request_metadata( pending_request->on_request_metadata = on_request_metadata; pending_request->pollent = pollent; grpc_polling_entity_add_to_pollset_set( - pollent, grpc_polling_entity_pollset_set(&c->pollent)); - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; + pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_request->next = pending_requests_; + pending_requests_ = pending_request; bool start_fetch = false; - if (!c->token_fetch_pending) { - c->token_fetch_pending = true; + if (!token_fetch_pending_) { + token_fetch_pending_ = true; start_fetch = true; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); if (start_fetch) { - grpc_call_credentials_ref(creds); - c->fetch_func(grpc_credentials_metadata_request_create(creds), - &c->httpcli_context, &c->pollent, - on_oauth2_token_fetcher_http_response, - grpc_core::ExecCtx::Get()->Now() + refresh_threshold); + Ref().release(); + fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()), + &httpcli_context_, &pollent_, + on_oauth2_token_fetcher_http_response, + grpc_core::ExecCtx::Get()->Now() + refresh_threshold); } return false; } -static void oauth2_token_fetcher_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); - gpr_mu_lock(&c->mu); +void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); grpc_oauth2_pending_get_request_metadata* prev = nullptr; - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; while (pending_request != nullptr) { if (pending_request->md_array == md_array) { // Remove matching pending request from the list. if (prev != nullptr) { prev->next = pending_request->next; } else { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } // Invoke the callback immediately with an error. GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, @@ -327,96 +325,89 @@ static void oauth2_token_fetcher_cancel_get_request_metadata( prev = pending_request; pending_request = pending_request->next; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials* c, - grpc_fetch_oauth2_func fetch_func) { - memset(c, 0, sizeof(grpc_oauth2_token_fetcher_credentials)); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - gpr_ref_init(&c->base.refcount, 1); - gpr_mu_init(&c->mu); - c->token_expiration = gpr_inf_past(GPR_CLOCK_MONOTONIC); - c->fetch_func = fetch_func; - c->pollent = - grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create()); - grpc_httpcli_context_init(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials() + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)), + pollent_(grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set_create())) { + gpr_mu_init(&mu_); + grpc_httpcli_context_init(&httpcli_context_); } // // Google Compute Engine credentials. // -static grpc_call_credentials_vtable compute_engine_vtable = { - oauth2_token_fetcher_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; +namespace { + +class grpc_compute_engine_token_fetcher_credentials + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_compute_engine_token_fetcher_credentials() = default; + ~grpc_compute_engine_token_fetcher_credentials() override = default; + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* http_context, + grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, + grpc_millis deadline) override { + grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; + request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; + request.http.hdr_count = 1; + request.http.hdrs = &header; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials"); + grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline, + GRPC_CLOSURE_CREATE(response_cb, metadata_req, + grpc_schedule_on_exec_ctx), + &metadata_req->response); + grpc_resource_quota_unref_internal(resource_quota); + } +}; -static void compute_engine_fetch_oauth2( - grpc_credentials_metadata_request* metadata_req, - grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, - grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; - grpc_httpcli_request request; - memset(&request, 0, sizeof(grpc_httpcli_request)); - request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; - request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; - request.http.hdr_count = 1; - request.http.hdrs = &header; - /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host - channel. This would allow us to cancel an authentication query when under - extreme memory pressure. */ - grpc_resource_quota* resource_quota = - grpc_resource_quota_create("oauth2_credentials"); - grpc_httpcli_get( - httpcli_context, pollent, resource_quota, &request, deadline, - GRPC_CLOSURE_CREATE(response_cb, metadata_req, grpc_schedule_on_exec_ctx), - &metadata_req->response); - grpc_resource_quota_unref_internal(resource_quota); -} +} // namespace grpc_call_credentials* grpc_google_compute_engine_credentials_create( void* reserved) { - grpc_oauth2_token_fetcher_credentials* c = - static_cast<grpc_oauth2_token_fetcher_credentials*>( - gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials))); GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - init_oauth2_token_fetcher(c, compute_engine_fetch_oauth2); - c->base.vtable = &compute_engine_vtable; - return &c->base; + return grpc_core::MakeRefCounted< + grpc_compute_engine_token_fetcher_credentials>() + .release(); } // // Google Refresh Token credentials. // -static void refresh_token_destruct(grpc_call_credentials* creds) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast<grpc_google_refresh_token_credentials*>(creds); - grpc_auth_refresh_token_destruct(&c->refresh_token); - oauth2_token_fetcher_destruct(&c->base.base); +grpc_google_refresh_token_credentials:: + ~grpc_google_refresh_token_credentials() { + grpc_auth_refresh_token_destruct(&refresh_token_); } -static grpc_call_credentials_vtable refresh_token_vtable = { - refresh_token_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; - -static void refresh_token_fetch_oauth2( +void grpc_google_refresh_token_credentials::fetch_oauth2( grpc_credentials_metadata_request* metadata_req, grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast<grpc_google_refresh_token_credentials*>( - metadata_req->creds); grpc_http_header header = {(char*)"Content-Type", (char*)"application/x-www-form-urlencoded"}; grpc_httpcli_request request; char* body = nullptr; gpr_asprintf(&body, GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING, - c->refresh_token.client_id, c->refresh_token.client_secret, - c->refresh_token.refresh_token); + refresh_token_.client_id, refresh_token_.client_secret, + refresh_token_.refresh_token); memset(&request, 0, sizeof(grpc_httpcli_request)); request.host = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_HOST; request.http.path = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH; @@ -437,20 +428,19 @@ static void refresh_token_fetch_oauth2( gpr_free(body); } -grpc_call_credentials* +grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials( + grpc_auth_refresh_token refresh_token) + : refresh_token_(refresh_token) {} + +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token refresh_token) { - grpc_google_refresh_token_credentials* c; if (!grpc_auth_refresh_token_is_valid(&refresh_token)) { gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation"); return nullptr; } - c = static_cast<grpc_google_refresh_token_credentials*>( - gpr_zalloc(sizeof(grpc_google_refresh_token_credentials))); - init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2); - c->base.base.vtable = &refresh_token_vtable; - c->refresh_token = refresh_token; - return &c->base.base; + return grpc_core::MakeRefCounted<grpc_google_refresh_token_credentials>( + refresh_token); } static char* create_loggable_refresh_token(grpc_auth_refresh_token* token) { @@ -478,59 +468,50 @@ grpc_call_credentials* grpc_google_refresh_token_credentials_create( gpr_free(loggable_token); } GPR_ASSERT(reserved == nullptr); - return grpc_refresh_token_credentials_create_from_auth_refresh_token(token); + return grpc_refresh_token_credentials_create_from_auth_refresh_token(token) + .release(); } // // Oauth2 Access Token credentials. // -static void access_token_destruct(grpc_call_credentials* creds) { - grpc_access_token_credentials* c = - reinterpret_cast<grpc_access_token_credentials*>(creds); - GRPC_MDELEM_UNREF(c->access_token_md); +grpc_access_token_credentials::~grpc_access_token_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); } -static bool access_token_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_access_token_credentials* c = - reinterpret_cast<grpc_access_token_credentials*>(creds); - grpc_credentials_mdelem_array_add(md_array, c->access_token_md); +bool grpc_access_token_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, access_token_md_); return true; } -static void access_token_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_access_token_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable access_token_vtable = { - access_token_destruct, access_token_get_request_metadata, - access_token_cancel_get_request_metadata}; +grpc_access_token_credentials::grpc_access_token_credentials( + const char* access_token) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) { + char* token_md_value; + gpr_asprintf(&token_md_value, "Bearer %s", access_token); + grpc_core::ExecCtx exec_ctx; + access_token_md_ = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_copied_string(token_md_value)); + gpr_free(token_md_value); +} grpc_call_credentials* grpc_access_token_credentials_create( const char* access_token, void* reserved) { - grpc_access_token_credentials* c = - static_cast<grpc_access_token_credentials*>( - gpr_zalloc(sizeof(grpc_access_token_credentials))); GRPC_API_TRACE( "grpc_access_token_credentials_create(access_token=<redacted>, " "reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &access_token_vtable; - gpr_ref_init(&c->base.refcount, 1); - char* token_md_value; - gpr_asprintf(&token_md_value, "Bearer %s", access_token); - grpc_core::ExecCtx exec_ctx; - c->access_token_md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), - grpc_slice_from_copied_string(token_md_value)); - - gpr_free(token_md_value); - return &c->base; + return grpc_core::MakeRefCounted<grpc_access_token_credentials>(access_token) + .release(); } diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h index 12a1d4484f..510a78b484 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h @@ -54,46 +54,91 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token); // This object is a base for credentials that need to acquire an oauth2 token // from an http service. -typedef void (*grpc_fetch_oauth2_func)(grpc_credentials_metadata_request* req, - grpc_httpcli_context* http_context, - grpc_polling_entity* pollent, - grpc_iomgr_cb_func cb, - grpc_millis deadline); - -typedef struct grpc_oauth2_pending_get_request_metadata { +struct grpc_oauth2_pending_get_request_metadata { grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_polling_entity* pollent; struct grpc_oauth2_pending_get_request_metadata* next; -} grpc_oauth2_pending_get_request_metadata; - -typedef struct { - grpc_call_credentials base; - gpr_mu mu; - grpc_mdelem access_token_md; - gpr_timespec token_expiration; - bool token_fetch_pending; - grpc_oauth2_pending_get_request_metadata* pending_requests; - grpc_httpcli_context httpcli_context; - grpc_fetch_oauth2_func fetch_func; - grpc_polling_entity pollent; -} grpc_oauth2_token_fetcher_credentials; +}; + +class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials { + public: + grpc_oauth2_token_fetcher_credentials(); + ~grpc_oauth2_token_fetcher_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + void on_http_response(grpc_credentials_metadata_request* r, + grpc_error* error); + + GRPC_ABSTRACT_BASE_CLASS + + protected: + virtual void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) GRPC_ABSTRACT; + + private: + gpr_mu mu_; + grpc_mdelem access_token_md_ = GRPC_MDNULL; + gpr_timespec token_expiration_; + bool token_fetch_pending_ = false; + grpc_oauth2_pending_get_request_metadata* pending_requests_ = nullptr; + grpc_httpcli_context httpcli_context_; + grpc_polling_entity pollent_; +}; // Google refresh token credentials. -typedef struct { - grpc_oauth2_token_fetcher_credentials base; - grpc_auth_refresh_token refresh_token; -} grpc_google_refresh_token_credentials; +class grpc_google_refresh_token_credentials final + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_google_refresh_token_credentials(grpc_auth_refresh_token refresh_token); + ~grpc_google_refresh_token_credentials() override; + + const grpc_auth_refresh_token& refresh_token() const { + return refresh_token_; + } + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) override; + + private: + grpc_auth_refresh_token refresh_token_; +}; // Access token credentials. -typedef struct { - grpc_call_credentials base; - grpc_mdelem access_token_md; -} grpc_access_token_credentials; +class grpc_access_token_credentials final : public grpc_call_credentials { + public: + grpc_access_token_credentials(const char* access_token); + ~grpc_access_token_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem access_token_md_; +}; // Private constructor for refresh token credentials from an already parsed // refresh token. Takes ownership of the refresh token. -grpc_call_credentials* +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token token); diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc index 4015124298..52982fdb8f 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.cc +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc @@ -35,20 +35,17 @@ grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials"); -static void plugin_destruct(grpc_call_credentials* creds) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); - gpr_mu_destroy(&c->mu); - if (c->plugin.state != nullptr && c->plugin.destroy != nullptr) { - c->plugin.destroy(c->plugin.state); +grpc_plugin_credentials::~grpc_plugin_credentials() { + gpr_mu_destroy(&mu_); + if (plugin_.state != nullptr && plugin_.destroy != nullptr) { + plugin_.destroy(plugin_.state); } } -static void pending_request_remove_locked( - grpc_plugin_credentials* c, - grpc_plugin_credentials_pending_request* pending_request) { +void grpc_plugin_credentials::pending_request_remove_locked( + pending_request* pending_request) { if (pending_request->prev == nullptr) { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } else { pending_request->prev->next = pending_request->next; } @@ -62,17 +59,17 @@ static void pending_request_remove_locked( // cancelled out from under us. // When this returns, r->cancelled indicates whether the request was // cancelled before completion. -static void pending_request_complete( - grpc_plugin_credentials_pending_request* r) { - gpr_mu_lock(&r->creds->mu); - if (!r->cancelled) pending_request_remove_locked(r->creds, r); - gpr_mu_unlock(&r->creds->mu); +void grpc_plugin_credentials::pending_request_complete(pending_request* r) { + GPR_DEBUG_ASSERT(r->creds == this); + gpr_mu_lock(&mu_); + if (!r->cancelled) pending_request_remove_locked(r); + gpr_mu_unlock(&mu_); // Ref to credentials not needed anymore. - grpc_call_credentials_unref(&r->creds->base); + Unref(); } static grpc_error* process_plugin_result( - grpc_plugin_credentials_pending_request* r, const grpc_metadata* md, + grpc_plugin_credentials::pending_request* r, const grpc_metadata* md, size_t num_md, grpc_status_code status, const char* error_details) { grpc_error* error = GRPC_ERROR_NONE; if (status != GRPC_STATUS_OK) { @@ -119,8 +116,8 @@ static void plugin_md_request_metadata_ready(void* request, /* called from application code */ grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED | GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP); - grpc_plugin_credentials_pending_request* r = - static_cast<grpc_plugin_credentials_pending_request*>(request); + grpc_plugin_credentials::pending_request* r = + static_cast<grpc_plugin_credentials::pending_request*>(request); if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " @@ -128,7 +125,7 @@ static void plugin_md_request_metadata_ready(void* request, r->creds, r); } // Remove request from pending list if not previously cancelled. - pending_request_complete(r); + r->creds->pending_request_complete(r); // If it has not been cancelled, process it. if (!r->cancelled) { grpc_error* error = @@ -143,65 +140,59 @@ static void plugin_md_request_metadata_ready(void* request, gpr_free(r); } -static bool plugin_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); +bool grpc_plugin_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { bool retval = true; // Synchronous return. - if (c->plugin.get_metadata != nullptr) { + if (plugin_.get_metadata != nullptr) { // Create pending_request object. - grpc_plugin_credentials_pending_request* pending_request = - static_cast<grpc_plugin_credentials_pending_request*>( - gpr_zalloc(sizeof(*pending_request))); - pending_request->creds = c; - pending_request->md_array = md_array; - pending_request->on_request_metadata = on_request_metadata; + pending_request* request = + static_cast<pending_request*>(gpr_zalloc(sizeof(*request))); + request->creds = this; + request->md_array = md_array; + request->on_request_metadata = on_request_metadata; // Add it to the pending list. - gpr_mu_lock(&c->mu); - if (c->pending_requests != nullptr) { - c->pending_requests->prev = pending_request; + gpr_mu_lock(&mu_); + if (pending_requests_ != nullptr) { + pending_requests_->prev = request; } - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; - gpr_mu_unlock(&c->mu); + request->next = pending_requests_; + pending_requests_ = request; + gpr_mu_unlock(&mu_); // Invoke the plugin. The callback holds a ref to us. if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin", - c, pending_request); + this, request); } - grpc_call_credentials_ref(creds); + Ref().release(); grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX]; size_t num_creds_md = 0; grpc_status_code status = GRPC_STATUS_OK; const char* error_details = nullptr; - if (!c->plugin.get_metadata(c->plugin.state, context, - plugin_md_request_metadata_ready, - pending_request, creds_md, &num_creds_md, - &status, &error_details)) { + if (!plugin_.get_metadata( + plugin_.state, context, plugin_md_request_metadata_ready, request, + creds_md, &num_creds_md, &status, &error_details)) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin will return " "asynchronously", - c, pending_request); + this, request); } return false; // Asynchronous return. } // Returned synchronously. // Remove request from pending list if not previously cancelled. - pending_request_complete(pending_request); + request->creds->pending_request_complete(request); // If the request was cancelled, the error will have been returned // asynchronously by plugin_cancel_get_request_metadata(), so return // false. Otherwise, process the result. - if (pending_request->cancelled) { + if (request->cancelled) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p was cancelled, error " "will be returned asynchronously", - c, pending_request); + this, request); } retval = false; } else { @@ -209,10 +200,10 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " "synchronously", - c, pending_request); + this, request); } - *error = process_plugin_result(pending_request, creds_md, num_creds_md, - status, error_details); + *error = process_plugin_result(request, creds_md, num_creds_md, status, + error_details); } // Clean up. for (size_t i = 0; i < num_creds_md; ++i) { @@ -220,51 +211,42 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, grpc_slice_unref_internal(creds_md[i].value); } gpr_free((void*)error_details); - gpr_free(pending_request); + gpr_free(request); } return retval; } -static void plugin_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); - gpr_mu_lock(&c->mu); - for (grpc_plugin_credentials_pending_request* pending_request = - c->pending_requests; +void grpc_plugin_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); + for (pending_request* pending_request = pending_requests_; pending_request != nullptr; pending_request = pending_request->next) { if (pending_request->md_array == md_array) { if (grpc_plugin_credentials_trace.enabled()) { - gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", c, + gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this, pending_request); } pending_request->cancelled = true; GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, GRPC_ERROR_REF(error)); - pending_request_remove_locked(c, pending_request); + pending_request_remove_locked(pending_request); break; } } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable plugin_vtable = { - plugin_destruct, plugin_get_request_metadata, - plugin_cancel_get_request_metadata}; +grpc_plugin_credentials::grpc_plugin_credentials( + grpc_metadata_credentials_plugin plugin) + : grpc_call_credentials(plugin.type), plugin_(plugin) { + gpr_mu_init(&mu_); +} grpc_call_credentials* grpc_metadata_credentials_create_from_plugin( grpc_metadata_credentials_plugin plugin, void* reserved) { - grpc_plugin_credentials* c = - static_cast<grpc_plugin_credentials*>(gpr_zalloc(sizeof(*c))); GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = plugin.type; - c->base.vtable = &plugin_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->plugin = plugin; - gpr_mu_init(&c->mu); - return &c->base; + return grpc_core::New<grpc_plugin_credentials>(plugin); } diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h index caf990efa1..77a957e513 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.h +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h @@ -25,22 +25,45 @@ extern grpc_core::TraceFlag grpc_plugin_credentials_trace; -struct grpc_plugin_credentials; - -typedef struct grpc_plugin_credentials_pending_request { - bool cancelled; - struct grpc_plugin_credentials* creds; - grpc_credentials_mdelem_array* md_array; - grpc_closure* on_request_metadata; - struct grpc_plugin_credentials_pending_request* prev; - struct grpc_plugin_credentials_pending_request* next; -} grpc_plugin_credentials_pending_request; - -typedef struct grpc_plugin_credentials { - grpc_call_credentials base; - grpc_metadata_credentials_plugin plugin; - gpr_mu mu; - grpc_plugin_credentials_pending_request* pending_requests; -} grpc_plugin_credentials; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_plugin_credentials final : public grpc_call_credentials { + public: + struct pending_request { + bool cancelled; + struct grpc_plugin_credentials* creds; + grpc_credentials_mdelem_array* md_array; + grpc_closure* on_request_metadata; + struct pending_request* prev; + struct pending_request* next; + }; + + explicit grpc_plugin_credentials(grpc_metadata_credentials_plugin plugin); + ~grpc_plugin_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + // Checks if the request has been cancelled. + // If not, removes it from the pending list, so that it cannot be + // cancelled out from under us. + // When this returns, r->cancelled indicates whether the request was + // cancelled before completion. + void pending_request_complete(pending_request* r); + + private: + void pending_request_remove_locked(pending_request* pending_request); + + grpc_metadata_credentials_plugin plugin_; + gpr_mu mu_; + pending_request* pending_requests_ = nullptr; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_PLUGIN_PLUGIN_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index 3d6f2f200a..83db86f1ea 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -44,22 +44,27 @@ void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp, gpr_free(kp); } -static void ssl_destruct(grpc_channel_credentials* creds) { - grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds); - gpr_free(c->config.pem_root_certs); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pair, 1); - if (c->config.verify_options.verify_peer_destruct != nullptr) { - c->config.verify_options.verify_peer_destruct( - c->config.verify_options.verify_peer_callback_userdata); +grpc_ssl_credentials::grpc_ssl_credentials( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + build_config(pem_root_certs, pem_key_cert_pair, verify_options); +} + +grpc_ssl_credentials::~grpc_ssl_credentials() { + gpr_free(config_.pem_root_certs); + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1); + if (config_.verify_options.verify_peer_destruct != nullptr) { + config_.verify_options.verify_peer_destruct( + config_.verify_options.verify_peer_callback_userdata); } } -static grpc_security_status ssl_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds); - grpc_security_status status = GRPC_SECURITY_OK; + grpc_channel_args** new_args) { const char* overridden_target_name = nullptr; tsi_ssl_session_cache* ssl_session_cache = nullptr; for (size_t i = 0; args && i < args->num_args; i++) { @@ -74,52 +79,47 @@ static grpc_security_status ssl_create_security_connector( static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p); } } - status = grpc_ssl_channel_security_connector_create( - creds, call_creds, &c->config, target, overridden_target_name, - ssl_session_cache, sc); - if (status != GRPC_SECURITY_OK) { - return status; + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name, ssl_session_cache); + if (sc == nullptr) { + return sc; } grpc_arg new_arg = grpc_channel_arg_string_create( (char*)GRPC_ARG_HTTP2_SCHEME, (char*)"https"); *new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1); - return status; + return sc; } -static grpc_channel_credentials_vtable ssl_vtable = { - ssl_destruct, ssl_create_security_connector, nullptr}; - -static void ssl_build_config(const char* pem_root_certs, - grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, - const verify_peer_options* verify_options, - grpc_ssl_config* config) { - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } +void grpc_ssl_credentials::build_config( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) { + config_.pem_root_certs = gpr_strdup(pem_root_certs); if (pem_key_cert_pair != nullptr) { GPR_ASSERT(pem_key_cert_pair->private_key != nullptr); GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr); - config->pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>( + config_.pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>( gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair))); - config->pem_key_cert_pair->cert_chain = + config_.pem_key_cert_pair->cert_chain = gpr_strdup(pem_key_cert_pair->cert_chain); - config->pem_key_cert_pair->private_key = + config_.pem_key_cert_pair->private_key = gpr_strdup(pem_key_cert_pair->private_key); + } else { + config_.pem_key_cert_pair = nullptr; } if (verify_options != nullptr) { - memcpy(&config->verify_options, verify_options, + memcpy(&config_.verify_options, verify_options, sizeof(verify_peer_options)); } else { // Otherwise set all options to default values - memset(&config->verify_options, 0, sizeof(verify_peer_options)); + memset(&config_.verify_options, 0, sizeof(verify_peer_options)); } } grpc_channel_credentials* grpc_ssl_credentials_create( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const verify_peer_options* verify_options, void* reserved) { - grpc_ssl_credentials* c = static_cast<grpc_ssl_credentials*>( - gpr_zalloc(sizeof(grpc_ssl_credentials))); GRPC_API_TRACE( "grpc_ssl_credentials_create(pem_root_certs=%s, " "pem_key_cert_pair=%p, " @@ -127,12 +127,9 @@ grpc_channel_credentials* grpc_ssl_credentials_create( "reserved=%p)", 4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - c->base.vtable = &ssl_vtable; - gpr_ref_init(&c->base.refcount, 1); - ssl_build_config(pem_root_certs, pem_key_cert_pair, verify_options, - &c->config); - return &c->base; + + return grpc_core::New<grpc_ssl_credentials>(pem_root_certs, pem_key_cert_pair, + verify_options); } // @@ -145,21 +142,29 @@ struct grpc_ssl_server_credentials_options { grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher; }; -static void ssl_server_destruct(grpc_server_credentials* creds) { - grpc_ssl_server_credentials* c = - reinterpret_cast<grpc_ssl_server_credentials*>(creds); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pairs, - c->config.num_key_cert_pairs); - gpr_free(c->config.pem_root_certs); +grpc_ssl_server_credentials::grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options) + : grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + if (options.certificate_config_fetcher != nullptr) { + config_.client_certificate_request = options.client_certificate_request; + certificate_config_fetcher_ = *options.certificate_config_fetcher; + } else { + build_config(options.certificate_config->pem_root_certs, + options.certificate_config->pem_key_cert_pairs, + options.certificate_config->num_key_cert_pairs, + options.client_certificate_request); + } } -static grpc_security_status ssl_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_ssl_server_security_connector_create(creds, sc); +grpc_ssl_server_credentials::~grpc_ssl_server_credentials() { + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs, + config_.num_key_cert_pairs); + gpr_free(config_.pem_root_certs); +} +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_credentials::create_security_connector() { + return grpc_ssl_server_security_connector_create(this->Ref()); } - -static grpc_server_credentials_vtable ssl_server_vtable = { - ssl_server_destruct, ssl_server_create_security_connector}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, @@ -179,18 +184,15 @@ tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( return tsi_pairs; } -static void ssl_build_server_config( +void grpc_ssl_server_credentials::build_config( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, - grpc_ssl_client_certificate_request_type client_certificate_request, - grpc_ssl_server_config* config) { - config->client_certificate_request = client_certificate_request; - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } - config->pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + grpc_ssl_client_certificate_request_type client_certificate_request) { + config_.client_certificate_request = client_certificate_request; + config_.pem_root_certs = gpr_strdup(pem_root_certs); + config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( pem_key_cert_pairs, num_key_cert_pairs); - config->num_key_cert_pairs = num_key_cert_pairs; + config_.num_key_cert_pairs = num_key_cert_pairs; } grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( @@ -200,9 +202,7 @@ grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( grpc_ssl_server_certificate_config* config = static_cast<grpc_ssl_server_certificate_config*>( gpr_zalloc(sizeof(grpc_ssl_server_certificate_config))); - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } + config->pem_root_certs = gpr_strdup(pem_root_certs); if (num_key_cert_pairs > 0) { GPR_ASSERT(pem_key_cert_pairs != nullptr); config->pem_key_cert_pairs = static_cast<grpc_ssl_pem_key_cert_pair*>( @@ -311,7 +311,6 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_ex( grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( grpc_ssl_server_credentials_options* options) { grpc_server_credentials* retval = nullptr; - grpc_ssl_server_credentials* c = nullptr; if (options == nullptr) { gpr_log(GPR_ERROR, @@ -331,23 +330,7 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( goto done; } - c = static_cast<grpc_ssl_server_credentials*>( - gpr_zalloc(sizeof(grpc_ssl_server_credentials))); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &ssl_server_vtable; - - if (options->certificate_config_fetcher != nullptr) { - c->config.client_certificate_request = options->client_certificate_request; - c->certificate_config_fetcher = *options->certificate_config_fetcher; - } else { - ssl_build_server_config(options->certificate_config->pem_root_certs, - options->certificate_config->pem_key_cert_pairs, - options->certificate_config->num_key_cert_pairs, - options->client_certificate_request, &c->config); - } - - retval = &c->base; + retval = grpc_core::New<grpc_ssl_server_credentials>(*options); done: grpc_ssl_server_credentials_options_destroy(options); diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 0fba413876..e1174327b3 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -24,27 +24,70 @@ #include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h" -typedef struct { - grpc_channel_credentials base; - grpc_ssl_config config; -} grpc_ssl_credentials; +class grpc_ssl_credentials : public grpc_channel_credentials { + public: + grpc_ssl_credentials(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + ~grpc_ssl_credentials() override; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + private: + void build_config(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + grpc_ssl_config config_; +}; struct grpc_ssl_server_certificate_config { - grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; }; -typedef struct { - grpc_ssl_server_certificate_config_callback cb; +struct grpc_ssl_server_certificate_config_fetcher { + grpc_ssl_server_certificate_config_callback cb = nullptr; void* user_data; -} grpc_ssl_server_certificate_config_fetcher; +}; + +class grpc_ssl_server_credentials final : public grpc_server_credentials { + public: + grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options); + ~grpc_ssl_server_credentials() override; -typedef struct { - grpc_server_credentials base; - grpc_ssl_server_config config; - grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher; -} grpc_ssl_server_credentials; + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + bool has_cert_config_fetcher() const { + return certificate_config_fetcher_.cb != nullptr; + } + + grpc_ssl_certificate_config_reload_status FetchCertConfig( + grpc_ssl_server_certificate_config** config) { + GPR_DEBUG_ASSERT(has_cert_config_fetcher()); + return certificate_config_fetcher_.cb(certificate_config_fetcher_.user_data, + config); + } + + const grpc_ssl_server_config& config() const { return config_; } + + private: + void build_config( + const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, + grpc_ssl_client_certificate_request_type client_certificate_request); + + grpc_ssl_server_config config_; + grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher_; +}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index dd71c8bc60..3ad0cc353c 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -28,6 +28,7 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/alts/alts_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" @@ -35,64 +36,9 @@ #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_alts_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_alts_server_security_connector; - -static void alts_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void alts_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void alts_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - grpc_alts_credentials* creds = - reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, c->target_name, creds->handshaker_service_url, - true, interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc); - grpc_alts_server_credentials* creds = - reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, nullptr, creds->handshaker_service_url, false, - interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_set_rpc_protocol_versions( +void alts_set_rpc_protocol_versions( grpc_gcp_rpc_protocol_versions* rpc_versions) { grpc_gcp_rpc_protocol_versions_set_max(rpc_versions, GRPC_PROTOCOL_VERSION_MAX_MAJOR, @@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions( GRPC_PROTOCOL_VERSION_MIN_MINOR); } +void alts_check_peer(tsi_peer peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = + grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer); + tsi_peer_destruct(&peer); + grpc_error* error = + *auth_context != nullptr + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not get ALTS auth context from TSI peer"); + GRPC_CLOSURE_SCHED(on_peer_checked, error); +} + +class grpc_alts_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_alts_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(/*url_scheme=*/nullptr, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) { + grpc_alts_credentials* creds = + static_cast<grpc_alts_credentials*>(mutable_channel_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + + ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_credentials* creds = + static_cast<const grpc_alts_credentials*>(channel_creds()); + GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_, + creds->handshaker_service_url(), true, + interested_parties, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_alts_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ALTS call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + private: + char* target_name_; +}; + +class grpc_alts_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_alts_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(/*url_scheme=*/nullptr, + std::move(server_creds)) { + grpc_alts_server_credentials* creds = + reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + ~grpc_alts_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_server_credentials* creds = + static_cast<const grpc_alts_server_credentials*>(server_creds()); + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), nullptr, creds->handshaker_service_url(), + false, interested_parties, &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } +}; +} // namespace + namespace grpc_core { namespace internal { - -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx) { - if (peer == nullptr || ctx == nullptr) { +grpc_core::RefCountedPtr<grpc_auth_context> +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { + if (peer == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to grpc_alts_auth_context_from_tsi_peer()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - *ctx = nullptr; /* Validate certificate type. */ const tsi_peer_property* cert_type_prop = tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); @@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE, cert_type_prop->value.length) != 0) { gpr_log(GPR_ERROR, "Invalid or missing certificate type property."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Validate RPC protocol versions. */ const tsi_peer_property* rpc_versions_prop = tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS); if (rpc_versions_prop == nullptr) { gpr_log(GPR_ERROR, "Missing rpc protocol versions property."); - return GRPC_SECURITY_ERROR; + return nullptr; } grpc_gcp_rpc_protocol_versions local_versions, peer_versions; alts_set_rpc_protocol_versions(&local_versions); @@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( grpc_slice_unref_internal(slice); if (!decode_result) { gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* TODO: Pass highest common rpc protocol version to grpc caller. */ bool check_result = grpc_gcp_rpc_protocol_versions_check( &local_versions, &peer_versions, nullptr); if (!check_result) { gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + auto ctx = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_ALTS_TRANSPORT_SECURITY_TYPE); size_t i = 0; for (i = 0; i < peer->property_count; i++) { @@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( /* Add service account to auth context. */ if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) { grpc_auth_context_add_property( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data, - tsi_prop->value.length); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + tsi_prop->value.data, tsi_prop->value.length); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); } } - if (!grpc_auth_context_peer_is_authenticated(*ctx)) { + if (!grpc_auth_context_peer_is_authenticated(ctx.get())) { gpr_log(GPR_ERROR, "Invalid unauthenticated peer."); - GRPC_AUTH_CONTEXT_UNREF(*ctx, "test"); - *ctx = nullptr; - return GRPC_SECURITY_ERROR; + ctx.reset(DEBUG_LOCATION, "test"); + return nullptr; } - return GRPC_SECURITY_OK; + return ctx; } } // namespace internal } // namespace grpc_core -static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; - status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer( - &peer, auth_context); - tsi_peer_destruct(&peer); - grpc_error* error = - status == GRPC_SECURITY_OK - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not get ALTS auth context from TSI peer"); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} - -static int alts_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_channel_security_connector* c1 = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc1); - grpc_alts_channel_security_connector* c2 = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int alts_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_server_security_connector* c1 = - reinterpret_cast<grpc_alts_server_security_connector*>(sc1); - grpc_alts_server_security_connector* c2 = - reinterpret_cast<grpc_alts_server_security_connector*>(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_connector_vtable alts_channel_vtable = { - alts_channel_destroy, alts_check_peer, alts_channel_cmp}; - -static grpc_security_connector_vtable alts_server_vtable = { - alts_server_destroy, alts_check_peer, alts_server_cmp}; - -static bool alts_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_alts_channel_security_connector* alts_sc = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - if (host == nullptr || alts_sc == nullptr || - strcmp(host, alts_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "ALTS call host does not match target name"); - } - return true; -} - -static void alts_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_alts_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_alts_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_channel_vtable; - c->base.add_handshakers = alts_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = alts_check_call_host; - c->base.cancel_check_call_host = alts_cancel_check_call_host; - grpc_alts_credentials* creds = - reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_alts_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), target_name); } -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_alts_server_security_connector*>( - gpr_zalloc(sizeof(grpc_alts_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.add_handshakers = alts_server_add_handshakers; - grpc_alts_server_credentials* creds = - reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_alts_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.h b/src/core/lib/security/security_connector/alts/alts_security_connector.h index d2e057a76a..b96dc36b30 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.h +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.h @@ -36,12 +36,13 @@ * - sc: address of ALTS channel security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name); /** * This method creates an ALTS server security connector. @@ -50,17 +51,18 @@ grpc_security_status grpc_alts_channel_security_connector_create( * - sc: address of ALTS server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); namespace grpc_core { namespace internal { /* Exposed only for testing. */ -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx); +grpc_core::RefCountedPtr<grpc_auth_context> +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer); } // namespace internal } // namespace grpc_core diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc index 5c0c89b88f..e3b8affb36 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -31,6 +31,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/fake/fake_credentials.h" @@ -38,91 +39,183 @@ #include "src/core/lib/security/transport/target_authority_table.h" #include "src/core/tsi/fake_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target; - char* expected_targets; - bool is_lb_channel; - char* target_name_override; -} grpc_fake_channel_security_connector; +namespace { +class grpc_fake_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_fake_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args) + : grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_(gpr_strdup(target)), + expected_targets_( + gpr_strdup(grpc_fake_transport_get_expected_targets(args))), + is_lb_channel_(grpc_core::FindTargetAuthorityTableInArgs(args) != + nullptr) { + const grpc_arg* target_name_override_arg = + grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); + if (target_name_override_arg != nullptr) { + target_name_override_ = + gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + } else { + target_name_override_ = nullptr; + } + } -static void fake_channel_destroy(grpc_security_connector* sc) { - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - gpr_free(c->target); - gpr_free(c->expected_targets); - gpr_free(c->target_name_override); - gpr_free(c); -} + ~grpc_fake_channel_security_connector() override { + gpr_free(target_); + gpr_free(expected_targets_); + if (target_name_override_ != nullptr) gpr_free(target_name_override_); + } -static void fake_server_destroy(grpc_security_connector* sc) { gpr_free(sc); } + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override; -static bool fake_check_target(const char* target_type, const char* target, - const char* set_str) { - GPR_ASSERT(target_type != nullptr); - GPR_ASSERT(target != nullptr); - char** set = nullptr; - size_t set_size = 0; - gpr_string_split(set_str, ",", &set, &set_size); - bool found = false; - for (size_t i = 0; i < set_size; ++i) { - if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_fake_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_, other->target_); + if (c != 0) return c; + if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) { + c = GPR_ICMP(expected_targets_, other->expected_targets_); + } else { + c = strcmp(expected_targets_, other->expected_targets_); + } + if (c != 0) return c; + return GPR_ICMP(is_lb_channel_, other->is_lb_channel_); } - for (size_t i = 0; i < set_size; ++i) { - gpr_free(set[i]); + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*is_client=*/true), this)); } - gpr_free(set); - return found; -} -static void fake_secure_name_check(const char* target, - const char* expected_targets, - bool is_lb_channel) { - if (expected_targets == nullptr) return; - char** lbs_and_backends = nullptr; - size_t lbs_and_backends_size = 0; - bool success = false; - gpr_string_split(expected_targets, ";", &lbs_and_backends, - &lbs_and_backends_size); - if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { - gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", - expected_targets); - goto done; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + char* authority_hostname = nullptr; + char* authority_ignored_port = nullptr; + char* target_hostname = nullptr; + char* target_ignored_port = nullptr; + gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); + gpr_split_host_port(target_, &target_hostname, &target_ignored_port); + if (target_name_override_ != nullptr) { + char* fake_security_target_name_override_hostname = nullptr; + char* fake_security_target_name_override_ignored_port = nullptr; + gpr_split_host_port(target_name_override_, + &fake_security_target_name_override_hostname, + &fake_security_target_name_override_ignored_port); + if (strcmp(authority_hostname, + fake_security_target_name_override_hostname) != 0) { + gpr_log(GPR_ERROR, + "Authority (host) '%s' != Fake Security Target override '%s'", + host, fake_security_target_name_override_hostname); + abort(); + } + gpr_free(fake_security_target_name_override_hostname); + gpr_free(fake_security_target_name_override_ignored_port); + } else if (strcmp(authority_hostname, target_hostname) != 0) { + gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", + authority_hostname, target_hostname); + abort(); + } + gpr_free(authority_hostname); + gpr_free(authority_ignored_port); + gpr_free(target_hostname); + gpr_free(target_ignored_port); + return true; } - if (is_lb_channel) { - if (lbs_and_backends_size != 2) { - gpr_log(GPR_ERROR, - "Invalid expected targets arg value: '%s'. Expectations for LB " - "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", - expected_targets); - goto done; + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + char* target() const { return target_; } + char* expected_targets() const { return expected_targets_; } + bool is_lb_channel() const { return is_lb_channel_; } + char* target_name_override() const { return target_name_override_; } + + private: + bool fake_check_target(const char* target_type, const char* target, + const char* set_str) const { + GPR_ASSERT(target_type != nullptr); + GPR_ASSERT(target != nullptr); + char** set = nullptr; + size_t set_size = 0; + gpr_string_split(set_str, ",", &set, &set_size); + bool found = false; + for (size_t i = 0; i < set_size; ++i) { + if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; } - if (!fake_check_target("LB", target, lbs_and_backends[1])) { - gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", - target, lbs_and_backends[1]); - goto done; + for (size_t i = 0; i < set_size; ++i) { + gpr_free(set[i]); } - success = true; - } else { - if (!fake_check_target("Backend", target, lbs_and_backends[0])) { - gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", - target, lbs_and_backends[0]); + gpr_free(set); + return found; + } + + void fake_secure_name_check() const { + if (expected_targets_ == nullptr) return; + char** lbs_and_backends = nullptr; + size_t lbs_and_backends_size = 0; + bool success = false; + gpr_string_split(expected_targets_, ";", &lbs_and_backends, + &lbs_and_backends_size); + if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { + gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", + expected_targets_); goto done; } - success = true; - } -done: - for (size_t i = 0; i < lbs_and_backends_size; ++i) { - gpr_free(lbs_and_backends[i]); + if (is_lb_channel_) { + if (lbs_and_backends_size != 2) { + gpr_log(GPR_ERROR, + "Invalid expected targets arg value: '%s'. Expectations for LB " + "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", + expected_targets_); + goto done; + } + if (!fake_check_target("LB", target_, lbs_and_backends[1])) { + gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", + target_, lbs_and_backends[1]); + goto done; + } + success = true; + } else { + if (!fake_check_target("Backend", target_, lbs_and_backends[0])) { + gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", + target_, lbs_and_backends[0]); + goto done; + } + success = true; + } + done: + for (size_t i = 0; i < lbs_and_backends_size; ++i) { + gpr_free(lbs_and_backends[i]); + } + gpr_free(lbs_and_backends); + if (!success) abort(); } - gpr_free(lbs_and_backends); - if (!success) abort(); -} -static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { + char* target_; + char* expected_targets_; + bool is_lb_channel_; + char* target_name_override_; +}; + +static void fake_check_peer( + grpc_security_connector* sc, tsi_peer peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { const char* prop_name; grpc_error* error = GRPC_ERROR_NONE; *auth_context = nullptr; @@ -147,164 +240,66 @@ static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, "Invalid value for cert type property."); goto end; } - *auth_context = grpc_auth_context_create(nullptr); + *auth_context = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_FAKE_TRANSPORT_SECURITY_TYPE); end: GRPC_CLOSURE_SCHED(on_peer_checked, error); tsi_peer_destruct(&peer); } -static void fake_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - fake_secure_name_check(c->target, c->expected_targets, c->is_lb_channel); +void grpc_fake_channel_security_connector::check_peer( + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { + fake_check_peer(this, peer, auth_context, on_peer_checked); + fake_secure_name_check(); } -static void fake_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); -} +class grpc_fake_server_security_connector + : public grpc_server_security_connector { + public: + grpc_fake_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(server_creds)) {} + ~grpc_fake_server_security_connector() override = default; -static int fake_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_fake_channel_security_connector* c1 = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc1); - grpc_fake_channel_security_connector* c2 = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target, c2->target); - if (c != 0) return c; - if (c1->expected_targets == nullptr || c2->expected_targets == nullptr) { - c = GPR_ICMP(c1->expected_targets, c2->expected_targets); - } else { - c = strcmp(c1->expected_targets, c2->expected_targets); + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + fake_check_peer(this, peer, auth_context, on_peer_checked); } - if (c != 0) return c; - return GPR_ICMP(c1->is_lb_channel, c2->is_lb_channel); -} -static int fake_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast<grpc_server_security_connector*>(sc1), - reinterpret_cast<grpc_server_security_connector*>(sc2)); -} - -static bool fake_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - char* authority_hostname = nullptr; - char* authority_ignored_port = nullptr; - char* target_hostname = nullptr; - char* target_ignored_port = nullptr; - gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); - gpr_split_host_port(c->target, &target_hostname, &target_ignored_port); - if (c->target_name_override != nullptr) { - char* fake_security_target_name_override_hostname = nullptr; - char* fake_security_target_name_override_ignored_port = nullptr; - gpr_split_host_port(c->target_name_override, - &fake_security_target_name_override_hostname, - &fake_security_target_name_override_ignored_port); - if (strcmp(authority_hostname, - fake_security_target_name_override_hostname) != 0) { - gpr_log(GPR_ERROR, - "Authority (host) '%s' != Fake Security Target override '%s'", - host, fake_security_target_name_override_hostname); - abort(); - } - gpr_free(fake_security_target_name_override_hostname); - gpr_free(fake_security_target_name_override_ignored_port); - } else if (strcmp(authority_hostname, target_hostname) != 0) { - gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", - authority_hostname, target_hostname); - abort(); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*=is_client*/ false), this)); } - gpr_free(authority_hostname); - gpr_free(authority_ignored_port); - gpr_free(target_hostname); - gpr_free(target_ignored_port); - return true; -} -static void fake_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -static void fake_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(true /* is_client */), &sc->base)); -} - -static void fake_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(false /* is_client */), &sc->base)); -} - -static grpc_security_connector_vtable fake_channel_vtable = { - fake_channel_destroy, fake_channel_check_peer, fake_channel_cmp}; - -static grpc_security_connector_vtable fake_server_vtable = { - fake_server_destroy, fake_server_check_peer, fake_server_cmp}; - -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args) { - grpc_fake_channel_security_connector* c = - static_cast<grpc_fake_channel_security_connector*>( - gpr_zalloc(sizeof(*c))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->base.base.vtable = &fake_channel_vtable; - c->base.channel_creds = channel_creds; - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = fake_channel_check_call_host; - c->base.cancel_check_call_host = fake_channel_cancel_check_call_host; - c->base.add_handshakers = fake_channel_add_handshakers; - c->target = gpr_strdup(target); - const char* expected_targets = grpc_fake_transport_get_expected_targets(args); - c->expected_targets = gpr_strdup(expected_targets); - c->is_lb_channel = grpc_core::FindTargetAuthorityTableInArgs(args) != nullptr; - const grpc_arg* target_name_override_arg = - grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); - if (target_name_override_arg != nullptr) { - c->target_name_override = - gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); } - return &c->base; +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args) { + return grpc_core::MakeRefCounted<grpc_fake_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), target, + args); } -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds) { - grpc_server_security_connector* c = - static_cast<grpc_server_security_connector*>( - gpr_zalloc(sizeof(grpc_server_security_connector))); - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &fake_server_vtable; - c->base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->server_creds = server_creds; - c->add_handshakers = fake_server_add_handshakers; - return c; +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + return grpc_core::MakeRefCounted<grpc_fake_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.h b/src/core/lib/security/security_connector/fake/fake_security_connector.h index fdfe048c6e..344a2349a4 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.h +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.h @@ -24,19 +24,22 @@ #include <grpc/grpc_security.h> #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/security_connector/security_connector.h" #define GRPC_FAKE_SECURITY_URL_SCHEME "http+fake_security" /* Creates a fake connector that emulates real channel security. */ -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args); /* Creates a fake connector that emulates real server security. */ -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_FAKE_FAKE_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc index 008a98df28..7cc482c16c 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.cc +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -30,217 +30,224 @@ #include "src/core/ext/filters/client_channel/client_channel.h" #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr_utils.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" #include "src/core/lib/security/credentials/local/local_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/tsi/local_transport_security.h" #define GRPC_UDS_URI_PATTERN "unix:" -#define GRPC_UDS_URL_SCHEME "unix" #define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_local_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_local_server_security_connector; - -static void local_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_local_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void local_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_local_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void local_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void local_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static int local_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_channel_security_connector* c1 = - reinterpret_cast<grpc_local_channel_security_connector*>(sc1); - grpc_local_channel_security_connector* c2 = - reinterpret_cast<grpc_local_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int local_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_server_security_connector* c1 = - reinterpret_cast<grpc_local_server_security_connector*>(sc1); - grpc_local_server_security_connector* c2 = - reinterpret_cast<grpc_local_server_security_connector*>(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_status local_auth_context_create(grpc_auth_context** ctx) { - if (ctx == nullptr) { - gpr_log(GPR_ERROR, "Invalid arguments to local_auth_context_create()"); - return GRPC_SECURITY_ERROR; - } +grpc_core::RefCountedPtr<grpc_auth_context> local_auth_context_create() { /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr<grpc_auth_context> ctx = + grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_LOCAL_TRANSPORT_SECURITY_TYPE); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); - return GRPC_SECURITY_OK; + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); + return ctx; } -static void local_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; +void local_check_peer(grpc_security_connector* sc, tsi_peer peer, + grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked, + grpc_local_connect_type type) { + int fd = grpc_endpoint_get_fd(ep); + grpc_resolved_address resolved_addr; + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = GRPC_MAX_SOCKADDR_SIZE; + bool is_endpoint_local = false; + if (getsockname(fd, reinterpret_cast<grpc_sockaddr*>(resolved_addr.addr), + &resolved_addr.len) == 0) { + grpc_resolved_address addr_normalized; + grpc_resolved_address* addr = + grpc_sockaddr_is_v4mapped(&resolved_addr, &addr_normalized) + ? &addr_normalized + : &resolved_addr; + grpc_sockaddr* sock_addr = reinterpret_cast<grpc_sockaddr*>(&addr->addr); + // UDS + if (type == UDS && grpc_is_unix_socket(addr)) { + is_endpoint_local = true; + // IPV4 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast<const grpc_sockaddr_in*>(sock_addr); + if (grpc_htonl(addr4->sin_addr.s_addr) == INADDR_LOOPBACK) { + is_endpoint_local = true; + } + // IPv6 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast<const grpc_sockaddr_in6*>(addr); + if (memcmp(&addr6->sin6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)) == 0) { + is_endpoint_local = true; + } + } + } + grpc_error* error = GRPC_ERROR_NONE; + if (!is_endpoint_local) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Endpoint is neither UDS or TCP loopback address."); + GRPC_CLOSURE_SCHED(on_peer_checked, error); + return; + } /* Create an auth context which is necessary to pass the santiy check in * {client, server}_auth_filter that verifies if the peer's auth context is * obtained during handshakes. The auth context is only checked for its * existence and not actually used. */ - status = local_auth_context_create(auth_context); - grpc_error* error = status == GRPC_SECURITY_OK - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not create local auth context"); + *auth_context = local_auth_context_create(); + error = *auth_context != nullptr ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not create local auth context"); GRPC_CLOSURE_SCHED(on_peer_checked, error); } -static grpc_security_connector_vtable local_channel_vtable = { - local_channel_destroy, local_check_peer, local_channel_cmp}; - -static grpc_security_connector_vtable local_server_vtable = { - local_server_destroy, local_check_peer, local_server_cmp}; - -static bool local_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_local_channel_security_connector* local_sc = - reinterpret_cast<grpc_local_channel_security_connector*>(sc); - if (host == nullptr || local_sc == nullptr || - strcmp(host, local_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "local call host does not match target name"); +class grpc_local_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_local_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(nullptr, std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) {} + + ~grpc_local_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == + TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - return true; -} -static void local_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_local_channel_security_connector*>( + other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_credentials* creds = + reinterpret_cast<grpc_local_credentials*>(mutable_channel_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "local call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { + const char* target_name() const { return target_name_; } + + private: + char* target_name_; +}; + +class grpc_local_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_local_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(nullptr, std::move(server_creds)) {} + ~grpc_local_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_server_credentials* creds = + static_cast<grpc_local_server_credentials*>(mutable_server_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_channel_args* args, const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - // Check if local_connect_type is UDS. Only UDS is supported for now. + // Perform sanity check on UDS address. For TCP local connection, the check + // will be done during check_peer procedure. grpc_local_credentials* creds = - reinterpret_cast<grpc_local_credentials*>(channel_creds); - if (creds->connect_type != UDS) { - gpr_log(GPR_ERROR, - "Invalid local channel type to " - "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; - } - // Check if target_name is a valid UDS address. + static_cast<grpc_local_credentials*>(channel_creds.get()); const grpc_arg* server_uri_arg = grpc_channel_args_find(args, GRPC_ARG_SERVER_URI); const char* server_uri_str = grpc_channel_arg_get_string(server_uri_arg); - if (strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, + if (creds->connect_type() == UDS && + strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, strlen(GRPC_UDS_URI_PATTERN)) != 0) { gpr_log(GPR_ERROR, - "Invalid target_name to " + "Invalid UDS target name to " "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_local_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_local_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_channel_vtable; - c->base.add_handshakers = local_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = local_check_call_host; - c->base.cancel_check_call_host = local_cancel_check_call_host; - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_local_channel_security_connector>( + channel_creds, request_metadata_creds, target_name); } -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; - } - // Check if local_connect_type is UDS. Only UDS is supported for now. - grpc_local_server_credentials* creds = - reinterpret_cast<grpc_local_server_credentials*>(server_creds); - if (creds->connect_type != UDS) { - gpr_log(GPR_ERROR, - "Invalid local server type to " - "grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_local_server_security_connector*>( - gpr_zalloc(sizeof(grpc_local_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->base.add_handshakers = local_server_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_local_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/local/local_security_connector.h b/src/core/lib/security/security_connector/local/local_security_connector.h index 5369a2127a..6eee0ca9a6 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.h +++ b/src/core/lib/security/security_connector/local/local_security_connector.h @@ -34,13 +34,13 @@ * - sc: address of local channel security connector instance to be returned * from the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_channel_args* args, const char* target_name); /** * This method creates a local server security connector. @@ -49,10 +49,11 @@ grpc_security_status grpc_local_channel_security_connector_create( * - sc: address of local server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_LOCAL_LOCAL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index 02cecb0eb1..96a1960546 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -35,150 +35,67 @@ #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/security/transport/security_handshaker.h" grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount( false, "security_connector_refcount"); -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - if (sc == nullptr) { - GRPC_CLOSURE_SCHED(on_peer_checked, - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check peer -- no security connector")); - tsi_peer_destruct(&peer); - } else { - sc->vtable->check_peer(sc, peer, auth_context, on_peer_checked); - } -} - -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other) { +grpc_server_security_connector::grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_security_connector(url_scheme), + server_creds_(std::move(server_creds)) {} + +grpc_channel_security_connector::grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds) + : grpc_security_connector(url_scheme), + channel_creds_(std::move(channel_creds)), + request_metadata_creds_(std::move(request_metadata_creds)) {} +grpc_channel_security_connector::~grpc_channel_security_connector() {} + +int grpc_security_connector_cmp(const grpc_security_connector* sc, + const grpc_security_connector* other) { if (sc == nullptr || other == nullptr) return GPR_ICMP(sc, other); - int c = GPR_ICMP(sc->vtable, other->vtable); - if (c != 0) return c; - return sc->vtable->cmp(sc, other); + return sc->cmp(other); } -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2) { - GPR_ASSERT(sc1->channel_creds != nullptr); - GPR_ASSERT(sc2->channel_creds != nullptr); - int c = GPR_ICMP(sc1->channel_creds, sc2->channel_creds); - if (c != 0) return c; - c = GPR_ICMP(sc1->request_metadata_creds, sc2->request_metadata_creds); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->check_call_host, (void*)sc2->check_call_host); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->cancel_check_call_host, - (void*)sc2->cancel_check_call_host); +int grpc_channel_security_connector::channel_security_connector_cmp( + const grpc_channel_security_connector* other) const { + const grpc_channel_security_connector* other_sc = + static_cast<const grpc_channel_security_connector*>(other); + GPR_ASSERT(channel_creds() != nullptr); + GPR_ASSERT(other_sc->channel_creds() != nullptr); + int c = GPR_ICMP(channel_creds(), other_sc->channel_creds()); if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); + return GPR_ICMP(request_metadata_creds(), other_sc->request_metadata_creds()); } -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2) { - GPR_ASSERT(sc1->server_creds != nullptr); - GPR_ASSERT(sc2->server_creds != nullptr); - int c = GPR_ICMP(sc1->server_creds, sc2->server_creds); - if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); -} - -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error) { - if (sc == nullptr || sc->check_call_host == nullptr) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check call host -- no security connector"); - return true; - } - return sc->check_call_host(sc, host, auth_context, on_call_host_checked, - error); -} - -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - if (sc == nullptr || sc->cancel_check_call_host == nullptr) { - GRPC_ERROR_UNREF(error); - return; - } - sc->cancel_check_call_host(sc, on_call_host_checked, error); -} - -#ifndef NDEBUG -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc, const char* file, int line, - const char* reason) { - if (sc == nullptr) return nullptr; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val + 1, reason); - } -#else -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc) { - if (sc == nullptr) return nullptr; -#endif - gpr_ref(&sc->refcount); - return sc; -} - -#ifndef NDEBUG -void grpc_security_connector_unref(grpc_security_connector* sc, - const char* file, int line, - const char* reason) { - if (sc == nullptr) return; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val - 1, reason); - } -#else -void grpc_security_connector_unref(grpc_security_connector* sc) { - if (sc == nullptr) return; -#endif - if (gpr_unref(&sc->refcount)) sc->vtable->destroy(sc); +int grpc_server_security_connector::server_security_connector_cmp( + const grpc_server_security_connector* other) const { + const grpc_server_security_connector* other_sc = + static_cast<const grpc_server_security_connector*>(other); + GPR_ASSERT(server_creds() != nullptr); + GPR_ASSERT(other_sc->server_creds() != nullptr); + return GPR_ICMP(server_creds(), other_sc->server_creds()); } static void connector_arg_destroy(void* p) { - GRPC_SECURITY_CONNECTOR_UNREF((grpc_security_connector*)p, - "connector_arg_destroy"); + static_cast<grpc_security_connector*>(p)->Unref(DEBUG_LOCATION, + "connector_arg_destroy"); } static void* connector_arg_copy(void* p) { - return GRPC_SECURITY_CONNECTOR_REF((grpc_security_connector*)p, - "connector_arg_copy"); + return static_cast<grpc_security_connector*>(p) + ->Ref(DEBUG_LOCATION, "connector_arg_copy") + .release(); } static int connector_cmp(void* a, void* b) { - return grpc_security_connector_cmp(static_cast<grpc_security_connector*>(a), - static_cast<grpc_security_connector*>(b)); + return static_cast<grpc_security_connector*>(a)->cmp( + static_cast<grpc_security_connector*>(b)); } static const grpc_arg_pointer_vtable connector_arg_vtable = { diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 4c921a8793..74b0ef21a6 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -26,6 +26,7 @@ #include <grpc/grpc_security.h> #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/iomgr/tcp_server.h" @@ -34,8 +35,6 @@ extern grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount; -/* --- status enum. --- */ - typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; /* --- security_connector object. --- @@ -43,54 +42,34 @@ typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; A security connector object represents away to configure the underlying transport security mechanism and check the resulting trusted peer. */ -typedef struct grpc_security_connector grpc_security_connector; - #define GRPC_ARG_SECURITY_CONNECTOR "grpc.security_connector" -typedef struct { - void (*destroy)(grpc_security_connector* sc); - void (*check_peer)(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - int (*cmp)(grpc_security_connector* sc, grpc_security_connector* other); -} grpc_security_connector_vtable; - -struct grpc_security_connector { - const grpc_security_connector_vtable* vtable; - gpr_refcount refcount; - const char* url_scheme; -}; +class grpc_security_connector + : public grpc_core::RefCounted<grpc_security_connector> { + public: + explicit grpc_security_connector(const char* url_scheme) + : grpc_core::RefCounted<grpc_security_connector>( + &grpc_trace_security_connector_refcount), + url_scheme_(url_scheme) {} + virtual ~grpc_security_connector() = default; + + /* Check the peer. Callee takes ownership of the peer object. + When done, sets *auth_context and invokes on_peer_checked. */ + virtual void check_peer( + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) GRPC_ABSTRACT; + + /* Compares two security connectors. */ + virtual int cmp(const grpc_security_connector* other) const GRPC_ABSTRACT; + + const char* url_scheme() const { return url_scheme_; } -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_SECURITY_CONNECTOR_REF(p, r) \ - grpc_security_connector_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) \ - grpc_security_connector_unref((p), __FILE__, __LINE__, (r)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy, const char* file, int line, - const char* reason); -void grpc_security_connector_unref(grpc_security_connector* policy, - const char* file, int line, - const char* reason); -#else -#define GRPC_SECURITY_CONNECTOR_REF(p, r) grpc_security_connector_ref((p)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) grpc_security_connector_unref((p)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy); -void grpc_security_connector_unref(grpc_security_connector* policy); -#endif - -/* Check the peer. Callee takes ownership of the peer object. - When done, sets *auth_context and invokes on_peer_checked. */ -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - -/* Compares two security connectors. */ -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other); + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* url_scheme_; +}; /* Util to encapsulate the connector in a channel arg. */ grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc); @@ -107,71 +86,89 @@ grpc_security_connector* grpc_security_connector_find_in_args( A channel security connector object represents a way to configure the underlying transport security mechanism on the client side. */ -typedef struct grpc_channel_security_connector grpc_channel_security_connector; - -struct grpc_channel_security_connector { - grpc_security_connector base; - grpc_channel_credentials* channel_creds; - grpc_call_credentials* request_metadata_creds; - bool (*check_call_host)(grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error); - void (*cancel_check_call_host)(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error); - void (*add_handshakers)(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_channel_security_connector : public grpc_security_connector { + public: + grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds); + ~grpc_channel_security_connector() override; + + /// Checks that the host that will be set for a call is acceptable. + /// Returns true if completed synchronously, in which case \a error will + /// be set to indicate the result. Otherwise, \a on_call_host_checked + /// will be invoked when complete. + virtual bool check_call_host(const char* host, + grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) GRPC_ABSTRACT; + /// Cancels a pending asychronous call to + /// grpc_channel_security_connector_check_call_host() with + /// \a on_call_host_checked as its callback. + virtual void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) GRPC_ABSTRACT; + /// Registers handshakers with \a handshake_mgr. + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_channel_credentials* channel_creds() const { + return channel_creds_.get(); + } + grpc_channel_credentials* mutable_channel_creds() { + return channel_creds_.get(); + } + const grpc_call_credentials* request_metadata_creds() const { + return request_metadata_creds_.get(); + } + grpc_call_credentials* mutable_request_metadata_creds() { + return request_metadata_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int channel_security_connector_cmp( + const grpc_channel_security_connector* other) const; + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_; + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2); - -/// Checks that the host that will be set for a call is acceptable. -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_call_host_checked -/// will be invoked when complete. -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error); - -/// Cancels a pending asychronous call to -/// grpc_channel_security_connector_check_call_host() with -/// \a on_call_host_checked as its callback. -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error); - -/* Registers handshakers with \a handshake_mgr. */ -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - /* --- server_security_connector object. --- A server security connector object represents a way to configure the underlying transport security mechanism on the server side. */ -typedef struct grpc_server_security_connector grpc_server_security_connector; - -struct grpc_server_security_connector { - grpc_security_connector base; - grpc_server_credentials* server_creds; - void (*add_handshakers)(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_server_security_connector : public grpc_security_connector { + public: + grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); + ~grpc_server_security_connector() override = default; + + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_server_credentials* server_creds() const { + return server_creds_.get(); + } + grpc_server_credentials* mutable_server_creds() { + return server_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int server_security_connector_cmp( + const grpc_server_security_connector* other) const; + + private: + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2); - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */ diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 20a9533dd1..7414ab1a37 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/ssl/ssl_credentials.h" @@ -39,172 +40,10 @@ #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* client_handshaker_factory; - char* target_name; - char* overridden_target_name; - const verify_peer_options* verify_options; -} grpc_ssl_channel_security_connector; - -typedef struct { - grpc_server_security_connector base; - tsi_ssl_server_handshaker_factory* server_handshaker_factory; -} grpc_ssl_server_security_connector; - -static bool server_connector_has_cert_config_fetcher( - grpc_ssl_server_security_connector* c) { - GPR_ASSERT(c != nullptr); - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(c->base.server_creds); - GPR_ASSERT(server_creds != nullptr); - return server_creds->certificate_config_fetcher.cb != nullptr; -} - -static void ssl_channel_destroy(grpc_security_connector* sc) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - grpc_channel_credentials_unref(c->base.channel_creds); - grpc_call_credentials_unref(c->base.request_metadata_creds); - tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory); - c->client_handshaker_factory = nullptr; - if (c->target_name != nullptr) gpr_free(c->target_name); - if (c->overridden_target_name != nullptr) gpr_free(c->overridden_target_name); - gpr_free(sc); -} - -static void ssl_server_destroy(grpc_security_connector* sc) { - grpc_ssl_server_security_connector* c = - reinterpret_cast<grpc_ssl_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory); - c->server_handshaker_factory = nullptr; - gpr_free(sc); -} - -static void ssl_channel_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - // Instantiate TSI handshaker. - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->client_handshaker_factory, - c->overridden_target_name != nullptr ? c->overridden_target_name - : c->target_name, - &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -/* Attempts to replace the server_handshaker_factory with a new factory using - * the provided grpc_ssl_server_certificate_config. Should new factory creation - * fail, the existing factory will not be replaced. Returns true on success (new - * factory created). */ -static bool try_replace_server_handshaker_factory( - grpc_ssl_server_security_connector* sc, - const grpc_ssl_server_certificate_config* config) { - if (config == nullptr) { - gpr_log(GPR_ERROR, - "Server certificate config callback returned invalid (NULL) " - "config."); - return false; - } - gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); - - size_t num_alpn_protocols = 0; - const char** alpn_protocol_strings = - grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( - config->pem_key_cert_pairs, config->num_key_cert_pairs); - tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds); - tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( - cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, - grpc_get_tsi_client_certificate_request_type( - server_creds->config.client_certificate_request), - grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory); - gpr_free(cert_pairs); - gpr_free((void*)alpn_protocol_strings); - - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - return false; - } - tsi_ssl_server_handshaker_factory_unref(sc->server_handshaker_factory); - sc->server_handshaker_factory = new_handshaker_factory; - return true; -} - -/* Attempts to fetch the server certificate config if a callback is available. - * Current certificate config will continue to be used if the callback returns - * an error. Returns true if new credentials were sucessfully loaded. */ -static bool try_fetch_ssl_server_credentials( - grpc_ssl_server_security_connector* sc) { - grpc_ssl_server_certificate_config* certificate_config = nullptr; - bool status; - - GPR_ASSERT(sc != nullptr); - if (!server_connector_has_cert_config_fetcher(sc)) return false; - - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds); - grpc_ssl_certificate_config_reload_status cb_result = - server_creds->certificate_config_fetcher.cb( - server_creds->certificate_config_fetcher.user_data, - &certificate_config); - if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { - gpr_log(GPR_DEBUG, "No change in SSL server credentials."); - status = false; - } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { - status = try_replace_server_handshaker_factory(sc, certificate_config); - } else { - // Log error, continue using previously-loaded credentials. - gpr_log(GPR_ERROR, - "Failed fetching new server credentials, continuing to " - "use previously-loaded credentials."); - status = false; - } - - if (certificate_config != nullptr) { - grpc_ssl_server_certificate_config_destroy(certificate_config); - } - return status; -} - -static void ssl_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_server_security_connector* c = - reinterpret_cast<grpc_ssl_server_security_connector*>(sc); - // Instantiate TSI handshaker. - try_fetch_ssl_server_credentials(c); - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( - c->server_handshaker_factory, &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -static grpc_error* ssl_check_peer(grpc_security_connector* sc, - const char* peer_name, const tsi_peer* peer, - grpc_auth_context** auth_context) { +namespace { +grpc_error* ssl_check_peer( + const char* peer_name, const tsi_peer* peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context) { #if TSI_OPENSSL_ALPN_SUPPORT /* Check the ALPN if ALPN is supported. */ const tsi_peer_property* p = @@ -230,245 +69,384 @@ static grpc_error* ssl_check_peer(grpc_security_connector* sc, return GRPC_ERROR_NONE; } -static void ssl_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - const char* target_name = c->overridden_target_name != nullptr - ? c->overridden_target_name - : c->target_name; - grpc_error* error = ssl_check_peer(sc, target_name, &peer, auth_context); - if (error == GRPC_ERROR_NONE && - c->verify_options->verify_peer_callback != nullptr) { - const tsi_peer_property* p = - tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); - if (p == nullptr) { - error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Cannot check peer: missing pem cert property."); - } else { - char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1)); - memcpy(peer_pem, p->value.data, p->value.length); - peer_pem[p->value.length] = '\0'; - int callback_status = c->verify_options->verify_peer_callback( - target_name, peer_pem, - c->verify_options->verify_peer_callback_userdata); - gpr_free(peer_pem); - if (callback_status) { - char* msg; - gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", - callback_status); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); - } - } +class grpc_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_ssl_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name) + : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + overridden_target_name_(overridden_target_name == nullptr + ? nullptr + : gpr_strdup(overridden_target_name)), + verify_options_(&config->verify_options) { + char* port; + gpr_split_host_port(target_name, &target_name_, &port); + gpr_free(port); } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static void ssl_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_error* error = ssl_check_peer(sc, nullptr, &peer, auth_context); - tsi_peer_destruct(&peer); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} + ~grpc_ssl_channel_security_connector() override { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + if (target_name_ != nullptr) gpr_free(target_name_); + if (overridden_target_name_ != nullptr) gpr_free(overridden_target_name_); + } -static int ssl_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_ssl_channel_security_connector* c1 = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc1); - grpc_ssl_channel_security_connector* c2 = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target_name, c2->target_name); - if (c != 0) return c; - return (c1->overridden_target_name == nullptr || - c2->overridden_target_name == nullptr) - ? GPR_ICMP(c1->overridden_target_name, c2->overridden_target_name) - : strcmp(c1->overridden_target_name, c2->overridden_target_name); -} + grpc_security_status InitializeHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache) { + bool has_key_cert_pair = + config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + GPR_DEBUG_ASSERT(pem_root_certs != nullptr); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory_); + gpr_free((void*)options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; + } -static int ssl_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast<grpc_server_security_connector*>(sc1), - reinterpret_cast<grpc_server_security_connector*>(sc2)); -} + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory_, + overridden_target_name_ != nullptr ? overridden_target_name_ + : target_name_, + &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); + } -static bool ssl_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - grpc_security_status status = GRPC_SECURITY_ERROR; - tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); - if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; - /* If the target name was overridden, then the original target_name was - 'checked' transitively during the previous peer check at the end of the - handshake. */ - if (c->overridden_target_name != nullptr && - strcmp(host, c->target_name) == 0) { - status = GRPC_SECURITY_OK; + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + const char* target_name = overridden_target_name_ != nullptr + ? overridden_target_name_ + : target_name_; + grpc_error* error = ssl_check_peer(target_name, &peer, auth_context); + if (error == GRPC_ERROR_NONE && + verify_options_->verify_peer_callback != nullptr) { + const tsi_peer_property* p = + tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + if (p == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing pem cert property."); + } else { + char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1)); + memcpy(peer_pem, p->value.data, p->value.length); + peer_pem[p->value.length] = '\0'; + int callback_status = verify_options_->verify_peer_callback( + target_name, peer_pem, + verify_options_->verify_peer_callback_userdata); + gpr_free(peer_pem); + if (callback_status) { + char* msg; + gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", + callback_status); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + } + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); } - if (status != GRPC_SECURITY_OK) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "call host does not match SSL server name"); + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_ssl_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_name_, other->target_name_); + if (c != 0) return c; + return (overridden_target_name_ == nullptr || + other->overridden_target_name_ == nullptr) + ? GPR_ICMP(overridden_target_name_, + other->overridden_target_name_) + : strcmp(overridden_target_name_, + other->overridden_target_name_); } - grpc_shallow_peer_destruct(&peer); - return true; -} -static void ssl_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + grpc_security_status status = GRPC_SECURITY_ERROR; + tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); + if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; + /* If the target name was overridden, then the original target_name was + 'checked' transitively during the previous peer check at the end of the + handshake. */ + if (overridden_target_name_ != nullptr && strcmp(host, target_name_) == 0) { + status = GRPC_SECURITY_OK; + } + if (status != GRPC_SECURITY_OK) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "call host does not match SSL server name"); + } + grpc_shallow_peer_destruct(&peer); + return true; + } -static grpc_security_connector_vtable ssl_channel_vtable = { - ssl_channel_destroy, ssl_channel_check_peer, ssl_channel_cmp}; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } -static grpc_security_connector_vtable ssl_server_vtable = { - ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp}; + private: + tsi_ssl_client_handshaker_factory* client_handshaker_factory_; + char* target_name_; + char* overridden_target_name_; + const verify_peer_options* verify_options_; +}; + +class grpc_ssl_server_security_connector + : public grpc_server_security_connector { + public: + grpc_ssl_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(GRPC_SSL_URL_SCHEME, + std::move(server_creds)) {} + + ~grpc_ssl_server_security_connector() override { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_channel_security_connector* c; - char* port; - bool has_key_cert_pair; - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.alpn_protocols = - grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + bool has_cert_config_fetcher() const { + return static_cast<const grpc_ssl_server_credentials*>(server_creds()) + ->has_cert_config_fetcher(); + } - if (config == nullptr || target_name == nullptr) { - gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); - goto error; + const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const { + return server_handshaker_factory_; } - if (config->pem_root_certs == nullptr) { - // Use default root certificates. - options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); - options.root_store = grpc_core::DefaultSslRootStore::GetRootStore(); - if (options.pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, "Could not get default pem root certs."); - goto error; + + grpc_security_status InitializeHandshakerFactory() { + if (has_cert_config_fetcher()) { + // Load initial credentials from certificate_config_fetcher: + if (!try_fetch_ssl_server_credentials()) { + gpr_log(GPR_ERROR, + "Failed loading SSL server credentials from fetcher."); + return GRPC_SECURITY_ERROR; + } + } else { + auto* server_credentials = + static_cast<const grpc_ssl_server_credentials*>(server_creds()); + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + const tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + server_credentials->config().pem_key_cert_pairs, + server_credentials->config().num_key_cert_pairs, + server_credentials->config().pem_root_certs, + grpc_get_tsi_client_certificate_request_type( + server_credentials->config().client_certificate_request), + grpc_get_ssl_cipher_suites(), alpn_protocol_strings, + static_cast<uint16_t>(num_alpn_protocols), + &server_handshaker_factory_); + gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } } - } else { - options.pem_root_certs = config->pem_root_certs; - } - c = static_cast<grpc_ssl_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_ssl_channel_security_connector))); - - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &ssl_channel_vtable; - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = ssl_channel_check_call_host; - c->base.cancel_check_call_host = ssl_channel_cancel_check_call_host; - c->base.add_handshakers = ssl_channel_add_handshakers; - gpr_split_host_port(target_name, &c->target_name, &port); - gpr_free(port); - if (overridden_target_name != nullptr) { - c->overridden_target_name = gpr_strdup(overridden_target_name); + return GRPC_SECURITY_OK; } - c->verify_options = &config->verify_options; - has_key_cert_pair = config->pem_key_cert_pair != nullptr && - config->pem_key_cert_pair->private_key != nullptr && - config->pem_key_cert_pair->cert_chain != nullptr; - if (has_key_cert_pair) { - options.pem_key_cert_pair = config->pem_key_cert_pair; + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + try_fetch_ssl_server_credentials(); + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory_, &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); } - options.cipher_suites = grpc_get_ssl_cipher_suites(); - options.session_cache = ssl_session_cache; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->client_handshaker_factory); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - ssl_channel_destroy(&c->base.base); - *sc = nullptr; - goto error; + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context); + tsi_peer_destruct(&peer); + GRPC_CLOSURE_SCHED(on_peer_checked, error); } - *sc = &c->base; - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_OK; -error: - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_ERROR; -} + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } -static grpc_ssl_server_security_connector* -grpc_ssl_server_security_connector_initialize( - grpc_server_credentials* server_creds) { - grpc_ssl_server_security_connector* c = - static_cast<grpc_ssl_server_security_connector*>( - gpr_zalloc(sizeof(grpc_ssl_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.base.vtable = &ssl_server_vtable; - c->base.add_handshakers = ssl_server_add_handshakers; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - return c; -} + private: + /* Attempts to fetch the server certificate config if a callback is available. + * Current certificate config will continue to be used if the callback returns + * an error. Returns true if new credentials were sucessfully loaded. */ + bool try_fetch_ssl_server_credentials() { + grpc_ssl_server_certificate_config* certificate_config = nullptr; + bool status; + + if (!has_cert_config_fetcher()) return false; + + grpc_ssl_server_credentials* server_creds = + static_cast<grpc_ssl_server_credentials*>(this->mutable_server_creds()); + grpc_ssl_certificate_config_reload_status cb_result = + server_creds->FetchCertConfig(&certificate_config); + if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { + gpr_log(GPR_DEBUG, "No change in SSL server credentials."); + status = false; + } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { + status = try_replace_server_handshaker_factory(certificate_config); + } else { + // Log error, continue using previously-loaded credentials. + gpr_log(GPR_ERROR, + "Failed fetching new server credentials, continuing to " + "use previously-loaded credentials."); + status = false; + } -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* gsc, grpc_server_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_server_credentials* server_credentials = - reinterpret_cast<grpc_ssl_server_credentials*>(gsc); - grpc_security_status retval = GRPC_SECURITY_OK; + if (certificate_config != nullptr) { + grpc_ssl_server_certificate_config_destroy(certificate_config); + } + return status; + } - GPR_ASSERT(server_credentials != nullptr); - GPR_ASSERT(sc != nullptr); - - grpc_ssl_server_security_connector* c = - grpc_ssl_server_security_connector_initialize(gsc); - if (server_connector_has_cert_config_fetcher(c)) { - // Load initial credentials from certificate_config_fetcher: - if (!try_fetch_ssl_server_credentials(c)) { - gpr_log(GPR_ERROR, "Failed loading SSL server credentials from fetcher."); - retval = GRPC_SECURITY_ERROR; + /* Attempts to replace the server_handshaker_factory with a new factory using + * the provided grpc_ssl_server_certificate_config. Should new factory + * creation fail, the existing factory will not be replaced. Returns true on + * success (new factory created). */ + bool try_replace_server_handshaker_factory( + const grpc_ssl_server_certificate_config* config) { + if (config == nullptr) { + gpr_log(GPR_ERROR, + "Server certificate config callback returned invalid (NULL) " + "config."); + return false; } - } else { + gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); + size_t num_alpn_protocols = 0; const char** alpn_protocol_strings = grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - result = tsi_create_ssl_server_handshaker_factory_ex( - server_credentials->config.pem_key_cert_pairs, - server_credentials->config.num_key_cert_pairs, - server_credentials->config.pem_root_certs, + tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + config->pem_key_cert_pairs, config->num_key_cert_pairs); + tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; + const grpc_ssl_server_credentials* server_creds = + static_cast<const grpc_ssl_server_credentials*>(this->server_creds()); + GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr); + tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, grpc_get_tsi_client_certificate_request_type( - server_credentials->config.client_certificate_request), + server_creds->config().client_certificate_request), grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast<uint16_t>(num_alpn_protocols), - &c->server_handshaker_factory); + static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory); + gpr_free(cert_pairs); gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - retval = GRPC_SECURITY_ERROR; + return false; } + set_server_handshaker_factory(new_handshaker_factory); + return true; + } + + void set_server_handshaker_factory( + tsi_ssl_server_handshaker_factory* new_factory) { + if (server_handshaker_factory_) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + server_handshaker_factory_ = new_factory; + } + + tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr; +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) { + if (config == nullptr || target_name == nullptr) { + gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); + return nullptr; } - if (retval == GRPC_SECURITY_OK) { - *sc = &c->base; + const char* pem_root_certs; + const tsi_ssl_root_certs_store* root_store; + if (config->pem_root_certs == nullptr) { + // Use default root certificates. + pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return nullptr; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); } else { - if (c != nullptr) ssl_server_destroy(&c->base.base); - if (sc != nullptr) *sc = nullptr; + pem_root_certs = config->pem_root_certs; + root_store = nullptr; + } + + grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c = + grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), config, + target_name, overridden_target_name); + const grpc_security_status result = c->InitializeHandshakerFactory( + config, pem_root_certs, root_store, ssl_session_cache); + if (result != GRPC_SECURITY_OK) { + return nullptr; } - return retval; + return c; +} + +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials) { + GPR_ASSERT(server_credentials != nullptr); + grpc_core::RefCountedPtr<grpc_ssl_server_security_connector> c = + grpc_core::MakeRefCounted<grpc_ssl_server_security_connector>( + std::move(server_credentials)); + const grpc_security_status retval = c->InitializeHandshakerFactory(); + if (retval != GRPC_SECURITY_OK) { + return nullptr; + } + return c; } diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h index 9b80590606..70e26e338a 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h @@ -25,6 +25,7 @@ #include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,20 +48,21 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc); + tsi_ssl_session_cache* ssl_session_cache); /* Config for ssl servers. */ typedef struct { - tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; - grpc_ssl_client_certificate_request_type client_certificate_request; + tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; + grpc_ssl_client_certificate_request_type client_certificate_request = + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE; } grpc_ssl_server_config; /* Creates an SSL server_security_connector. @@ -69,9 +71,9 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* server_credentials, - grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SSL_SSL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/ssl_utils.cc b/src/core/lib/security/security_connector/ssl_utils.cc index fbf41cfbc7..29030f07ad 100644 --- a/src/core/lib/security/security_connector/ssl_utils.cc +++ b/src/core/lib/security/security_connector/ssl_utils.cc @@ -30,6 +30,7 @@ #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/load_file.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/security_connector/load_system_roots.h" @@ -141,16 +142,17 @@ int grpc_ssl_host_matches_name(const tsi_peer* peer, const char* peer_name) { return r; } -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { +grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context( + const tsi_peer* peer) { size_t i; - grpc_auth_context* ctx = nullptr; const char* peer_identity_property_name = nullptr; /* The caller has checked the certificate type property. */ GPR_ASSERT(peer->property_count >= 1); - ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr<grpc_auth_context> ctx = + grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_SSL_TRANSPORT_SECURITY_TYPE); for (i = 0; i < peer->property_count; i++) { const tsi_peer_property* prop = &peer->properties[i]; @@ -160,24 +162,26 @@ grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { if (peer_identity_property_name == nullptr) { peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME; } - grpc_auth_context_add_property(ctx, GRPC_X509_CN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME; - grpc_auth_context_add_property(ctx, GRPC_X509_SAN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), + GRPC_X509_PEM_CERT_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY, + grpc_auth_context_add_property(ctx.get(), + GRPC_SSL_SESSION_REUSED_PROPERTY, prop->value.data, prop->value.length); } } if (peer_identity_property_name != nullptr) { GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, peer_identity_property_name) == 1); + ctx.get(), peer_identity_property_name) == 1); } return ctx; } diff --git a/src/core/lib/security/security_connector/ssl_utils.h b/src/core/lib/security/security_connector/ssl_utils.h index 6f6d473311..c9cd1a1d9c 100644 --- a/src/core/lib/security/security_connector/ssl_utils.h +++ b/src/core/lib/security/security_connector/ssl_utils.h @@ -26,6 +26,7 @@ #include <grpc/grpc_security.h> #include <grpc/slice_buffer.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,7 +48,8 @@ grpc_get_tsi_client_certificate_request_type( const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols); /* Exposed for testing only. */ -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer); +grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context( + const tsi_peer* peer); tsi_peer grpc_shallow_peer_from_ssl_auth_context( const grpc_auth_context* auth_context); void grpc_shallow_peer_destruct(tsi_peer* peer); diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 6955e8698e..66f86b8bc5 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -55,7 +55,7 @@ struct call_data { // that the memory is not initialized. void destroy() { grpc_credentials_mdelem_array_destroy(&md_array); - grpc_call_credentials_unref(creds); + creds.reset(); grpc_slice_unref_internal(host); grpc_slice_unref_internal(method); grpc_auth_metadata_context_reset(&auth_md_context); @@ -64,7 +64,7 @@ struct call_data { gpr_arena* arena; grpc_call_stack* owning_call; grpc_call_combiner* call_combiner; - grpc_call_credentials* creds = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> creds; grpc_slice host = grpc_empty_slice(); grpc_slice method = grpc_empty_slice(); /* pollset{_set} bound to this call; if we need to make external @@ -83,8 +83,18 @@ struct call_data { /* We can have a per-channel credentials. */ struct channel_data { - grpc_channel_security_connector* security_connector; - grpc_auth_context* auth_context; + channel_data(grpc_channel_security_connector* security_connector, + grpc_auth_context* auth_context) + : security_connector( + security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")), + auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {} + ~channel_data() { + security_connector.reset(DEBUG_LOCATION, "client_auth_filter"); + auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; }; } // namespace @@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset( gpr_free(const_cast<char*>(auth_md_context->method_name)); auth_md_context->method_name = nullptr; } - GRPC_AUTH_CONTEXT_UNREF( - (grpc_auth_context*)auth_md_context->channel_auth_context, - "grpc_auth_metadata_context"); - auth_md_context->channel_auth_context = nullptr; + if (auth_md_context->channel_auth_context != nullptr) { + const_cast<grpc_auth_context*>(auth_md_context->channel_auth_context) + ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context"); + auth_md_context->channel_auth_context = nullptr; + } } static void add_error(grpc_error** combined, grpc_error* error) { @@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build( auth_md_context->service_url = service_url; auth_md_context->method_name = method_name; auth_md_context->channel_auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context"); + auth_context == nullptr + ? nullptr + : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context") + .release(); gpr_free(service); gpr_free(host_and_port); } @@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); if (error != GRPC_ERROR_NONE) { - grpc_call_credentials_cancel_get_request_metadata( - calld->creds, &calld->md_array, GRPC_ERROR_REF(error)); + calld->creds->cancel_get_request_metadata(&calld->md_array, + GRPC_ERROR_REF(error)); } } @@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem, static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); grpc_call_credentials* channel_call_creds = - chand->security_connector->request_metadata_creds; + chand->security_connector->mutable_request_metadata_creds(); int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); if (channel_call_creds == nullptr && !call_creds_has_md) { @@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem, } if (channel_call_creds != nullptr && call_creds_has_md) { - calld->creds = grpc_composite_call_credentials_create(channel_call_creds, - ctx->creds, nullptr); + calld->creds = grpc_core::RefCountedPtr<grpc_call_credentials>( + grpc_composite_call_credentials_create(channel_call_creds, + ctx->creds.get(), nullptr)); if (calld->creds == nullptr) { grpc_transport_stream_op_batch_finish_with_failure( batch, @@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem, return; } } else { - calld->creds = grpc_call_credentials_ref( - call_creds_has_md ? ctx->creds : channel_call_creds); + calld->creds = + call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref(); } grpc_auth_metadata_context_build( - chand->security_connector->base.url_scheme, calld->host, calld->method, - chand->auth_context, &calld->auth_md_context); + chand->security_connector->url_scheme(), calld->host, calld->method, + chand->auth_context.get(), &calld->auth_md_context); GPR_ASSERT(calld->pollent != nullptr); GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata"); GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata, batch, grpc_schedule_on_exec_ctx); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - calld->creds, calld->pollent, calld->auth_md_context, - &calld->md_array, &calld->async_result_closure, &error)) { + if (calld->creds->get_request_metadata( + calld->pollent, calld->auth_md_context, &calld->md_array, + &calld->async_result_closure, &error)) { // Synchronous return; invoke on_credentials_metadata() directly. on_credentials_metadata(batch, error); GRPC_ERROR_UNREF(error); @@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) { call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); if (error != GRPC_ERROR_NONE) { - grpc_channel_security_connector_cancel_check_call_host( - chand->security_connector, &calld->async_result_closure, - GRPC_ERROR_REF(error)); + chand->security_connector->cancel_check_call_host( + &calld->async_result_closure, GRPC_ERROR_REF(error)); } } @@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch( GPR_ASSERT(batch->payload->context != nullptr); if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) { batch->payload->context[GRPC_CONTEXT_SECURITY].value = - grpc_client_security_context_create(calld->arena); + grpc_client_security_context_create(calld->arena, /*creds=*/nullptr); batch->payload->context[GRPC_CONTEXT_SECURITY].destroy = grpc_client_security_context_destroy; } grpc_client_security_context* sec_ctx = static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); - GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter"); + sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); sec_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter"); } if (batch->send_initial_metadata) { @@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch( grpc_schedule_on_exec_ctx); char* call_host = grpc_slice_to_c_string(calld->host); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_channel_security_connector_check_call_host( - chand->security_connector, call_host, chand->auth_context, + if (chand->security_connector->check_call_host( + call_host, chand->auth_context.get(), &calld->async_result_closure, &error)) { // Synchronous return; invoke on_host_checked() directly. on_host_checked(batch, error); @@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem, /* Constructor for channel_data */ static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { + /* The first and the last filters tend to be implemented differently to + handle the case that there's no 'next' filter to call on the up or down + path */ + GPR_ASSERT(!args->is_last); grpc_security_connector* sc = grpc_security_connector_find_in_args(args->channel_args); if (sc == nullptr) { @@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Auth context missing from client auth filter args"); } - - /* grab pointers to our data from the channel element */ - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - - /* The first and the last filters tend to be implemented differently to - handle the case that there's no 'next' filter to call on the up or down - path */ - GPR_ASSERT(!args->is_last); - - /* initialize members */ - chand->security_connector = - reinterpret_cast<grpc_channel_security_connector*>( - GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter")); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter"); + new (elem->channel_data) channel_data( + static_cast<grpc_channel_security_connector*>(sc), auth_context); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { - /* grab pointers to our data from the channel element */ channel_data* chand = static_cast<channel_data*>(elem->channel_data); - grpc_channel_security_connector* sc = chand->security_connector; - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter"); - } - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter"); + chand->~channel_data(); } const grpc_channel_filter grpc_client_auth_filter = { diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 854a1c4af9..01831dab10 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/transport/secure_endpoint.h" #include "src/core/lib/security/transport/tsi_error.h" @@ -38,34 +39,62 @@ #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 -typedef struct { +namespace { +struct security_handshaker { + security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector); + ~security_handshaker() { + gpr_mu_destroy(&mu); + tsi_handshaker_destroy(handshaker); + tsi_handshaker_result_destroy(handshaker_result); + if (endpoint_to_destroy != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy); + } + if (read_buffer_to_destroy != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy); + gpr_free(read_buffer_to_destroy); + } + gpr_free(handshake_buffer); + grpc_slice_buffer_destroy_internal(&outgoing); + auth_context.reset(DEBUG_LOCATION, "handshake"); + connector.reset(DEBUG_LOCATION, "handshake"); + } + + void Ref() { refs.Ref(); } + void Unref() { + if (refs.Unref()) { + grpc_core::Delete(this); + } + } + grpc_handshaker base; // State set at creation time. tsi_handshaker* handshaker; - grpc_security_connector* connector; + grpc_core::RefCountedPtr<grpc_security_connector> connector; gpr_mu mu; - gpr_refcount refs; + grpc_core::RefCount refs; - bool shutdown; + bool shutdown = false; // Endpoint and read buffer to destroy after a shutdown. - grpc_endpoint* endpoint_to_destroy; - grpc_slice_buffer* read_buffer_to_destroy; + grpc_endpoint* endpoint_to_destroy = nullptr; + grpc_slice_buffer* read_buffer_to_destroy = nullptr; // State saved while performing the handshake. - grpc_handshaker_args* args; - grpc_closure* on_handshake_done; + grpc_handshaker_args* args = nullptr; + grpc_closure* on_handshake_done = nullptr; - unsigned char* handshake_buffer; size_t handshake_buffer_size; + unsigned char* handshake_buffer; grpc_slice_buffer outgoing; grpc_closure on_handshake_data_sent_to_peer; grpc_closure on_handshake_data_received_from_peer; grpc_closure on_peer_checked; - grpc_auth_context* auth_context; - tsi_handshaker_result* handshaker_result; -} security_handshaker; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; + tsi_handshaker_result* handshaker_result = nullptr; +}; +} // namespace static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { size_t bytes_in_read_buffer = h->args->read_buffer->length; @@ -85,26 +114,6 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { return bytes_in_read_buffer; } -static void security_handshaker_unref(security_handshaker* h) { - if (gpr_unref(&h->refs)) { - gpr_mu_destroy(&h->mu); - tsi_handshaker_destroy(h->handshaker); - tsi_handshaker_result_destroy(h->handshaker_result); - if (h->endpoint_to_destroy != nullptr) { - grpc_endpoint_destroy(h->endpoint_to_destroy); - } - if (h->read_buffer_to_destroy != nullptr) { - grpc_slice_buffer_destroy_internal(h->read_buffer_to_destroy); - gpr_free(h->read_buffer_to_destroy); - } - gpr_free(h->handshake_buffer); - grpc_slice_buffer_destroy_internal(&h->outgoing); - GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake"); - GRPC_SECURITY_CONNECTOR_UNREF(h->connector, "handshake"); - gpr_free(h); - } -} - // Set args fields to NULL, saving the endpoint and read buffer for // later destruction. static void cleanup_args_for_failure_locked(security_handshaker* h) { @@ -194,7 +203,7 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { tsi_handshaker_result_destroy(h->handshaker_result); h->handshaker_result = nullptr; // Add auth context to channel args. - grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context); + grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get()); grpc_channel_args* tmp_args = h->args->args; h->args->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1); @@ -211,7 +220,7 @@ static void on_peer_checked(void* arg, grpc_error* error) { gpr_mu_lock(&h->mu); on_peer_checked_inner(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } static grpc_error* check_peer_locked(security_handshaker* h) { @@ -222,8 +231,8 @@ static grpc_error* check_peer_locked(security_handshaker* h) { return grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); } - grpc_security_connector_check_peer(h->connector, peer, &h->auth_context, - &h->on_peer_checked); + h->connector->check_peer(peer, h->args->endpoint, &h->auth_context, + &h->on_peer_checked); return GRPC_ERROR_NONE; } @@ -281,7 +290,7 @@ static void on_handshake_next_done_grpc_wrapper( if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -317,7 +326,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake read failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // Copy all slices received. @@ -329,7 +338,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -343,7 +352,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake write failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // We may be done. @@ -355,7 +364,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } } @@ -368,7 +377,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { static void security_handshaker_destroy(grpc_handshaker* handshaker) { security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker); - security_handshaker_unref(h); + h->Unref(); } static void security_handshaker_shutdown(grpc_handshaker* handshaker, @@ -393,14 +402,14 @@ static void security_handshaker_do_handshake(grpc_handshaker* handshaker, gpr_mu_lock(&h->mu); h->args = args; h->on_handshake_done = on_handshake_done; - gpr_ref(&h->refs); + h->Ref(); size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); grpc_error* error = do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } gpr_mu_unlock(&h->mu); @@ -410,27 +419,32 @@ static const grpc_handshaker_vtable security_handshaker_vtable = { security_handshaker_destroy, security_handshaker_shutdown, security_handshaker_do_handshake, "security"}; -static grpc_handshaker* security_handshaker_create( - tsi_handshaker* handshaker, grpc_security_connector* connector) { - security_handshaker* h = static_cast<security_handshaker*>( - gpr_zalloc(sizeof(security_handshaker))); - grpc_handshaker_init(&security_handshaker_vtable, &h->base); - h->handshaker = handshaker; - h->connector = GRPC_SECURITY_CONNECTOR_REF(connector, "handshake"); - gpr_mu_init(&h->mu); - gpr_ref_init(&h->refs, 1); - h->handshake_buffer_size = GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE; - h->handshake_buffer = - static_cast<uint8_t*>(gpr_malloc(h->handshake_buffer_size)); - GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer, - on_handshake_data_sent_to_peer, h, +namespace { +security_handshaker::security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector) + : handshaker(handshaker), + connector(connector->Ref(DEBUG_LOCATION, "handshake")), + handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), + handshake_buffer( + static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size))) { + grpc_handshaker_init(&security_handshaker_vtable, &base); + gpr_mu_init(&mu); + grpc_slice_buffer_init(&outgoing); + GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer, + ::on_handshake_data_sent_to_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer, - on_handshake_data_received_from_peer, h, + GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer, + ::on_handshake_data_received_from_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_peer_checked, on_peer_checked, h, + GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this, grpc_schedule_on_exec_ctx); - grpc_slice_buffer_init(&h->outgoing); +} +} // namespace + +static grpc_handshaker* security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector) { + security_handshaker* h = + grpc_core::New<security_handshaker>(handshaker, connector); return &h->base; } @@ -477,8 +491,9 @@ static void client_handshaker_factory_add_handshakers( grpc_channel_security_connector* security_connector = reinterpret_cast<grpc_channel_security_connector*>( grpc_security_connector_find_in_args(args)); - grpc_channel_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void server_handshaker_factory_add_handshakers( @@ -488,8 +503,9 @@ static void server_handshaker_factory_add_handshakers( grpc_server_security_connector* security_connector = reinterpret_cast<grpc_server_security_connector*>( grpc_security_connector_find_in_args(args)); - grpc_server_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void handshaker_factory_destroy( diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 362f49a584..f93eb4275e 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -39,8 +39,12 @@ enum async_state { }; struct channel_data { - grpc_auth_context* auth_context; - grpc_server_credentials* creds; + channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds) + : auth_context(auth_context->Ref()), creds(creds->Ref()) {} + ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); } + + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; + grpc_core::RefCountedPtr<grpc_server_credentials> creds; }; struct call_data { @@ -58,7 +62,7 @@ struct call_data { grpc_server_security_context_create(args.arena); channel_data* chand = static_cast<channel_data*>(elem->channel_data); server_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter"); if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) { args.context[GRPC_CONTEXT_SECURITY].destroy( args.context[GRPC_CONTEXT_SECURITY].value); @@ -208,7 +212,8 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { call_data* calld = static_cast<call_data*>(elem->call_data); grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; if (error == GRPC_ERROR_NONE) { - if (chand->creds != nullptr && chand->creds->processor.process != nullptr) { + if (chand->creds != nullptr && + chand->creds->auth_metadata_processor().process != nullptr) { // We're calling out to the application, so we need to make sure // to drop the call combiner early if we get cancelled. GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem, @@ -218,9 +223,10 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata"); calld->md = metadata_batch_to_md_array( batch->payload->recv_initial_metadata.recv_initial_metadata); - chand->creds->processor.process( - chand->creds->processor.state, chand->auth_context, - calld->md.metadata, calld->md.count, on_md_processing_done, elem); + chand->creds->auth_metadata_processor().process( + chand->creds->auth_metadata_processor().state, + chand->auth_context.get(), calld->md.metadata, calld->md.count, + on_md_processing_done, elem); return; } } @@ -290,23 +296,19 @@ static void destroy_call_elem(grpc_call_element* elem, static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { GPR_ASSERT(!args->is_last); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args->channel_args); GPR_ASSERT(auth_context != nullptr); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter"); grpc_server_credentials* creds = grpc_find_server_credentials_in_args(args->channel_args); - chand->creds = grpc_server_credentials_ref(creds); + new (elem->channel_data) channel_data(auth_context, creds); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter"); - grpc_server_credentials_unref(chand->creds); + chand->~channel_data(); } const grpc_channel_filter grpc_server_auth_filter = { diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc index 67b38e6f0c..7ae6e51a5f 100644 --- a/src/core/lib/surface/server.cc +++ b/src/core/lib/surface/server.cc @@ -194,13 +194,10 @@ struct call_data { }; struct request_matcher { - request_matcher(grpc_server* server); - ~request_matcher(); - grpc_server* server; - std::atomic<call_data*> pending_head{nullptr}; - call_data* pending_tail = nullptr; - gpr_locked_mpscq* requests_per_cq = nullptr; + call_data* pending_head; + call_data* pending_tail; + gpr_locked_mpscq* requests_per_cq; }; struct registered_method { @@ -349,30 +346,22 @@ static void channel_broadcaster_shutdown(channel_broadcaster* cb, * request_matcher */ -namespace { -request_matcher::request_matcher(grpc_server* server) : server(server) { - requests_per_cq = static_cast<gpr_locked_mpscq*>( - gpr_malloc(sizeof(*requests_per_cq) * server->cq_count)); - for (size_t i = 0; i < server->cq_count; i++) { - gpr_locked_mpscq_init(&requests_per_cq[i]); - } -} - -request_matcher::~request_matcher() { +static void request_matcher_init(request_matcher* rm, grpc_server* server) { + memset(rm, 0, sizeof(*rm)); + rm->server = server; + rm->requests_per_cq = static_cast<gpr_locked_mpscq*>( + gpr_malloc(sizeof(*rm->requests_per_cq) * server->cq_count)); for (size_t i = 0; i < server->cq_count; i++) { - GPR_ASSERT(gpr_locked_mpscq_pop(&requests_per_cq[i]) == nullptr); - gpr_locked_mpscq_destroy(&requests_per_cq[i]); + gpr_locked_mpscq_init(&rm->requests_per_cq[i]); } - gpr_free(requests_per_cq); -} -} // namespace - -static void request_matcher_init(request_matcher* rm, grpc_server* server) { - new (rm) request_matcher(server); } static void request_matcher_destroy(request_matcher* rm) { - rm->~request_matcher(); + for (size_t i = 0; i < rm->server->cq_count; i++) { + GPR_ASSERT(gpr_locked_mpscq_pop(&rm->requests_per_cq[i]) == nullptr); + gpr_locked_mpscq_destroy(&rm->requests_per_cq[i]); + } + gpr_free(rm->requests_per_cq); } static void kill_zombie(void* elem, grpc_error* error) { @@ -381,10 +370,9 @@ static void kill_zombie(void* elem, grpc_error* error) { } static void request_matcher_zombify_all_pending_calls(request_matcher* rm) { - call_data* calld; - while ((calld = rm->pending_head.load(std::memory_order_relaxed)) != - nullptr) { - rm->pending_head.store(calld->pending_next, std::memory_order_relaxed); + while (rm->pending_head) { + call_data* calld = rm->pending_head; + rm->pending_head = calld->pending_next; gpr_atm_no_barrier_store(&calld->state, ZOMBIED); GRPC_CLOSURE_INIT( &calld->kill_zombie_closure, kill_zombie, @@ -582,9 +570,8 @@ static void publish_new_rpc(void* arg, grpc_error* error) { } gpr_atm_no_barrier_store(&calld->state, PENDING); - if (rm->pending_head.load(std::memory_order_relaxed) == nullptr) { - rm->pending_head.store(calld, std::memory_order_relaxed); - rm->pending_tail = calld; + if (rm->pending_head == nullptr) { + rm->pending_tail = rm->pending_head = calld; } else { rm->pending_tail->pending_next = calld; rm->pending_tail = calld; @@ -1448,39 +1435,30 @@ static grpc_call_error queue_call_request(grpc_server* server, size_t cq_idx, rm = &rc->data.registered.method->matcher; break; } - - // Fast path: if there is no pending request to be processed, immediately - // return. - if (!gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link) || - // Note: We are reading the pending_head without holding the server's call - // mutex. Even if we read a non-null value here due to reordering, - // we will check it below again after grabbing the lock. - rm->pending_head.load(std::memory_order_relaxed) == nullptr) { - return GRPC_CALL_OK; - } - // Slow path: This was the first queued request and there are pendings: - // We need to lock and start matching calls. - gpr_mu_lock(&server->mu_call); - while ((calld = rm->pending_head.load(std::memory_order_relaxed)) != - nullptr) { - rc = reinterpret_cast<requested_call*>( - gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx])); - if (rc == nullptr) break; - rm->pending_head.store(calld->pending_next, std::memory_order_relaxed); - gpr_mu_unlock(&server->mu_call); - if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) { - // Zombied Call - GRPC_CLOSURE_INIT( - &calld->kill_zombie_closure, kill_zombie, - grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0), - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE); - } else { - publish_call(server, calld, cq_idx, rc); - } + if (gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link)) { + /* this was the first queued request: we need to lock and start + matching calls */ gpr_mu_lock(&server->mu_call); + while ((calld = rm->pending_head) != nullptr) { + rc = reinterpret_cast<requested_call*>( + gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx])); + if (rc == nullptr) break; + rm->pending_head = calld->pending_next; + gpr_mu_unlock(&server->mu_call); + if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) { + // Zombied Call + GRPC_CLOSURE_INIT( + &calld->kill_zombie_closure, kill_zombie, + grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0), + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE); + } else { + publish_call(server, calld, cq_idx, rc); + } + gpr_mu_lock(&server->mu_call); + } + gpr_mu_unlock(&server->mu_call); } - gpr_mu_unlock(&server->mu_call); return GRPC_CALL_OK; } diff --git a/src/core/lib/surface/version.cc b/src/core/lib/surface/version.cc index 4829cc80a5..70d7580bec 100644 --- a/src/core/lib/surface/version.cc +++ b/src/core/lib/surface/version.cc @@ -25,4 +25,4 @@ const char* grpc_version_string(void) { return "7.0.0-dev"; } -const char* grpc_g_stands_for(void) { return "goose"; } +const char* grpc_g_stands_for(void) { return "gold"; } diff --git a/src/core/lib/transport/metadata.cc b/src/core/lib/transport/metadata.cc index 60af22393e..30482a1b3b 100644 --- a/src/core/lib/transport/metadata.cc +++ b/src/core/lib/transport/metadata.cc @@ -187,6 +187,7 @@ static void gc_mdtab(mdtab_shard* shard) { ((destroy_user_data_func)gpr_atm_no_barrier_load( &md->destroy_user_data))(user_data); } + gpr_mu_destroy(&md->mu_user_data); gpr_free(md); *prev_next = next; num_freed++; diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index d6a72ada0d..fb6ea19210 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -156,9 +156,13 @@ static unsigned long openssl_thread_id_cb(void) { #endif static void init_openssl(void) { +#if OPENSSL_API_COMPAT >= 0x10100000L + OPENSSL_init_ssl(0, NULL); +#else SSL_library_init(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); +#endif #if OPENSSL_VERSION_NUMBER < 0x10100000 if (!CRYPTO_get_locking_callback()) { int num_locks = CRYPTO_num_locks(); @@ -1649,7 +1653,11 @@ tsi_result tsi_create_ssl_client_handshaker_factory_with_options( return TSI_INVALID_ARGUMENT; } +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + ssl_context = SSL_CTX_new(TLS_method()); +#else ssl_context = SSL_CTX_new(TLSv1_2_method()); +#endif if (ssl_context == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); return TSI_INVALID_ARGUMENT; @@ -1806,7 +1814,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options( for (i = 0; i < options->num_key_cert_pairs; i++) { do { +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + impl->ssl_contexts[i] = SSL_CTX_new(TLS_method()); +#else impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method()); +#endif if (impl->ssl_contexts[i] == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); result = TSI_OUT_OF_RESOURCES; @@ -1850,31 +1862,30 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options( break; } SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names); - switch (options->client_certificate_request) { - case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); - break; - case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, - NullVerifyCallback); - break; - case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); - break; - case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: - SSL_CTX_set_verify( - impl->ssl_contexts[i], - SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, - NullVerifyCallback); - break; - case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: - SSL_CTX_set_verify( - impl->ssl_contexts[i], - SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - break; - } - /* TODO(jboeuf): Add revocation verification. */ } + switch (options->client_certificate_request) { + case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, + NullVerifyCallback); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + NullVerifyCallback); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + break; + } + /* TODO(jboeuf): Add revocation verification. */ result = extract_x509_subject_names_from_pem_cert( options->pem_key_cert_pairs[i].cert_chain, diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index d0abe441a6..4d0ed355ab 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -261,10 +261,10 @@ void MetadataCredentialsPluginWrapper::InvokePlugin( grpc_status_code* status_code, const char** error_details) { std::multimap<grpc::string, grpc::string> metadata; - // const_cast is safe since the SecureAuthContext does not take owndership and - // the object is passed as a const ref to plugin_->GetMetadata. + // const_cast is safe since the SecureAuthContext only inc/dec the refcount + // and the object is passed as a const ref to plugin_->GetMetadata. SecureAuthContext cpp_channel_auth_context( - const_cast<grpc_auth_context*>(context.channel_auth_context), false); + const_cast<grpc_auth_context*>(context.channel_auth_context)); Status status = plugin_->GetMetadata(context.service_url, context.method_name, cpp_channel_auth_context, &metadata); diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index 613f1d6dc2..4918bd5a4d 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -24,6 +24,7 @@ #include <grpcpp/security/credentials.h> #include <grpcpp/support/config.h> +#include "src/core/lib/security/credentials/credentials.h" #include "src/cpp/server/thread_pool_interface.h" namespace grpc { @@ -31,7 +32,9 @@ namespace grpc { class SecureChannelCredentials final : public ChannelCredentials { public: explicit SecureChannelCredentials(grpc_channel_credentials* c_creds); - ~SecureChannelCredentials() { grpc_channel_credentials_release(c_creds_); } + ~SecureChannelCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_channel_credentials* GetRawCreds() { return c_creds_; } std::shared_ptr<grpc::Channel> CreateChannel( @@ -51,7 +54,9 @@ class SecureChannelCredentials final : public ChannelCredentials { class SecureCallCredentials final : public CallCredentials { public: explicit SecureCallCredentials(grpc_call_credentials* c_creds); - ~SecureCallCredentials() { grpc_call_credentials_release(c_creds_); } + ~SecureCallCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_call_credentials* GetRawCreds() { return c_creds_; } bool ApplyToCall(grpc_call* call) override; diff --git a/src/cpp/common/alarm.cc b/src/cpp/common/alarm.cc index 5819a4210b..148f0b9bc9 100644 --- a/src/cpp/common/alarm.cc +++ b/src/cpp/common/alarm.cc @@ -31,10 +31,10 @@ #include <grpc/support/log.h> #include "src/core/lib/debug/trace.h" -namespace grpc { +namespace grpc_impl { namespace internal { -class AlarmImpl : public CompletionQueueTag { +class AlarmImpl : public ::grpc::internal::CompletionQueueTag { public: AlarmImpl() : cq_(nullptr), tag_(nullptr) { gpr_ref_init(&refs_, 1); @@ -51,7 +51,7 @@ class AlarmImpl : public CompletionQueueTag { Unref(); return true; } - void Set(CompletionQueue* cq, gpr_timespec deadline, void* tag) { + void Set(::grpc::CompletionQueue* cq, gpr_timespec deadline, void* tag) { grpc_core::ExecCtx exec_ctx; GRPC_CQ_INTERNAL_REF(cq->cq(), "alarm"); cq_ = cq->cq(); @@ -114,13 +114,14 @@ class AlarmImpl : public CompletionQueueTag { }; } // namespace internal -static internal::GrpcLibraryInitializer g_gli_initializer; +static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer; Alarm::Alarm() : alarm_(new internal::AlarmImpl()) { g_gli_initializer.summon(); } -void Alarm::SetInternal(CompletionQueue* cq, gpr_timespec deadline, void* tag) { +void Alarm::SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline, + void* tag) { // Note that we know that alarm_ is actually an internal::AlarmImpl // but we declared it as the base pointer to avoid a forward declaration // or exposing core data structures in the C++ public headers. @@ -145,4 +146,4 @@ Alarm::~Alarm() { } void Alarm::Cancel() { static_cast<internal::AlarmImpl*>(alarm_)->Cancel(); } -} // namespace grpc +} // namespace grpc_impl diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc index 50ee9d871f..214d72f853 100644 --- a/src/cpp/common/channel_arguments.cc +++ b/src/cpp/common/channel_arguments.cc @@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) { } if (!replaced) { + strings_.push_back(grpc::string(mutator_arg.key)); args_.push_back(mutator_arg); + args_.back().key = const_cast<char*>(strings_.back().c_str()); } } diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc index 1d66dd3d1f..7a2b5afed6 100644 --- a/src/cpp/common/secure_auth_context.cc +++ b/src/cpp/common/secure_auth_context.cc @@ -22,19 +22,12 @@ namespace grpc { -SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx, - bool take_ownership) - : ctx_(ctx), take_ownership_(take_ownership) {} - -SecureAuthContext::~SecureAuthContext() { - if (take_ownership_) grpc_auth_context_release(ctx_); -} - std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector<grpc::string_ref>(); } - grpc_auth_property_iterator iter = grpc_auth_context_peer_identity(ctx_); + grpc_auth_property_iterator iter = + grpc_auth_context_peer_identity(ctx_.get()); std::vector<grpc::string_ref> identity; const grpc_auth_property* property = nullptr; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -45,20 +38,20 @@ std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { } grpc::string SecureAuthContext::GetPeerIdentityPropertyName() const { - if (!ctx_) { + if (ctx_ == nullptr) { return ""; } - const char* name = grpc_auth_context_peer_identity_property_name(ctx_); + const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get()); return name == nullptr ? "" : name; } std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues( const grpc::string& name) const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector<grpc::string_ref>(); } grpc_auth_property_iterator iter = - grpc_auth_context_find_properties_by_name(ctx_, name.c_str()); + grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str()); const grpc_auth_property* property = nullptr; std::vector<grpc::string_ref> values; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -68,9 +61,9 @@ std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues( } AuthPropertyIterator SecureAuthContext::begin() const { - if (ctx_) { + if (ctx_ != nullptr) { grpc_auth_property_iterator iter = - grpc_auth_context_property_iterator(ctx_); + grpc_auth_context_property_iterator(ctx_.get()); const grpc_auth_property* property = grpc_auth_property_iterator_next(&iter); return AuthPropertyIterator(property, &iter); @@ -85,19 +78,20 @@ AuthPropertyIterator SecureAuthContext::end() const { void SecureAuthContext::AddProperty(const grpc::string& key, const grpc::string_ref& value) { - if (!ctx_) return; - grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size()); + if (ctx_ == nullptr) return; + grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(), + value.size()); } bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) { - if (!ctx_) return false; - return grpc_auth_context_set_peer_identity_property_name(ctx_, + if (ctx_ == nullptr) return false; + return grpc_auth_context_set_peer_identity_property_name(ctx_.get(), name.c_str()) != 0; } bool SecureAuthContext::IsPeerAuthenticated() const { - if (!ctx_) return false; - return grpc_auth_context_peer_is_authenticated(ctx_) != 0; + if (ctx_ == nullptr) return false; + return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0; } } // namespace grpc diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h index 142617959c..2e8f793721 100644 --- a/src/cpp/common/secure_auth_context.h +++ b/src/cpp/common/secure_auth_context.h @@ -21,15 +21,17 @@ #include <grpcpp/security/auth_context.h> -struct grpc_auth_context; +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" namespace grpc { class SecureAuthContext final : public AuthContext { public: - SecureAuthContext(grpc_auth_context* ctx, bool take_ownership); + explicit SecureAuthContext(grpc_auth_context* ctx) + : ctx_(ctx != nullptr ? ctx->Ref() : nullptr) {} - ~SecureAuthContext() override; + ~SecureAuthContext() override = default; bool IsPeerAuthenticated() const override; @@ -50,8 +52,7 @@ class SecureAuthContext final : public AuthContext { virtual bool SetPeerIdentityPropertyName(const grpc::string& name) override; private: - grpc_auth_context* ctx_; - bool take_ownership_; + grpc_core::RefCountedPtr<grpc_auth_context> ctx_; }; } // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc index bc1387c8d7..908c46629e 100644 --- a/src/cpp/common/secure_create_auth_context.cc +++ b/src/cpp/common/secure_create_auth_context.cc @@ -20,6 +20,7 @@ #include <grpc/grpc.h> #include <grpc/grpc_security.h> #include <grpcpp/security/auth_context.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/cpp/common/secure_auth_context.h" namespace grpc { @@ -28,8 +29,8 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) { if (call == nullptr) { return std::shared_ptr<const AuthContext>(); } - return std::shared_ptr<const AuthContext>( - new SecureAuthContext(grpc_call_auth_context(call), true)); + grpc_core::RefCountedPtr<grpc_auth_context> ctx(grpc_call_auth_context(call)); + return std::make_shared<SecureAuthContext>(ctx.get()); } } // namespace grpc diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc index 55da89e6c8..358131c7c4 100644 --- a/src/cpp/common/version_cc.cc +++ b/src/cpp/common/version_cc.cc @@ -22,5 +22,5 @@ #include <grpcpp/grpcpp.h> namespace grpc { -grpc::string Version() { return "1.18.0-dev"; } +grpc::string Version() { return "1.19.0-dev"; } } // namespace grpc diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc index 78fc69a805..160590353a 100644 --- a/src/cpp/ext/filters/census/context.cc +++ b/src/cpp/ext/filters/census/context.cc @@ -28,6 +28,9 @@ using ::opencensus::trace::SpanContext; void GenerateServerContext(absl::string_view tracing, absl::string_view stats, absl::string_view primary_role, absl::string_view method, CensusContext* context) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + context->~CensusContext(); GrpcTraceContext trace_ctxt; if (TraceContextEncoding::Decode(tracing, &trace_ctxt) != TraceContextEncoding::kEncodeDecodeFailure) { @@ -42,6 +45,9 @@ void GenerateServerContext(absl::string_view tracing, absl::string_view stats, void GenerateClientContext(absl::string_view method, CensusContext* ctxt, CensusContext* parent_ctxt) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + ctxt->~CensusContext(); if (parent_ctxt != nullptr) { SpanContext span_ctxt = parent_ctxt->Context(); Span span = parent_ctxt->Span(); diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index ebb17def32..453e76eb25 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -61,7 +61,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor( metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key), StringRefFromSlice(&md[i].value))); } - SecureAuthContext context(ctx, false); + SecureAuthContext context(ctx); AuthMetadataProcessor::OutputMetadata consumed_metadata; AuthMetadataProcessor::OutputMetadata response_metadata; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 1e3c57446f..13741ce7aa 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -278,7 +278,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { request_payload_ = nullptr; interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - interceptor_methods_.SetRecvMessage(request_); + interceptor_methods_.SetRecvMessage(request_, nullptr); } if (interceptor_methods_.RunInterceptors( @@ -446,7 +446,7 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { req_->request_payload_ = nullptr; req_->interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - req_->interceptor_methods_.SetRecvMessage(req_->request_); + req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr); } if (req_->interceptor_methods_.RunInterceptors( diff --git a/src/csharp/Grpc.Core/Version.csproj.include b/src/csharp/Grpc.Core/Version.csproj.include index 4fffe4f644..52ab2215eb 100755 --- a/src/csharp/Grpc.Core/Version.csproj.include +++ b/src/csharp/Grpc.Core/Version.csproj.include @@ -1,7 +1,7 @@ <!-- This file is generated --> <Project> <PropertyGroup> - <GrpcCsharpVersion>1.18.0-dev</GrpcCsharpVersion> + <GrpcCsharpVersion>1.19.0-dev</GrpcCsharpVersion> <GoogleProtobufVersion>3.6.1</GoogleProtobufVersion> </PropertyGroup> </Project> diff --git a/src/csharp/Grpc.Core/VersionInfo.cs b/src/csharp/Grpc.Core/VersionInfo.cs index 633880189c..8f3be310ee 100644 --- a/src/csharp/Grpc.Core/VersionInfo.cs +++ b/src/csharp/Grpc.Core/VersionInfo.cs @@ -33,11 +33,11 @@ namespace Grpc.Core /// <summary> /// Current <c>AssemblyFileVersion</c> of gRPC C# assemblies /// </summary> - public const string CurrentAssemblyFileVersion = "1.18.0.0"; + public const string CurrentAssemblyFileVersion = "1.19.0.0"; /// <summary> /// Current version of gRPC C# /// </summary> - public const string CurrentVersion = "1.18.0-dev"; + public const string CurrentVersion = "1.19.0-dev"; } } diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs index e83a8a7274..4750353082 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs @@ -45,7 +45,7 @@ namespace Grpc.IntegrationTesting [Option("server_host", Default = "localhost")] public string ServerHost { get; set; } - [Option("server_host_override", Default = TestCredentials.DefaultHostOverride)] + [Option("server_host_override")] public string ServerHostOverride { get; set; } [Option("server_port", Required = true)] diff --git a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets index 5f76c03ce5..3fe1ccc918 100644 --- a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets @@ -22,9 +22,8 @@ <Target Name="gRPC_ResolvePluginFullPath" AfterTargets="Protobuf_ResolvePlatform"> <PropertyGroup> <!-- TODO(kkm): Do not use Protobuf_PackagedToolsPath, roll gRPC's own. --> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/$(gRPC_PluginFileName)</gRPC_PluginFullPath> </PropertyGroup> diff --git a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets index 1d233d23a8..26f9efb5a8 100644 --- a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets @@ -74,9 +74,8 @@ <!-- Next try OS and CPU resolved by ProtoToolsPlatform. --> <Protobuf_ToolsOs Condition=" '$(Protobuf_ToolsOs)' == '' ">$(_Protobuf_ToolsOs)</Protobuf_ToolsOs> <Protobuf_ToolsCpu Condition=" '$(Protobuf_ToolsCpu)' == '' ">$(_Protobuf_ToolsCpu)</Protobuf_ToolsCpu> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\protoc.exe</Protobuf_ProtocFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\protoc.exe</Protobuf_ProtocFullPath> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/protoc</Protobuf_ProtocFullPath> </PropertyGroup> diff --git a/src/csharp/build_packages_dotnetcli.bat b/src/csharp/build_packages_dotnetcli.bat index 76d4f14390..fef1a43bb8 100755 --- a/src/csharp/build_packages_dotnetcli.bat +++ b/src/csharp/build_packages_dotnetcli.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.18.0-dev +set VERSION=1.19.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/csharp/build_unitypackage.bat b/src/csharp/build_unitypackage.bat index 3334d24c11..6b66b941a8 100644 --- a/src/csharp/build_unitypackage.bat +++ b/src/csharp/build_unitypackage.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.18.0-dev +set VERSION=1.19.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec index 55ca6048bc..659cfebbdc 100644 --- a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec +++ b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec @@ -42,7 +42,7 @@ Pod::Spec.new do |s| # exclamation mark ensures that other "regular" pods will be able to find it as it'll be installed # before them. s.name = '!ProtoCompiler-gRPCPlugin' - v = '1.18.0-dev' + v = '1.19.0-dev' s.version = v s.summary = 'The gRPC ProtoC plugin generates Objective-C files from .proto services.' s.description = <<-DESC diff --git a/src/objective-c/GRPCClient/private/version.h b/src/objective-c/GRPCClient/private/version.h index 0be0e3c9a0..5e089fde31 100644 --- a/src/objective-c/GRPCClient/private/version.h +++ b/src/objective-c/GRPCClient/private/version.h @@ -22,4 +22,4 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.19.0-dev" diff --git a/src/objective-c/README.md b/src/objective-c/README.md index 32e3956a1e..83775f86e1 100644 --- a/src/objective-c/README.md +++ b/src/objective-c/README.md @@ -242,3 +242,12 @@ pod `gRPC-Core`, :podspec => "." # assuming gRPC-Core.podspec is in the same dir These steps should allow gRPC to use OpenSSL and drop BoringSSL dependency. If you see any issue, file an issue to us. + +## Upgrade issue with BoringSSL +If you were using an old version of gRPC (<= v1.14) which depended on pod `BoringSSL` rather than +`BoringSSL-GRPC` and meet issue with the library like: +``` +ld: framework not found openssl +``` +updating `-framework openssl` in Other Linker Flags to `-framework openssl_grpc` in your project +may resolve this issue (see [#16821](https://github.com/grpc/grpc/issues/16821)). diff --git a/src/objective-c/tests/version.h b/src/objective-c/tests/version.h index f2fd692070..54f95ad16a 100644 --- a/src/objective-c/tests/version.h +++ b/src/objective-c/tests/version.h @@ -22,5 +22,5 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.19.0-dev" #define GRPC_C_VERSION_STRING @"7.0.0-dev" diff --git a/src/php/composer.json b/src/php/composer.json index 9c298c0e85..75fab483f1 100644 --- a/src/php/composer.json +++ b/src/php/composer.json @@ -2,7 +2,7 @@ "name": "grpc/grpc-dev", "description": "gRPC library for PHP - for Developement use only", "license": "Apache-2.0", - "version": "1.18.0", + "version": "1.19.0", "require": { "php": ">=5.5.0", "google/protobuf": "^v3.3.0" diff --git a/src/php/ext/grpc/version.h b/src/php/ext/grpc/version.h index 1ddf90a667..c85ee4d315 100644 --- a/src/php/ext/grpc/version.h +++ b/src/php/ext/grpc/version.h @@ -20,6 +20,6 @@ #ifndef VERSION_H #define VERSION_H -#define PHP_GRPC_VERSION "1.18.0dev" +#define PHP_GRPC_VERSION "1.19.0dev" #endif /* VERSION_H */ diff --git a/src/php/tests/interop/interop_client.php b/src/php/tests/interop/interop_client.php index c865678f70..19cbf21bc2 100755 --- a/src/php/tests/interop/interop_client.php +++ b/src/php/tests/interop/interop_client.php @@ -530,7 +530,7 @@ function _makeStub($args) throw new Exception('Missing argument: --test_case is required'); } - if ($args['server_port'] === 443) { + if ($args['server_port'] === '443') { $server_address = $args['server_host']; } else { $server_address = $args['server_host'].':'.$args['server_port']; @@ -538,7 +538,7 @@ function _makeStub($args) $test_case = $args['test_case']; - $host_override = 'foo.test.google.fr'; + $host_override = ''; if (array_key_exists('server_host_override', $args)) { $host_override = $args['server_host_override']; } @@ -565,7 +565,9 @@ function _makeStub($args) $ssl_credentials = Grpc\ChannelCredentials::createSsl(); } $opts['credentials'] = $ssl_credentials; - $opts['grpc.ssl_target_name_override'] = $host_override; + if (!empty($host_override)) { + $opts['grpc.ssl_target_name_override'] = $host_override; + } } else { $opts['credentials'] = Grpc\ChannelCredentials::createInsecure(); } diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 6022fc3ef2..70d7618e05 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -23,6 +23,11 @@ from grpc._cython import cygrpc as _cygrpc logging.getLogger(__name__).addHandler(logging.NullHandler()) +try: + from ._grpcio_metadata import __version__ +except ImportError: + __version__ = "dev0" + ############################## Future Interface ############################### @@ -266,6 +271,22 @@ class StatusCode(enum.Enum): UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated') +############################# gRPC Status ################################ + + +class Status(six.with_metaclass(abc.ABCMeta)): + """Describes the status of an RPC. + + This is an EXPERIMENTAL API. + + Attributes: + code: A StatusCode object to be sent to the client. + details: An ASCII-encodable string to be sent to the client upon + termination of the RPC. + trailing_metadata: The trailing :term:`metadata` in the RPC. + """ + + ############################# gRPC Exceptions ################################ @@ -1119,6 +1140,25 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): raise NotImplementedError() @abc.abstractmethod + def abort_with_status(self, status): + """Raises an exception to terminate the RPC with a non-OK status. + + The status passed as argument will supercede any existing status code, + status message and trailing metadata. + + This is an EXPERIMENTAL API. + + Args: + status: A grpc.Status object. The status code in it must not be + StatusCode.OK. + + Raises: + Exception: An exception is always raised to signal the abortion the + RPC to the gRPC runtime. + """ + raise NotImplementedError() + + @abc.abstractmethod def set_code(self, code): """Sets the value to be used as status code upon RPC completion. @@ -1747,6 +1787,7 @@ __all__ = ( 'Future', 'ChannelConnectivity', 'StatusCode', + 'Status', 'RpcError', 'RpcContext', 'Call', diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index c17824563d..9b990f490d 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -46,7 +46,7 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin): # Hack to determine if these are JWT creds and we need to pass # additional_claims when getting a token - self._is_jwt = 'additional_claims' in inspect.getargspec( + self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method credentials.get_access_token).args def __call__(self, context, callback): diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 35fa82d56b..8051fb306c 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -499,6 +499,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def _prepare(self, request, timeout, metadata, wait_for_ready): deadline, serialized_request, rendezvous = _start_unary_request( @@ -525,17 +526,18 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready) if state is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: call = self._channel.segregated_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, (( operations, None, - ),)) + ),), self._context) event = call.next_event() _handle_event(event, state, self._response_deserializer) - return state, call, + return state, call def __call__(self, request, @@ -566,13 +568,14 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready) if state is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - (operations,), event_handler) + (operations,), event_handler, self._context) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -587,6 +590,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def __call__(self, request, @@ -599,7 +603,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) if serialized_request is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) operationses = ( @@ -615,9 +619,10 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): ) event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - operationses, event_handler) + operationses, event_handler, self._context) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -632,6 +637,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def _blocking(self, request_iterator, timeout, metadata, credentials, wait_for_ready): @@ -640,10 +646,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) call = self._channel.segregated_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses_and_tags( - metadata, initial_metadata_flags)) + metadata, initial_metadata_flags), self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, None) while True: @@ -653,7 +660,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): state.condition.notify_all() if not state.due: break - return state, call, + return state, call def __call__(self, request_iterator, @@ -687,10 +694,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses( - metadata, initial_metadata_flags), event_handler) + metadata, initial_metadata_flags), event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -706,6 +714,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def __call__(self, request_iterator, @@ -727,9 +736,10 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): ) event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, operationses, - event_handler) + event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -745,10 +755,10 @@ class _InitialMetadataFlags(int): def with_wait_for_ready(self, wait_for_ready): if wait_for_ready is not None: if wait_for_ready: - self = self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ + return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) elif not wait_for_ready: - self = self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ + return self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) return self @@ -789,7 +799,7 @@ def _channel_managed_call_management(state): # pylint: disable=too-many-arguments def create(flags, method, host, deadline, metadata, credentials, - operationses, event_handler): + operationses, event_handler, context): """Creates a cygrpc.IntegratedCall. Args: @@ -804,7 +814,7 @@ def _channel_managed_call_management(state): started on the call. event_handler: A behavior to call to handle the events resultant from the operations on the call. - + context: Context object for distributed tracing. Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ @@ -815,7 +825,7 @@ def _channel_managed_call_management(state): with state.lock: call = state.channel.integrated_call(flags, method, host, deadline, metadata, credentials, - operationses_and_tags) + operationses_and_tags, context) if state.managed_calls == 0: state.managed_calls = 1 _run_channel_spin_thread(state) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi index e0e068e452..01b8237484 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi @@ -28,19 +28,22 @@ cdef tuple _wrap_grpc_arg(grpc_arg arg) cdef grpc_arg _unwrap_grpc_arg(tuple wrapped_arg) -cdef class _ArgumentProcessor: +cdef class _ChannelArg: cdef grpc_arg c_argument cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references) except * -cdef class _ArgumentsProcessor: +cdef class _ChannelArgs: cdef readonly tuple _arguments - cdef list _argument_processors + cdef list _channel_args cdef readonly list _references cdef grpc_channel_args _c_arguments - cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable) except * - cdef un_c(self) + cdef void _c(self, grpc_arg_pointer_vtable *vtable) except * + cdef grpc_channel_args *c_args(self) except * + + @staticmethod + cdef _ChannelArgs from_args(object arguments, grpc_arg_pointer_vtable * vtable) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi index b7a4277ff6..bf12871015 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi @@ -50,7 +50,7 @@ cdef grpc_arg _unwrap_grpc_arg(tuple wrapped_arg): return wrapped.arg -cdef class _ArgumentProcessor: +cdef class _ChannelArg: cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references) except *: key, value = argument @@ -82,27 +82,34 @@ cdef class _ArgumentProcessor: 'Expected int, bytes, or behavior, got {}'.format(type(value))) -cdef class _ArgumentsProcessor: +cdef class _ChannelArgs: def __cinit__(self, arguments): self._arguments = () if arguments is None else tuple(arguments) - self._argument_processors = [] + self._channel_args = [] self._references = [] + self._c_arguments.arguments = NULL - cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable) except *: + cdef void _c(self, grpc_arg_pointer_vtable *vtable) except *: self._c_arguments.arguments_length = len(self._arguments) - if self._c_arguments.arguments_length == 0: - return NULL - else: + if self._c_arguments.arguments_length != 0: self._c_arguments.arguments = <grpc_arg *>gpr_malloc( self._c_arguments.arguments_length * sizeof(grpc_arg)) for index, argument in enumerate(self._arguments): - argument_processor = _ArgumentProcessor() - argument_processor.c(argument, vtable, self._references) - self._c_arguments.arguments[index] = argument_processor.c_argument - self._argument_processors.append(argument_processor) - return &self._c_arguments - - cdef un_c(self): - if self._arguments: + channel_arg = _ChannelArg() + channel_arg.c(argument, vtable, self._references) + self._c_arguments.arguments[index] = channel_arg.c_argument + self._channel_args.append(channel_arg) + + cdef grpc_channel_args *c_args(self) except *: + return &self._c_arguments + + def __dealloc__(self): + if self._c_arguments.arguments != NULL: gpr_free(self._c_arguments.arguments) + + @staticmethod + cdef _ChannelArgs from_args(object arguments, grpc_arg_pointer_vtable * vtable): + cdef _ChannelArgs channel_args = _ChannelArgs(arguments) + channel_args._c(vtable) + return channel_args diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index 135d224095..70d4abb730 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -423,16 +423,15 @@ cdef class Channel: self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer self._vtable.cmp = &_compare_pointer - cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( - arguments) - cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) + cdef _ChannelArgs channel_args = _ChannelArgs.from_args( + arguments, &self._vtable) if channel_credentials is None: self._state.c_channel = grpc_insecure_channel_create( - <char *>target, c_arguments, NULL) + <char *>target, channel_args.c_args(), NULL) else: c_channel_credentials = channel_credentials.c() self._state.c_channel = grpc_secure_channel_create( - c_channel_credentials, <char *>target, c_arguments, NULL) + c_channel_credentials, <char *>target, channel_args.c_args(), NULL) grpc_channel_credentials_release(c_channel_credentials) self._state.c_call_completion_queue = ( grpc_completion_queue_create_for_next(NULL)) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 141116df5d..3c33b46dbb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -49,7 +49,7 @@ cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline): cdef _interpret_event(grpc_event c_event): cdef _Tag tag if c_event.type == GRPC_QUEUE_TIMEOUT: - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + # TODO(ericgribkoff) Do not coopt ConnectivityEvent here. return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) elif c_event.type == GRPC_QUEUE_SHUTDOWN: # NOTE(nathaniel): For now we coopt ConnectivityEvent here. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi index 52cfccb677..4a6fbe0f96 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi @@ -16,7 +16,6 @@ cdef class Server: cdef grpc_arg_pointer_vtable _vtable - cdef readonly _ArgumentsProcessor _arguments_processor cdef grpc_server *c_server cdef bint is_started # start has been called cdef bint is_shutting_down # shutdown has been called diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index ce701724fd..d72648a35d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -29,11 +29,9 @@ cdef class Server: self._vtable.copy = &_copy_pointer self._vtable.destroy = &_destroy_pointer self._vtable.cmp = &_compare_pointer - cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor( - arguments) - cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable) - self.c_server = grpc_server_create(c_arguments, NULL) - arguments_processor.un_c() + cdef _ChannelArgs channel_args = _ChannelArgs.from_args( + arguments, &self._vtable) + self.c_server = grpc_server_create(channel_args.c_args(), NULL) self.references.append(arguments) self.is_started = False self.is_shutting_down = False @@ -128,7 +126,10 @@ cdef class Server: with nogil: grpc_server_cancel_all_calls(self.c_server) - def __dealloc__(self): + # TODO(https://github.com/grpc/grpc/issues/17515) Determine what, if any, + # portion of this is safe to call from __dealloc__, and potentially remove + # backup_shutdown_queue. + def destroy(self): if self.c_server != NULL: if not self.is_started: pass @@ -146,4 +147,8 @@ cdef class Server: while not self.is_shutdown: time.sleep(0) grpc_server_destroy(self.c_server) - grpc_shutdown() + self.c_server = NULL + + def __dealloc(self): + if self.c_server == NULL: + grpc_shutdown() diff --git a/src/python/grpcio/grpc/_grpcio_metadata.py b/src/python/grpcio/grpc/_grpcio_metadata.py index 7a9f173947..dd9d436c3f 100644 --- a/src/python/grpcio/grpc/_grpcio_metadata.py +++ b/src/python/grpcio/grpc/_grpcio_metadata.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc/_grpcio_metadata.py.template`!!! -__version__ = """1.18.0.dev0""" +__version__ = """1.19.0.dev0""" diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index e939f615df..eb750ef1a8 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -48,7 +48,7 @@ _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 -_UNEXPECTED_EXIT_SERVER_GRACE = 1.0 +_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 def _serialized_request(request_event): @@ -291,6 +291,10 @@ class _Context(grpc.ServicerContext): self._state.abortion = Exception() raise self._state.abortion + def abort_with_status(self, status): + self._state.trailing_metadata = status.trailing_metadata + self.abort(status.code, status.details) + def set_code(self, code): with self._state.condition: self._state.code = code @@ -672,6 +676,9 @@ class _ServerState(object): self.rpc_states = set() self.due = set() + # A "volatile" flag to interrupt the daemon serving thread + self.server_deallocated = False + def _add_generic_handlers(state, generic_handlers): with state.lock: @@ -698,6 +705,7 @@ def _request_call(state): # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. def _stop_serving(state): if not state.rpc_states and not state.due: + state.server.destroy() for shutdown_event in state.shutdown_events: shutdown_event.set() state.stage = _ServerStage.STOPPED @@ -711,49 +719,69 @@ def _on_call_completed(state): state.active_rpc_count -= 1 -def _serve(state): - while True: - event = state.completion_queue.poll() - if event.tag is _SHUTDOWN_TAG: +def _process_event_and_continue(state, event): + should_continue = True + if event.tag is _SHUTDOWN_TAG: + with state.lock: + state.due.remove(_SHUTDOWN_TAG) + if _stop_serving(state): + should_continue = False + elif event.tag is _REQUEST_CALL_TAG: + with state.lock: + state.due.remove(_REQUEST_CALL_TAG) + concurrency_exceeded = ( + state.maximum_concurrent_rpcs is not None and + state.active_rpc_count >= state.maximum_concurrent_rpcs) + rpc_state, rpc_future = _handle_call( + event, state.generic_handlers, state.interceptor_pipeline, + state.thread_pool, concurrency_exceeded) + if rpc_state is not None: + state.rpc_states.add(rpc_state) + if rpc_future is not None: + state.active_rpc_count += 1 + rpc_future.add_done_callback( + lambda unused_future: _on_call_completed(state)) + if state.stage is _ServerStage.STARTED: + _request_call(state) + elif _stop_serving(state): + should_continue = False + else: + rpc_state, callbacks = event.tag(event) + for callback in callbacks: + callable_util.call_logging_exceptions(callback, + 'Exception calling callback!') + if rpc_state is not None: with state.lock: - state.due.remove(_SHUTDOWN_TAG) + state.rpc_states.remove(rpc_state) if _stop_serving(state): - return - elif event.tag is _REQUEST_CALL_TAG: - with state.lock: - state.due.remove(_REQUEST_CALL_TAG) - concurrency_exceeded = ( - state.maximum_concurrent_rpcs is not None and - state.active_rpc_count >= state.maximum_concurrent_rpcs) - rpc_state, rpc_future = _handle_call( - event, state.generic_handlers, state.interceptor_pipeline, - state.thread_pool, concurrency_exceeded) - if rpc_state is not None: - state.rpc_states.add(rpc_state) - if rpc_future is not None: - state.active_rpc_count += 1 - rpc_future.add_done_callback( - lambda unused_future: _on_call_completed(state)) - if state.stage is _ServerStage.STARTED: - _request_call(state) - elif _stop_serving(state): - return - else: - rpc_state, callbacks = event.tag(event) - for callback in callbacks: - callable_util.call_logging_exceptions( - callback, 'Exception calling callback!') - if rpc_state is not None: - with state.lock: - state.rpc_states.remove(rpc_state) - if _stop_serving(state): - return + should_continue = False + return should_continue + + +def _serve(state): + while True: + timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S + event = state.completion_queue.poll(timeout) + if state.server_deallocated: + _begin_shutdown_once(state) + if event.completion_type != cygrpc.CompletionType.queue_timeout: + if not _process_event_and_continue(state, event): + return # We want to force the deletion of the previous event # ~before~ we poll again; if the event has a reference # to a shutdown Call object, this can induce spinlock. event = None +def _begin_shutdown_once(state): + with state.lock: + if state.stage is _ServerStage.STARTED: + state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) + state.stage = _ServerStage.GRACE + state.shutdown_events = [] + state.due.add(_SHUTDOWN_TAG) + + def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: @@ -761,11 +789,7 @@ def _stop(state, grace): shutdown_event.set() return shutdown_event else: - if state.stage is _ServerStage.STARTED: - state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) - state.stage = _ServerStage.GRACE - state.shutdown_events = [] - state.due.add(_SHUTDOWN_TAG) + _begin_shutdown_once(state) shutdown_event = threading.Event() state.shutdown_events.append(shutdown_event) if grace is None: @@ -836,7 +860,9 @@ class _Server(grpc.Server): return _stop(self._state, grace) def __del__(self): - _stop(self._state, None) + # We can not grab a lock in __del__(), so set a flag to signal the + # serving daemon thread (if it exists) to initiate shutdown. + self._state.server_deallocated = True def create_server(thread_pool, generic_rpc_handlers, interceptors, options, diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index d90b34bcbd..2938a38b44 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -132,15 +132,12 @@ class _ChannelReadyFuture(grpc.Future): def result(self, timeout=None): self._block(timeout) - return None def exception(self, timeout=None): self._block(timeout) - return None def traceback(self, timeout=None): self._block(timeout) - return None def add_done_callback(self, fn): with self._condition: diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index c6ca970bee..6a1fd676ca 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -326,6 +326,7 @@ CORE_SOURCE_FILES = [ 'src/core/ext/filters/client_channel/parse_address.cc', 'src/core/ext/filters/client_channel/proxy_mapper.cc', 'src/core/ext/filters/client_channel/proxy_mapper_registry.cc', + 'src/core/ext/filters/client_channel/request_routing.cc', 'src/core/ext/filters/client_channel/resolver.cc', 'src/core/ext/filters/client_channel/resolver_registry.cc', 'src/core/ext/filters/client_channel/resolver_result_parsing.cc', diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py index 2e91818d2c..8e2f4d30bb 100644 --- a/src/python/grpcio/grpc_version.py +++ b/src/python/grpcio/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_channelz/grpc_version.py b/src/python/grpcio_channelz/grpc_version.py index 16356ea402..5f3a894a2a 100644 --- a/src/python/grpcio_channelz/grpc_version.py +++ b/src/python/grpcio_channelz/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_channelz/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py index 0583659428..0a5bbb5504 100644 --- a/src/python/grpcio_health_checking/grpc_health/v1/health.py +++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py @@ -23,15 +23,61 @@ from grpc_health.v1 import health_pb2_grpc as _health_pb2_grpc SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name['Health'].full_name +class _Watcher(): + + def __init__(self): + self._condition = threading.Condition() + self._responses = list() + self._open = True + + def __iter__(self): + return self + + def _next(self): + with self._condition: + while not self._responses and self._open: + self._condition.wait() + if self._responses: + return self._responses.pop(0) + else: + raise StopIteration() + + def next(self): + return self._next() + + def __next__(self): + return self._next() + + def add(self, response): + with self._condition: + self._responses.append(response) + self._condition.notify() + + def close(self): + with self._condition: + self._open = False + self._condition.notify() + + class HealthServicer(_health_pb2_grpc.HealthServicer): """Servicer handling RPCs for service statuses.""" def __init__(self): - self._server_status_lock = threading.Lock() + self._lock = threading.RLock() self._server_status = {} + self._watchers = {} + + def _on_close_callback(self, watcher, service): + + def callback(): + with self._lock: + self._watchers[service].remove(watcher) + watcher.close() + + return callback def Check(self, request, context): - with self._server_status_lock: + with self._lock: status = self._server_status.get(request.service) if status is None: context.set_code(grpc.StatusCode.NOT_FOUND) @@ -39,14 +85,30 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): else: return _health_pb2.HealthCheckResponse(status=status) + def Watch(self, request, context): + service = request.service + with self._lock: + status = self._server_status.get(service) + if status is None: + status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member + watcher = _Watcher() + watcher.add(_health_pb2.HealthCheckResponse(status=status)) + if service not in self._watchers: + self._watchers[service] = set() + self._watchers[service].add(watcher) + context.add_callback(self._on_close_callback(watcher, service)) + return watcher + def set(self, service, status): """Sets the status of a service. - Args: - service: string, the name of the service. - NOTE, '' must be set. - status: HealthCheckResponse.status enum value indicating - the status of the service - """ - with self._server_status_lock: + Args: + service: string, the name of the service. NOTE, '' must be set. + status: HealthCheckResponse.status enum value indicating the status of + the service + """ + with self._lock: self._server_status[service] = status + if service in self._watchers: + for watcher in self._watchers[service]: + watcher.add(_health_pb2.HealthCheckResponse(status=status)) diff --git a/src/python/grpcio_health_checking/grpc_version.py b/src/python/grpcio_health_checking/grpc_version.py index 85fa762f7e..4c2d434066 100644 --- a/src/python/grpcio_health_checking/grpc_version.py +++ b/src/python/grpcio_health_checking/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_health_checking/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_reflection/grpc_version.py b/src/python/grpcio_reflection/grpc_version.py index e62ab169a2..6b88b2dfc5 100644 --- a/src/python/grpcio_reflection/grpc_version.py +++ b/src/python/grpcio_reflection/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_reflection/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_status/.gitignore b/src/python/grpcio_status/.gitignore new file mode 100644 index 0000000000..19d1523efd --- /dev/null +++ b/src/python/grpcio_status/.gitignore @@ -0,0 +1,3 @@ +build/ +grpcio_status.egg-info/ +dist/ diff --git a/src/python/grpcio_status/MANIFEST.in b/src/python/grpcio_status/MANIFEST.in new file mode 100644 index 0000000000..09b8ea721e --- /dev/null +++ b/src/python/grpcio_status/MANIFEST.in @@ -0,0 +1,4 @@ +include grpc_version.py +recursive-include grpc_status *.py +global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_status/README.rst b/src/python/grpcio_status/README.rst new file mode 100644 index 0000000000..dc2f7b1dab --- /dev/null +++ b/src/python/grpcio_status/README.rst @@ -0,0 +1,9 @@ +gRPC Python Status Proto +=========================== + +Reference package for GRPC Python status proto mapping. + +Dependencies +------------ + +Depends on the `grpcio` package, available from PyPI via `pip install grpcio`. diff --git a/src/python/grpcio_status/grpc_status/BUILD.bazel b/src/python/grpcio_status/grpc_status/BUILD.bazel new file mode 100644 index 0000000000..223a077c3f --- /dev/null +++ b/src/python/grpcio_status/grpc_status/BUILD.bazel @@ -0,0 +1,14 @@ +load("@grpc_python_dependencies//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "grpc_status", + srcs = ["rpc_status.py",], + deps = [ + "//src/python/grpcio/grpc:grpcio", + requirement('protobuf'), + requirement('googleapis-common-protos'), + ], + imports=["../",], +) diff --git a/src/python/grpcio_status/grpc_status/__init__.py b/src/python/grpcio_status/grpc_status/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_status/grpc_status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_status/grpc_status/rpc_status.py b/src/python/grpcio_status/grpc_status/rpc_status.py new file mode 100644 index 0000000000..87618fa541 --- /dev/null +++ b/src/python/grpcio_status/grpc_status/rpc_status.py @@ -0,0 +1,92 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reference implementation for status mapping in gRPC Python.""" + +import collections + +import grpc + +# TODO(https://github.com/bazelbuild/bazel/issues/6844) +# Due to Bazel issue, the namespace packages won't resolve correctly. +# Adding this unused-import as a workaround to avoid module-not-found error +# under Bazel builds. +import google.protobuf # pylint: disable=unused-import +from google.rpc import status_pb2 + +_CODE_TO_GRPC_CODE_MAPPING = {x.value[0]: x for x in grpc.StatusCode} + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + + +class _Status( + collections.namedtuple( + '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status): + pass + + +def _code_to_grpc_status_code(code): + try: + return _CODE_TO_GRPC_CODE_MAPPING[code] + except KeyError: + raise ValueError('Invalid status code %s' % code) + + +def from_call(call): + """Returns a google.rpc.status.Status message corresponding to a given grpc.Call. + + This is an EXPERIMENTAL API. + + Args: + call: A grpc.Call instance. + + Returns: + A google.rpc.status.Status message representing the status of the RPC. + + Raises: + ValueError: If the gRPC call's code or details are inconsistent with the + status code and message inside of the google.rpc.status.Status. + """ + for key, value in call.trailing_metadata(): + if key == _GRPC_DETAILS_METADATA_KEY: + rich_status = status_pb2.Status.FromString(value) + if call.code().value[0] != rich_status.code: + raise ValueError( + 'Code in Status proto (%s) doesn\'t match status code (%s)' + % (_code_to_grpc_status_code(rich_status.code), + call.code())) + if call.details() != rich_status.message: + raise ValueError( + 'Message in Status proto (%s) doesn\'t match status details (%s)' + % (rich_status.message, call.details())) + return rich_status + return None + + +def to_status(status): + """Convert a google.rpc.status.Status message to grpc.Status. + + This is an EXPERIMENTAL API. + + Args: + status: a google.rpc.status.Status message representing the non-OK status + to terminate the RPC with and communicate it to the client. + + Returns: + A grpc.Status instance representing the input google.rpc.status.Status message. + """ + return _Status( + code=_code_to_grpc_status_code(status.code), + details=status.message, + trailing_metadata=((_GRPC_DETAILS_METADATA_KEY, + status.SerializeToString()),)) diff --git a/src/python/grpcio_status/grpc_version.py b/src/python/grpcio_status/grpc_version.py new file mode 100644 index 0000000000..2e58eb3b26 --- /dev/null +++ b/src/python/grpcio_status/grpc_version.py @@ -0,0 +1,17 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_status/grpc_version.py.template`!!! + +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_status/setup.py b/src/python/grpcio_status/setup.py new file mode 100644 index 0000000000..983d3ea430 --- /dev/null +++ b/src/python/grpcio_status/setup.py @@ -0,0 +1,93 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup module for the GRPC Python package's status mapping.""" + +import os + +import setuptools + +# Ensure we're in the proper directory whether or not we're being used by pip. +os.chdir(os.path.dirname(os.path.abspath(__file__))) + +# Break import-style to ensure we can actually find our local modules. +import grpc_version + + +class _NoOpCommand(setuptools.Command): + """No-op command.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + pass + + +CLASSIFIERS = [ + 'Development Status :: 5 - Production/Stable', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'License :: OSI Approved :: Apache Software License', +] + +PACKAGE_DIRECTORIES = { + '': '.', +} + +INSTALL_REQUIRES = ( + 'protobuf>=3.6.0', + 'grpcio>={version}'.format(version=grpc_version.VERSION), + 'googleapis-common-protos>=1.5.5', +) + +try: + import status_commands as _status_commands + # we are in the build environment, otherwise the above import fails + COMMAND_CLASS = { + # Run preprocess from the repository *before* doing any packaging! + 'preprocess': _status_commands.Preprocess, + 'build_package_protos': _NoOpCommand, + } +except ImportError: + COMMAND_CLASS = { + # wire up commands to no-op not to break the external dependencies + 'preprocess': _NoOpCommand, + 'build_package_protos': _NoOpCommand, + } + +setuptools.setup( + name='grpcio-status', + version=grpc_version.VERSION, + description='Status proto mapping for gRPC', + author='The gRPC Authors', + author_email='grpc-io@googlegroups.com', + url='https://grpc.io', + license='Apache License 2.0', + classifiers=CLASSIFIERS, + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages('.'), + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS) diff --git a/src/python/grpcio_status/status_commands.py b/src/python/grpcio_status/status_commands.py new file mode 100644 index 0000000000..78cd497f62 --- /dev/null +++ b/src/python/grpcio_status/status_commands.py @@ -0,0 +1,39 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides distutils command classes for the GRPC Python setup process.""" + +import os +import shutil + +import setuptools + +ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + + +class Preprocess(setuptools.Command): + """Command to copy LICENSE from root directory.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_handler.py b/src/python/grpcio_testing/grpc_testing/_server/_handler.py index 0e3404b0d0..100d8195f6 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_handler.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_handler.py @@ -185,7 +185,7 @@ class _Handler(Handler): elif self._code is None: self._condition.wait() else: - return self._trailing_metadata, self._code, self._details, + return self._trailing_metadata, self._code, self._details def expire(self): with self._condition: diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py index 90eeb130d3..5b1dfeacdf 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py @@ -70,6 +70,9 @@ class ServicerContext(grpc.ServicerContext): def abort(self, code, details): raise NotImplementedError() + def abort_with_status(self, status): + raise NotImplementedError() + def set_code(self, code): self._rpc.set_code(code) diff --git a/src/python/grpcio_testing/grpc_version.py b/src/python/grpcio_testing/grpc_version.py index 7b4c1695fa..d4c5d94ecb 100644 --- a/src/python/grpcio_testing/grpc_version.py +++ b/src/python/grpcio_testing/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_testing/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index 65e9a99950..582ce898de 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -22,7 +22,6 @@ import re import shutil import subprocess import sys -import traceback import setuptools from setuptools.command import build_ext @@ -133,6 +132,16 @@ class TestGevent(setuptools.Command): # TODO(https://github.com/grpc/grpc/issues/15411) unpin gevent version # This test will stuck while running higher version of gevent 'unit._auth_context_test.AuthContextTest.testSessionResumption', + # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests + 'unit._metadata_flags_test', + 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call', + 'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list', # TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels', 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets', diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py index 2fcd1ad617..e1645ab1b8 100644 --- a/src/python/grpcio_tests/grpc_version.py +++ b/src/python/grpcio_tests/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! -VERSION = '1.18.0.dev0' +VERSION = '1.19.0.dev0' diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index f56425ac6d..800b865da6 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -37,12 +37,19 @@ PACKAGE_DIRECTORIES = { } INSTALL_REQUIRES = ( - 'coverage>=4.0', 'enum34>=1.0.4', + 'coverage>=4.0', + 'enum34>=1.0.4', 'grpcio>={version}'.format(version=grpc_version.VERSION), - 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), + # TODO(https://github.com/pypa/warehouse/issues/5196) + # Re-enable it once we got the name back + # 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), + 'grpcio-status>={version}'.format(version=grpc_version.VERSION), 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), 'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION), - 'oauth2client>=1.4.7', 'protobuf>=3.6.0', 'six>=1.10', 'google-auth>=1.0.0', + 'oauth2client>=1.4.7', + 'protobuf>=3.6.0', + 'six>=1.10', + 'google-auth>=1.0.0', 'requests>=2.14.2') if not PY3: diff --git a/src/python/grpcio_tests/tests/_runner.py b/src/python/grpcio_tests/tests/_runner.py index eaaa027e61..9ef0f17684 100644 --- a/src/python/grpcio_tests/tests/_runner.py +++ b/src/python/grpcio_tests/tests/_runner.py @@ -203,7 +203,7 @@ class Runner(object): check_kill_self() time.sleep(0) case_thread.join() - except: + except: # pylint: disable=try-except-raise # re-raise the exception after forcing the with-block to end raise result.set_output(augmented_case.case, stdout_pipe.output(), diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py index 8ca5189522..c63ff5cd84 100644 --- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -88,11 +88,10 @@ def _generate_channel_server_pairs(n): def _close_channel_server_pairs(pairs): for pair in pairs: pair.server.stop(None) - # TODO(ericgribkoff) This del should not be required - del pair.server pair.channel.close() +@unittest.skip('https://github.com/pypa/warehouse/issues/5196') class ChannelzServicerTest(unittest.TestCase): def _send_successful_unary_unary(self, idx): diff --git a/src/python/grpcio_tests/tests/health_check/BUILD.bazel b/src/python/grpcio_tests/tests/health_check/BUILD.bazel index 19e1e1b2e1..77bc61aa30 100644 --- a/src/python/grpcio_tests/tests/health_check/BUILD.bazel +++ b/src/python/grpcio_tests/tests/health_check/BUILD.bazel @@ -9,6 +9,7 @@ py_test( "//src/python/grpcio/grpc:grpcio", "//src/python/grpcio_health_checking/grpc_health/v1:grpc_health", "//src/python/grpcio_tests/tests/unit:test_common", + "//src/python/grpcio_tests/tests/unit/framework/common:common", ], imports = ["../../",], ) diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 350b5eebe5..35794987bc 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests of grpc_health.v1.health.""" +import threading +import time import unittest import grpc @@ -21,58 +23,199 @@ from grpc_health.v1 import health_pb2 from grpc_health.v1 import health_pb2_grpc from tests.unit import test_common +from tests.unit.framework.common import test_constants + +from six.moves import queue + +_SERVING_SERVICE = 'grpc.test.TestServiceServing' +_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' +_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' +_WATCH_SERVICE = 'grpc.test.WatchService' + + +def _consume_responses(response_iterator, response_queue): + for response in response_iterator: + response_queue.put(response) class HealthServicerTest(unittest.TestCase): def setUp(self): - servicer = health.HealthServicer() - servicer.set('', health_pb2.HealthCheckResponse.SERVING) - servicer.set('grpc.test.TestServiceServing', - health_pb2.HealthCheckResponse.SERVING) - servicer.set('grpc.test.TestServiceUnknown', - health_pb2.HealthCheckResponse.UNKNOWN) - servicer.set('grpc.test.TestServiceNotServing', - health_pb2.HealthCheckResponse.NOT_SERVING) + self._servicer = health.HealthServicer() + self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) + self._servicer.set(_SERVING_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + self._servicer.set(_UNKNOWN_SERVICE, + health_pb2.HealthCheckResponse.UNKNOWN) + self._servicer.set(_NOT_SERVING_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) self._server = test_common.test_server() port = self._server.add_insecure_port('[::]:0') - health_pb2_grpc.add_HealthServicer_to_server(servicer, self._server) + health_pb2_grpc.add_HealthServicer_to_server(self._servicer, + self._server) self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = health_pb2_grpc.HealthStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) - def test_empty_service(self): + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_check_empty_service(self): request = health_pb2.HealthCheckRequest() resp = self._stub.Check(request) self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) - def test_serving_service(self): - request = health_pb2.HealthCheckRequest( - service='grpc.test.TestServiceServing') + def test_check_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE) resp = self._stub.Check(request) self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status) - def test_unknown_serivce(self): - request = health_pb2.HealthCheckRequest( - service='grpc.test.TestServiceUnknown') + def test_check_unknown_serivce(self): + request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE) resp = self._stub.Check(request) self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status) - def test_not_serving_service(self): - request = health_pb2.HealthCheckRequest( - service='grpc.test.TestServiceNotServing') + def test_check_not_serving_service(self): + request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) resp = self._stub.Check(request) self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, resp.status) - def test_not_found_service(self): + def test_check_not_found_service(self): request = health_pb2.HealthCheckRequest(service='not-found') with self.assertRaises(grpc.RpcError) as context: resp = self._stub.Check(request) self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) + def test_watch_empty_service(self): + request = health_pb2.HealthCheckRequest(service='') + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_watch_new_service(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.NOT_SERVING) + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, + response.status) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_watch_service_isolation(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + self._servicer.set('some-other-service', + health_pb2.HealthCheckResponse.SERVING) + with self.assertRaises(queue.Empty): + response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + + rendezvous.cancel() + thread.join() + self.assertTrue(response_queue.empty()) + + def test_two_watchers(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue1 = queue.Queue() + response_queue2 = queue.Queue() + rendezvous1 = self._stub.Watch(request) + rendezvous2 = self._stub.Watch(request) + thread1 = threading.Thread( + target=_consume_responses, args=(rendezvous1, response_queue1)) + thread2 = threading.Thread( + target=_consume_responses, args=(rendezvous2, response_queue2)) + thread1.start() + thread2.start() + + response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response2.status) + + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT) + response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response1.status) + self.assertEqual(health_pb2.HealthCheckResponse.SERVING, + response2.status) + + rendezvous1.cancel() + rendezvous2.cancel() + thread1.join() + thread2.join() + self.assertTrue(response_queue1.empty()) + self.assertTrue(response_queue2.empty()) + + def test_cancelled_watch_removed_from_watch_list(self): + request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) + response_queue = queue.Queue() + rendezvous = self._stub.Watch(request) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue)) + thread.start() + + response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) + self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + response.status) + + rendezvous.cancel() + self._servicer.set(_WATCH_SERVICE, + health_pb2.HealthCheckResponse.SERVING) + thread.join() + + # Wait, if necessary, for serving thread to process client cancellation + timeout = time.time() + test_constants.SHORT_TIMEOUT + while time.time() < timeout and self._servicer._watchers[_WATCH_SERVICE]: + time.sleep(1) + self.assertFalse(self._servicer._watchers[_WATCH_SERVICE], + 'watch set should be empty') + self.assertTrue(response_queue.empty()) + def test_health_service_name(self): self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py index 698c37017f..56cb29477c 100644 --- a/src/python/grpcio_tests/tests/interop/client.py +++ b/src/python/grpcio_tests/tests/interop/client.py @@ -54,7 +54,6 @@ def _args(): help='replace platform root CAs with ca.pem') parser.add_argument( '--server_host_override', - default="foo.test.google.fr", type=str, help='the server host to which to claim to connect') parser.add_argument( @@ -100,10 +99,13 @@ def _stub(args): channel_credentials = grpc.composite_channel_credentials( channel_credentials, call_credentials) - channel = grpc.secure_channel(target, channel_credentials, (( - 'grpc.ssl_target_name_override', - args.server_host_override, - ),)) + channel_opts = None + if args.server_host_override: + channel_opts = (( + 'grpc.ssl_target_name_override', + args.server_host_override, + ),) + channel = grpc.secure_channel(target, channel_credentials, channel_opts) else: channel = grpc.insecure_channel(target) if args.test_case == "unimplemented_service": diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py index e21ea0010a..2b735526cb 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py @@ -144,7 +144,7 @@ class _ProtoBeforeGrpcProtocStyle(object): absolute_proto_file_names) pb2_grpc_protoc_exit_code = _protoc( proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names) - return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code, + return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code class _GrpcBeforeProtoProtocStyle(object): @@ -160,7 +160,7 @@ class _GrpcBeforeProtoProtocStyle(object): proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names) pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None, absolute_proto_file_names) - return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code, + return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code _PROTOC_STYLES = ( @@ -243,9 +243,9 @@ class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)): def _services_modules(self): if self.PROTOC_STYLE.grpc_in_pb2_expected(): - return self._services_pb2, self._services_pb2_grpc, + return self._services_pb2, self._services_pb2_grpc else: - return self._services_pb2_grpc, + return (self._services_pb2_grpc,) def test_imported_attributes(self): self._protoc() diff --git a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py index b46e53315e..43c90af6a7 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py @@ -223,7 +223,7 @@ def _CreateService(payload_pb2, responses_pb2, service_pb2): server.start() channel = implementations.insecure_channel('localhost', port) stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) - yield servicer_methods, stub, + yield servicer_methods, stub server.stop(0) diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index 0488450740..fac0e44e5a 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -180,7 +180,7 @@ class StreamingSyncBenchmarkClient(BenchmarkClient): self._streams = [ _SyncStream(self._stub, self._generic, self._request, self._handle_response) - for _ in xrange(config.outstanding_rpcs_per_channel) + for _ in range(config.outstanding_rpcs_per_channel) ] self._curr_stream = 0 diff --git a/src/python/grpcio_tests/tests/qps/client_runner.py b/src/python/grpcio_tests/tests/qps/client_runner.py index e79abab3c7..a57524c74e 100644 --- a/src/python/grpcio_tests/tests/qps/client_runner.py +++ b/src/python/grpcio_tests/tests/qps/client_runner.py @@ -77,7 +77,7 @@ class ClosedLoopClientRunner(ClientRunner): def start(self): self._is_running = True self._client.start() - for _ in xrange(self._request_count): + for _ in range(self._request_count): self._client.send_request() def stop(self): diff --git a/src/python/grpcio_tests/tests/qps/worker_server.py b/src/python/grpcio_tests/tests/qps/worker_server.py index 337a94b546..a03367ec63 100644 --- a/src/python/grpcio_tests/tests/qps/worker_server.py +++ b/src/python/grpcio_tests/tests/qps/worker_server.py @@ -109,7 +109,7 @@ class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): start_time = time.time() # Create a client for each channel - for i in xrange(config.client_channels): + for i in range(config.client_channels): server = config.server_targets[i % len(config.server_targets)] runner = self._create_client_runner(server, config, qps_data) client_runners.append(runner) diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py index bcd9e14a38..560f6d3ddb 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -56,8 +56,12 @@ class ReflectionServicerTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = reflection_pb2_grpc.ServerReflectionStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() def testFileByName(self): requests = ( diff --git a/src/python/grpcio_tests/tests/status/BUILD.bazel b/src/python/grpcio_tests/tests/status/BUILD.bazel new file mode 100644 index 0000000000..937e50498e --- /dev/null +++ b/src/python/grpcio_tests/tests/status/BUILD.bazel @@ -0,0 +1,19 @@ +load("@grpc_python_dependencies//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "grpc_status_test", + srcs = ["_grpc_status_test.py"], + main = "_grpc_status_test.py", + size = "small", + deps = [ + "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_status/grpc_status:grpc_status", + "//src/python/grpcio_tests/tests/unit:test_common", + "//src/python/grpcio_tests/tests/unit/framework/common:common", + requirement('protobuf'), + requirement('googleapis-common-protos'), + ], + imports = ["../../",], +) diff --git a/src/python/grpcio_tests/tests/status/__init__.py b/src/python/grpcio_tests/tests/status/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_tests/tests/status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py new file mode 100644 index 0000000000..519c372a96 --- /dev/null +++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -0,0 +1,173 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_status.""" + +import unittest + +import logging +import traceback + +import grpc +from grpc_status import rpc_status + +from tests.unit import test_common + +from google.protobuf import any_pb2 +from google.rpc import code_pb2, status_pb2, error_details_pb2 + +_STATUS_OK = '/test/StatusOK' +_STATUS_NOT_OK = '/test/StatusNotOk' +_ERROR_DETAILS = '/test/ErrorDetails' +_INCONSISTENT = '/test/Inconsistent' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + +_STATUS_DETAILS = 'This is an error detail' +_STATUS_DETAILS_ANOTHER = 'This is another error detail' + + +def _ok_unary_unary(request, servicer_context): + return _RESPONSE + + +def _not_ok_unary_unary(request, servicer_context): + servicer_context.abort(grpc.StatusCode.INTERNAL, _STATUS_DETAILS) + + +def _error_details_unary_unary(request, servicer_context): + details = any_pb2.Any() + details.Pack( + error_details_pb2.DebugInfo( + stack_entries=traceback.format_stack(), + detail='Intentionally invoked')) + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + details=[details], + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +def _inconsistent_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + ) + servicer_context.set_code(grpc.StatusCode.NOT_FOUND) + servicer_context.set_details(_STATUS_DETAILS_ANOTHER) + # User put inconsistent status information in trailing metadata + servicer_context.set_trailing_metadata(((_GRPC_DETAILS_METADATA_KEY, + rich_status.SerializeToString()),)) + + +def _invalid_code_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=42, + message='Invalid code', + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _STATUS_OK: + return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) + elif handler_call_details.method == _STATUS_NOT_OK: + return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) + elif handler_call_details.method == _ERROR_DETAILS: + return grpc.unary_unary_rpc_method_handler( + _error_details_unary_unary) + elif handler_call_details.method == _INCONSISTENT: + return grpc.unary_unary_rpc_method_handler( + _inconsistent_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.unary_unary_rpc_method_handler( + _invalid_code_unary_unary) + else: + return None + + +class StatusTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_status_ok(self): + _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) + + # Succeed RPC doesn't have status + status = rpc_status.from_call(call) + self.assertIs(status, None) + + def test_status_not_ok(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + # Failed RPC doesn't automatically generate status + status = rpc_status.from_call(rpc_error) + self.assertIs(status, None) + + def test_error_details(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) + rpc_error = exception_context.exception + + status = rpc_status.from_call(rpc_error) + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + + # Check if the underlying proto message is intact + self.assertEqual(status.details[0].Is( + error_details_pb2.DebugInfo.DESCRIPTOR), True) + info = error_details_pb2.DebugInfo() + status.details[0].Unpack(info) + self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + + def test_code_message_validation(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) + + # Code/Message validation failed + self.assertRaises(ValueError, rpc_status.from_call, rpc_error) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + # Invalid status code exception raised during coversion + self.assertIn('Invalid status code', rpc_error.details()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py index 41f2e1b6c2..4c35b05044 100644 --- a/src/python/grpcio_tests/tests/stress/client.py +++ b/src/python/grpcio_tests/tests/stress/client.py @@ -71,7 +71,6 @@ def _args(): '--use_tls', help='Whether to use TLS', default=False, type=bool) parser.add_argument( '--server_host_override', - default="foo.test.google.fr", help='the server host to which to claim to connect', type=str) return parser.parse_args() @@ -132,9 +131,9 @@ def run_test(args): server.start() for test_server_target in test_server_targets: - for _ in xrange(args.num_channels_per_server): + for _ in range(args.num_channels_per_server): channel = _get_channel(test_server_target, args) - for _ in xrange(args.num_stubs_per_channel): + for _ in range(args.num_stubs_per_channel): stub = test_pb2_grpc.TestServiceStub(channel) runner = test_runner.TestRunner(stub, test_cases, hist, exception_queue, stop_event) diff --git a/src/python/grpcio_tests/tests/testing/_client_application.py b/src/python/grpcio_tests/tests/testing/_client_application.py index 3ddeba2373..4d42df0389 100644 --- a/src/python/grpcio_tests/tests/testing/_client_application.py +++ b/src/python/grpcio_tests/tests/testing/_client_application.py @@ -130,9 +130,9 @@ def _run_stream_stream(stub): request_pipe = _Pipe() response_iterator = stub.StreStre(iter(request_pipe)) request_pipe.add(_application_common.STREAM_STREAM_REQUEST) - first_responses = next(response_iterator), next(response_iterator), + first_responses = next(response_iterator), next(response_iterator) request_pipe.add(_application_common.STREAM_STREAM_REQUEST) - second_responses = next(response_iterator), next(response_iterator), + second_responses = next(response_iterator), next(response_iterator) request_pipe.close() try: next(response_iterator) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 9cffd3df19..de4c2c1fdd 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -15,10 +15,12 @@ "protoc_plugin._split_definitions_test.SplitProtoSingleProtocExecutionProtocStyleTest", "protoc_plugin.beta_python_plugin_test.PythonPluginTest", "reflection._reflection_servicer_test.ReflectionServicerTest", + "status._grpc_status_test.StatusTest", "testing._client_test.ClientTest", "testing._server_test.FirstServiceServicerTest", "testing._time_test.StrictFakeTimeTest", "testing._time_test.StrictRealTimeTest", + "unit._abort_test.AbortTest", "unit._api_test.AllTest", "unit._api_test.ChannelConnectivityTest", "unit._api_test.ChannelTest", @@ -55,12 +57,14 @@ "unit._reconnect_test.ReconnectTest", "unit._resource_exhausted_test.ResourceExhaustedTest", "unit._rpc_test.RPCTest", + "unit._server_shutdown_test.ServerShutdown", "unit._server_ssl_cert_config_test.ServerSSLCertConfigFetcherParamsChecks", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithoutClientAuth", "unit._server_test.ServerTest", "unit._session_cache_test.SSLSessionCacheTest", + "unit._version_test.VersionTest", "unit.beta._beta_features_test.BetaFeaturesTest", "unit.beta._beta_features_test.ContextManagementAndLifecycleTest", "unit.beta._connectivity_channel_test.ConnectivityStatesTest", diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index de33b81e32..a9bcd9f304 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -3,9 +3,11 @@ load("@grpc_python_dependencies//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) GRPCIO_TESTS_UNIT = [ + "_abort_test.py", "_api_test.py", "_auth_context_test.py", "_auth_test.py", + "_version_test.py", "_channel_args_test.py", "_channel_close_test.py", "_channel_connectivity_test.py", @@ -27,6 +29,7 @@ GRPCIO_TESTS_UNIT = [ # TODO(ghostwriternr): To be added later. # "_server_ssl_cert_config_test.py", "_server_test.py", + "_server_shutdown_test.py", "_session_cache_test.py", ] @@ -49,6 +52,11 @@ py_library( ) py_library( + name = "_server_shutdown_scenarios", + srcs = ["_server_shutdown_scenarios.py"], +) + +py_library( name = "_thread_pool", srcs = ["_thread_pool.py"], ) @@ -69,6 +77,7 @@ py_library( ":resources", ":test_common", ":_exit_scenarios", + ":_server_shutdown_scenarios", ":_thread_pool", ":_from_grpc_import_star", "//src/python/grpcio_tests/tests/unit/framework/common", diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py new file mode 100644 index 0000000000..6438f6897a --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -0,0 +1,124 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server context abort mechanism""" + +import unittest +import collections +import logging + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_ABORT = '/test/abort' +_ABORT_WITH_STATUS = '/test/AbortWithStatus' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_ABORT_DETAILS = 'Abandon ship!' +_ABORT_METADATA = (('a-trailing-metadata', '42'),) + + +class _Status( + collections.namedtuple( + '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status): + pass + + +def abort_unary_unary(request, servicer_context): + servicer_context.abort( + grpc.StatusCode.INTERNAL, + _ABORT_DETAILS, + ) + raise Exception('This line should not be executed!') + + +def abort_with_status_unary_unary(request, servicer_context): + servicer_context.abort_with_status( + _Status( + code=grpc.StatusCode.INTERNAL, + details=_ABORT_DETAILS, + trailing_metadata=_ABORT_METADATA, + )) + raise Exception('This line should not be executed!') + + +def invalid_code_unary_unary(request, servicer_context): + servicer_context.abort( + 42, + _ABORT_DETAILS, + ) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _ABORT: + return grpc.unary_unary_rpc_method_handler(abort_unary_unary) + elif handler_call_details.method == _ABORT_WITH_STATUS: + return grpc.unary_unary_rpc_method_handler( + abort_with_status_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.stream_stream_rpc_method_handler( + invalid_code_unary_unary) + else: + return None + + +class AbortTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._channel.close() + self._server.stop(0) + + def test_abort(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + def test_abort_with_status(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + self.assertEqual(rpc_error.trailing_metadata(), _ABORT_METADATA) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index 38072861a4..0dc6a8718c 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -32,6 +32,7 @@ class AllTest(unittest.TestCase): 'Future', 'ChannelConnectivity', 'StatusCode', + 'Status', 'RpcError', 'RpcContext', 'Call', @@ -100,6 +101,7 @@ class ChannelTest(unittest.TestCase): def test_secure_channel(self): channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel('google.com:443', channel_credentials) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index b1b5bbdcab..96c4e9ec76 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -71,8 +71,8 @@ class AuthContextTest(unittest.TestCase): port = server.add_insecure_port('[::]:0') server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + with grpc.insecure_channel('localhost:%d' % port) as channel: + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) server.stop(None) auth_data = pickle.loads(response) @@ -98,6 +98,7 @@ class AuthContextTest(unittest.TestCase): channel_creds, options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) @@ -132,6 +133,7 @@ class AuthContextTest(unittest.TestCase): options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py index 727fb7d65f..565bd39b3a 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -75,6 +75,8 @@ class ChannelConnectivityTest(unittest.TestCase): channel.unsubscribe(callback.update) fifth_connectivities = callback.connectivities() + channel.close() + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities) @@ -108,7 +110,8 @@ class ChannelConnectivityTest(unittest.TestCase): _ready_in_connectivities) second_callback.block_until_connectivities_satisfy( _ready_in_connectivities) - del channel + channel.close() + server.stop(None) self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) @@ -139,6 +142,7 @@ class ChannelConnectivityTest(unittest.TestCase): callback.block_until_connectivities_satisfy( _last_connectivity_is_not_ready) channel.unsubscribe(callback.update) + channel.close() self.assertFalse(thread_pool.was_used()) diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py index 345460ef40..46a4eb9bb6 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -60,6 +60,8 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertTrue(ready_future.done()) self.assertFalse(ready_future.running()) + channel.close() + def test_immediately_connectable_channel_connectivity(self): thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) server = grpc.server(thread_pool, options=(('grpc.so_reuseport', 0),)) @@ -84,6 +86,9 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertFalse(ready_future.running()) self.assertFalse(thread_pool.was_used()) + channel.close() + server.stop(None) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index 876d8e827e..87884a19dc 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -77,6 +77,9 @@ class CompressionTest(unittest.TestCase): self._port = self._server.add_insecure_port('[::]:0') self._server.start() + def tearDown(self): + self._server.stop(None) + def testUnary(self): request = b'\x00' * 100 @@ -102,6 +105,7 @@ class CompressionTest(unittest.TestCase): response = multi_callable( request, metadata=[('grpc-internal-encoding-request', 'gzip')]) self.assertEqual(request, response) + compressed_channel.close() def testStreaming(self): request = b'\x00' * 100 @@ -115,6 +119,7 @@ class CompressionTest(unittest.TestCase): call = multi_callable(iter([request] * test_constants.STREAM_LENGTH)) for response in call: self.assertEqual(request, response) + compressed_channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py index aeb02458a7..5a5dedd5f2 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -27,6 +27,7 @@ def _get_number_active_threads(): class ForkPosixTester(unittest.TestCase): def setUp(self): + self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT cygrpc._GRPC_ENABLE_FORK_SUPPORT = True def testForkManagedThread(self): @@ -50,6 +51,9 @@ class ForkPosixTester(unittest.TestCase): thread.join() self.assertEqual(0, _get_number_active_threads()) + def tearDown(self): + cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag + @unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') class ForkWindowsTester(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index 3e8393b53c..f27ea422d0 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -96,6 +96,7 @@ class EmptyMessageTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index 6c551df3ec..81de1dae1d 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -71,6 +71,7 @@ class ErrorMessageEncodingTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py index 5226537579..b429ee089f 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_test.py +++ b/src/python/grpcio_tests/tests/unit/_exit_test.py @@ -71,7 +71,6 @@ def wait(process): process.wait() -@unittest.skip('https://github.com/grpc/grpc/issues/7311') class ExitTest(unittest.TestCase): def test_unstarted_server(self): @@ -130,6 +129,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_unary_unary_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], @@ -138,6 +139,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_unary_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], @@ -145,6 +148,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_stream_unary_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], @@ -153,6 +158,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_stream_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], @@ -161,6 +168,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_unary_stream_call(self): process = subprocess.Popen( BASE_COMMAND + @@ -169,6 +178,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_stream_unary_call(self): process = subprocess.Popen( BASE_COMMAND + @@ -178,6 +189,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_stream_stream_call(self): process = subprocess.Popen( BASE_COMMAND + diff --git a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py index ad847ae03e..1ada25382d 100644 --- a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py +++ b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py @@ -14,7 +14,7 @@ _BEFORE_IMPORT = tuple(globals()) -from grpc import * # pylint: disable=wildcard-import +from grpc import * # pylint: disable=wildcard-import,unused-wildcard-import _AFTER_IMPORT = tuple(globals()) diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 99db0ac58b..a647e5e720 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -337,6 +337,7 @@ class InterceptorTest(unittest.TestCase): def tearDown(self): self._server.stop(None) self._server_pool.shutdown(wait=True) + self._channel.close() def testTripleRequestMessagesClientInterceptor(self): diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 0ff49490d5..7ed7c83893 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -62,6 +62,9 @@ class InvalidMetadataTest(unittest.TestCase): self._stream_unary = _stream_unary_multi_callable(self._channel) self._stream_stream = _stream_stream_multi_callable(self._channel) + def tearDown(self): + self._channel.close() + def testUnaryRequestBlockingUnaryResponse(self): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index 00949e2236..e89b521cc5 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -215,6 +215,7 @@ class InvocationDefectsTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testIterableStreamRequestBlockingUnaryResponse(self): requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)] diff --git a/src/python/grpcio_tests/tests/unit/_logging_test.py b/src/python/grpcio_tests/tests/unit/_logging_test.py index 631b9de9db..8ff127f506 100644 --- a/src/python/grpcio_tests/tests/unit/_logging_test.py +++ b/src/python/grpcio_tests/tests/unit/_logging_test.py @@ -14,66 +14,86 @@ """Test of gRPC Python's interaction with the python logging module""" import unittest -import six -from six.moves import reload_module import logging import grpc -import functools +import subprocess import sys +INTERPRETER = sys.executable -def patch_stderr(f): - @functools.wraps(f) - def _impl(*args, **kwargs): - old_stderr = sys.stderr - sys.stderr = six.StringIO() - try: - f(*args, **kwargs) - finally: - sys.stderr = old_stderr +class LoggingTest(unittest.TestCase): - return _impl + def test_logger_not_occupied(self): + script = """if True: + import logging + import grpc -def isolated_logging(f): + if len(logging.getLogger().handlers) != 0: + raise Exception('expected 0 logging handlers') - @functools.wraps(f) - def _impl(*args, **kwargs): - reload_module(logging) - reload_module(grpc) - try: - f(*args, **kwargs) - finally: - reload_module(logging) + """ + self._verifyScriptSucceeds(script) - return _impl + def test_handler_found(self): + script = """if True: + import logging + import grpc + """ + out, err = self._verifyScriptSucceeds(script) + self.assertEqual(0, len(err), 'unexpected output to stderr') -class LoggingTest(unittest.TestCase): + def test_can_configure_logger(self): + script = """if True: + import logging + import six - @isolated_logging - def test_logger_not_occupied(self): - self.assertEqual(0, len(logging.getLogger().handlers)) + import grpc - @patch_stderr - @isolated_logging - def test_handler_found(self): - self.assertEqual(0, len(sys.stderr.getvalue())) - @isolated_logging - def test_can_configure_logger(self): - intended_stream = six.StringIO() - logging.basicConfig(stream=intended_stream) - self.assertEqual(1, len(logging.getLogger().handlers)) - self.assertIs(logging.getLogger().handlers[0].stream, intended_stream) + intended_stream = six.StringIO() + logging.basicConfig(stream=intended_stream) + + if len(logging.getLogger().handlers) != 1: + raise Exception('expected 1 logging handler') + + if logging.getLogger().handlers[0].stream is not intended_stream: + raise Exception('wrong handler stream') + + """ + self._verifyScriptSucceeds(script) - @isolated_logging def test_grpc_logger(self): - self.assertIn("grpc", logging.Logger.manager.loggerDict) - root_logger = logging.getLogger("grpc") - self.assertEqual(1, len(root_logger.handlers)) - self.assertIsInstance(root_logger.handlers[0], logging.NullHandler) + script = """if True: + import logging + + import grpc + + if "grpc" not in logging.Logger.manager.loggerDict: + raise Exception('grpc logger not found') + + root_logger = logging.getLogger("grpc") + if len(root_logger.handlers) != 1: + raise Exception('expected 1 root logger handler') + if not isinstance(root_logger.handlers[0], logging.NullHandler): + raise Exception('expected logging.NullHandler') + + """ + self._verifyScriptSucceeds(script) + + def _verifyScriptSucceeds(self, script): + process = subprocess.Popen( + [INTERPRETER, '-c', script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = process.communicate() + self.assertEqual( + 0, process.returncode, + 'process failed with exit code %d (stdout: %s, stderr: %s)' % + (process.returncode, out, err)) + return out, err if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 0dafab827a..a63664ac5d 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -198,8 +198,8 @@ class MetadataCodeDetailsTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) - self._unary_unary = channel.unary_unary( + self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = self._channel.unary_unary( '/'.join(( '', _SERVICE, @@ -208,17 +208,17 @@ class MetadataCodeDetailsTest(unittest.TestCase): request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) - self._unary_stream = channel.unary_stream('/'.join(( + self._unary_stream = self._channel.unary_stream('/'.join(( '', _SERVICE, _UNARY_STREAM, )),) - self._stream_unary = channel.stream_unary('/'.join(( + self._stream_unary = self._channel.stream_unary('/'.join(( '', _SERVICE, _STREAM_UNARY, )),) - self._stream_stream = channel.stream_stream( + self._stream_stream = self._channel.stream_stream( '/'.join(( '', _SERVICE, @@ -228,6 +228,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): response_deserializer=_RESPONSE_DESERIALIZER, ) + def tearDown(self): + self._server.stop(None) + self._channel.close() + def testSuccessfulUnaryUnary(self): self._servicer.set_details(_DETAILS) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 2d352e99d4..7b32b5b5f3 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -187,13 +187,14 @@ class MetadataFlagsTest(unittest.TestCase): def test_call_wait_for_ready_default(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast(perform_call, - create_dummy_channel()) + with create_dummy_channel() as channel: + self.check_connection_does_failfast(perform_call, channel) def test_call_wait_for_ready_disabled(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast( - perform_call, create_dummy_channel(), wait_for_ready=False) + with create_dummy_channel() as channel: + self.check_connection_does_failfast( + perform_call, channel, wait_for_ready=False) def test_call_wait_for_ready_enabled(self): # To test the wait mechanism, Python thread is required to make @@ -210,16 +211,16 @@ class MetadataFlagsTest(unittest.TestCase): wg.done() def test_call(perform_call): - try: - channel = grpc.insecure_channel(addr) - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The channel - # object can be collected before calling the callback, which - # will result in a deadlock. - wg.done() - unhandled_exceptions.put(e, True) + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) test_threads = [] for perform_call in _ALL_CALL_CASES: diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 777ab683e3..892df3df08 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -186,6 +186,7 @@ class MetadataTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index f6d4fcbd0a..d4ea126e2b 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -98,6 +98,8 @@ class ReconnectTest(unittest.TestCase): server.add_insecure_port('[::]:{}'.format(port)) server.start() self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + server.stop(None) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index 4fead8fcd5..517c2d2f97 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -148,6 +148,7 @@ class ResourceExhaustedTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py index a768d6c7c1..a99121cee5 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py @@ -193,6 +193,7 @@ class RPCTest(unittest.TestCase): def tearDown(self): self._server.stop(None) + self._channel.close() def testUnrecognizedMethod(self): request = b'abc' diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py new file mode 100644 index 0000000000..1d1fdba11e --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -0,0 +1,97 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines a number of module-scope gRPC scenarios to test server shutdown.""" + +import argparse +import os +import threading +import time +import logging + +import grpc +from tests.unit import test_common + +from concurrent import futures +from six.moves import queue + +WAIT_TIME = 1000 + +REQUEST = b'request' +RESPONSE = b'response' + +SERVER_RAISES_EXCEPTION = 'server_raises_exception' +SERVER_DEALLOCATED = 'server_deallocated' +SERVER_FORK_CAN_EXIT = 'server_fork_can_exit' + +FORK_EXIT = '/test/ForkExit' + + +def fork_and_exit(request, servicer_context): + pid = os.fork() + if pid == 0: + os._exit(0) + return RESPONSE + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == FORK_EXIT: + return grpc.unary_unary_rpc_method_handler(fork_and_exit) + else: + return None + + +def run_server(port_queue): + server = test_common.test_server() + port = server.add_insecure_port('[::]:0') + port_queue.put(port) + server.add_generic_rpc_handlers((GenericHandler(),)) + server.start() + # threading.Event.wait() does not exhibit the bug identified in + # https://github.com/grpc/grpc/issues/17093, sleep instead + time.sleep(WAIT_TIME) + + +def run_test(args): + if args.scenario == SERVER_RAISES_EXCEPTION: + server = test_common.test_server() + server.start() + raise Exception() + elif args.scenario == SERVER_DEALLOCATED: + server = test_common.test_server() + server.start() + server.__del__() + while server._state.stage != grpc._server._ServerStage.STOPPED: + pass + elif args.scenario == SERVER_FORK_CAN_EXIT: + port_queue = queue.Queue() + thread = threading.Thread(target=run_server, args=(port_queue,)) + thread.daemon = True + thread.start() + port = port_queue.get() + channel = grpc.insecure_channel('localhost:%d' % port) + multi_callable = channel.unary_unary(FORK_EXIT) + result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) + os.wait() + else: + raise ValueError('unknown test scenario') + + +if __name__ == '__main__': + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + args = parser.parse_args() + run_test(args) diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py new file mode 100644 index 0000000000..47446d65a5 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests clean shutdown of server on various interpreter exit conditions. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import subprocess +import sys +import threading +import unittest +import logging + +from tests.unit import _server_shutdown_scenarios + +SCENARIO_FILE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '_server_shutdown_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: # pylint: disable=broad-except + pass + + +atexit.register(cleanup_processes) + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ServerShutdown(unittest.TestCase): + + # Currently we shut down a server (if possible) after the Python server + # instance is garbage collected. This behavior may change in the future. + def test_deallocated_server_stops(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + def test_server_exception_exits(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + @unittest.skipIf(os.name == 'nt', 'fork not supported on windows') + def test_server_fork_can_exit(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_version_test.py b/src/python/grpcio_tests/tests/unit/_version_test.py new file mode 100644 index 0000000000..3d37b319e5 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_version_test.py @@ -0,0 +1,30 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for grpc.__version__""" + +import unittest +import grpc +import logging +from grpc import _grpcio_metadata + + +class VersionTest(unittest.TestCase): + + def test_get_version(self): + self.assertEqual(grpc.__version__, _grpcio_metadata.__version__) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/ruby/end2end/graceful_sig_handling_client.rb b/src/ruby/end2end/graceful_sig_handling_client.rb new file mode 100755 index 0000000000..14a67a62cc --- /dev/null +++ b/src/ruby/end2end/graceful_sig_handling_client.rb @@ -0,0 +1,61 @@ +#!/usr/bin/env ruby + +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +require_relative './end2end_common' + +# Test client. Sends RPC's as normal but process also has signal handlers +class SigHandlingClientController < ClientControl::ClientController::Service + def initialize(stub) + @stub = stub + end + + def do_echo_rpc(req, _) + response = @stub.echo(Echo::EchoRequest.new(request: req.request)) + fail 'bad response' unless response.response == req.request + ClientControl::Void.new + end +end + +def main + client_control_port = '' + server_port = '' + OptionParser.new do |opts| + opts.on('--client_control_port=P', String) do |p| + client_control_port = p + end + opts.on('--server_port=P', String) do |p| + server_port = p + end + end.parse! + + # Allow a few seconds to be safe. + srv = new_rpc_server_for_testing + srv.add_http2_port("0.0.0.0:#{client_control_port}", + :this_port_is_insecure) + stub = Echo::EchoServer::Stub.new("localhost:#{server_port}", + :this_channel_is_insecure) + control_service = SigHandlingClientController.new(stub) + srv.handle(control_service) + server_thread = Thread.new do + srv.run_till_terminated_or_interrupted(['int']) + end + srv.wait_till_running + # send a first RPC to notify the parent process that we've started + stub.echo(Echo::EchoRequest.new(request: 'client/child started')) + server_thread.join +end + +main diff --git a/src/ruby/end2end/graceful_sig_handling_driver.rb b/src/ruby/end2end/graceful_sig_handling_driver.rb new file mode 100755 index 0000000000..e12ae28485 --- /dev/null +++ b/src/ruby/end2end/graceful_sig_handling_driver.rb @@ -0,0 +1,83 @@ +#!/usr/bin/env ruby + +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# smoke test for a grpc-using app that receives and +# handles process-ending signals + +require_relative './end2end_common' + +# A service that calls back it's received_rpc_callback +# upon receiving an RPC. Used for synchronization/waiting +# for child process to start. +class ClientStartedService < Echo::EchoServer::Service + def initialize(received_rpc_callback) + @received_rpc_callback = received_rpc_callback + end + + def echo(echo_req, _) + @received_rpc_callback.call unless @received_rpc_callback.nil? + @received_rpc_callback = nil + Echo::EchoReply.new(response: echo_req.request) + end +end + +def main + STDERR.puts 'start server' + client_started = false + client_started_mu = Mutex.new + client_started_cv = ConditionVariable.new + received_rpc_callback = proc do + client_started_mu.synchronize do + client_started = true + client_started_cv.signal + end + end + + client_started_service = ClientStartedService.new(received_rpc_callback) + server_runner = ServerRunner.new(client_started_service) + server_port = server_runner.run + STDERR.puts 'start client' + control_stub, client_pid = start_client('graceful_sig_handling_client.rb', server_port) + + client_started_mu.synchronize do + client_started_cv.wait(client_started_mu) until client_started + end + + control_stub.do_echo_rpc( + ClientControl::DoEchoRpcRequest.new(request: 'hello')) + + STDERR.puts 'killing client' + Process.kill('SIGINT', client_pid) + Process.wait(client_pid) + client_exit_status = $CHILD_STATUS + + if client_exit_status.exited? + if client_exit_status.exitstatus != 0 + STDERR.puts 'Client did not close gracefully' + exit(1) + end + else + STDERR.puts 'Client did not close gracefully' + exit(1) + end + + STDERR.puts 'Client ended gracefully' + + # no need to call cleanup, client should already be dead + server_runner.stop +end + +main diff --git a/src/ruby/end2end/graceful_sig_stop_client.rb b/src/ruby/end2end/graceful_sig_stop_client.rb new file mode 100755 index 0000000000..b672dc3f2a --- /dev/null +++ b/src/ruby/end2end/graceful_sig_stop_client.rb @@ -0,0 +1,78 @@ +#!/usr/bin/env ruby + +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +require_relative './end2end_common' + +# Test client. Sends RPC's as normal but process also has signal handlers +class SigHandlingClientController < ClientControl::ClientController::Service + def initialize(srv, stub) + @srv = srv + @stub = stub + end + + def do_echo_rpc(req, _) + response = @stub.echo(Echo::EchoRequest.new(request: req.request)) + fail 'bad response' unless response.response == req.request + ClientControl::Void.new + end + + def shutdown(_, _) + # Spawn a new thread because RpcServer#stop is + # synchronous and blocks until either this RPC has finished, + # or the server's "poll_period" seconds have passed. + @shutdown_thread = Thread.new do + @srv.stop + end + ClientControl::Void.new + end + + def join_shutdown_thread + @shutdown_thread.join + end +end + +def main + client_control_port = '' + server_port = '' + OptionParser.new do |opts| + opts.on('--client_control_port=P', String) do |p| + client_control_port = p + end + opts.on('--server_port=P', String) do |p| + server_port = p + end + end.parse! + + # The "shutdown" RPC should end very quickly. + # Allow a few seconds to be safe. + srv = new_rpc_server_for_testing(poll_period: 3) + srv.add_http2_port("0.0.0.0:#{client_control_port}", + :this_port_is_insecure) + stub = Echo::EchoServer::Stub.new("localhost:#{server_port}", + :this_channel_is_insecure) + control_service = SigHandlingClientController.new(srv, stub) + srv.handle(control_service) + server_thread = Thread.new do + srv.run_till_terminated_or_interrupted(['int']) + end + srv.wait_till_running + # send a first RPC to notify the parent process that we've started + stub.echo(Echo::EchoRequest.new(request: 'client/child started')) + server_thread.join + control_service.join_shutdown_thread +end + +main diff --git a/src/ruby/end2end/graceful_sig_stop_driver.rb b/src/ruby/end2end/graceful_sig_stop_driver.rb new file mode 100755 index 0000000000..7a132403eb --- /dev/null +++ b/src/ruby/end2end/graceful_sig_stop_driver.rb @@ -0,0 +1,62 @@ +#!/usr/bin/env ruby + +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# smoke test for a grpc-using app that receives and +# handles process-ending signals + +require_relative './end2end_common' + +# A service that calls back it's received_rpc_callback +# upon receiving an RPC. Used for synchronization/waiting +# for child process to start. +class ClientStartedService < Echo::EchoServer::Service + def initialize(received_rpc_callback) + @received_rpc_callback = received_rpc_callback + end + + def echo(echo_req, _) + @received_rpc_callback.call unless @received_rpc_callback.nil? + @received_rpc_callback = nil + Echo::EchoReply.new(response: echo_req.request) + end +end + +def main + STDERR.puts 'start server' + client_started = false + client_started_mu = Mutex.new + client_started_cv = ConditionVariable.new + received_rpc_callback = proc do + client_started_mu.synchronize do + client_started = true + client_started_cv.signal + end + end + + client_started_service = ClientStartedService.new(received_rpc_callback) + server_runner = ServerRunner.new(client_started_service) + server_port = server_runner.run + STDERR.puts 'start client' + control_stub, client_pid = start_client('./graceful_sig_stop_client.rb', server_port) + + client_started_mu.synchronize do + client_started_cv.wait(client_started_mu) until client_started + end + + cleanup(control_stub, client_pid, server_runner) +end + +main diff --git a/src/ruby/lib/grpc/generic/rpc_server.rb b/src/ruby/lib/grpc/generic/rpc_server.rb index 3b5a0ce27f..f0f73dc56e 100644 --- a/src/ruby/lib/grpc/generic/rpc_server.rb +++ b/src/ruby/lib/grpc/generic/rpc_server.rb @@ -240,6 +240,13 @@ module GRPC # the call has no impact if the server is already stopped, otherwise # server's current call loop is it's last. def stop + # if called via run_till_terminated_or_interrupted, + # signal stop_server_thread and dont do anything + if @stop_server.nil? == false && @stop_server == false + @stop_server = true + @stop_server_cv.broadcast + return + end @run_mutex.synchronize do fail 'Cannot stop before starting' if @running_state == :not_started return if @running_state != :running @@ -354,6 +361,60 @@ module GRPC alias_method :run_till_terminated, :run + # runs the server with signal handlers + # @param signals + # List of String, Integer or both representing signals that the user + # would like to send to the server for graceful shutdown + # @param wait_interval (optional) + # Integer seconds that user would like stop_server_thread to poll + # stop_server + def run_till_terminated_or_interrupted(signals, wait_interval = 60) + @stop_server = false + @stop_server_mu = Mutex.new + @stop_server_cv = ConditionVariable.new + + @stop_server_thread = Thread.new do + loop do + break if @stop_server + @stop_server_mu.synchronize do + @stop_server_cv.wait(@stop_server_mu, wait_interval) + end + end + + # stop is surrounded by mutex, should handle multiple calls to stop + # correctly + stop + end + + valid_signals = Signal.list + + # register signal handlers + signals.each do |sig| + # input validation + if sig.class == String + sig.upcase! + if sig.start_with?('SIG') + # cut out the SIG prefix to see if valid signal + sig = sig[3..-1] + end + end + + # register signal traps for all valid signals + if valid_signals.value?(sig) || valid_signals.key?(sig) + Signal.trap(sig) do + @stop_server = true + @stop_server_cv.broadcast + end + else + fail "#{sig} not a valid signal" + end + end + + run + + @stop_server_thread.join + end + # Sends RESOURCE_EXHAUSTED if there are too many unprocessed jobs def available?(an_rpc) return an_rpc if @pool.ready_for_work? diff --git a/src/ruby/lib/grpc/version.rb b/src/ruby/lib/grpc/version.rb index a4ed052d85..3b7f62d9f5 100644 --- a/src/ruby/lib/grpc/version.rb +++ b/src/ruby/lib/grpc/version.rb @@ -14,5 +14,5 @@ # GRPC contains the General RPC module. module GRPC - VERSION = '1.18.0.dev' + VERSION = '1.19.0.dev' end diff --git a/src/ruby/pb/test/client.rb b/src/ruby/pb/test/client.rb index b2303c6e14..03f3d9001a 100755 --- a/src/ruby/pb/test/client.rb +++ b/src/ruby/pb/test/client.rb @@ -111,10 +111,13 @@ def create_stub(opts) if opts.secure creds = ssl_creds(opts.use_test_ca) stub_opts = { - channel_args: { - GRPC::Core::Channel::SSL_TARGET => opts.server_host_override - } + channel_args: {} } + unless opts.server_host_override.empty? + stub_opts[:channel_args].merge!({ + GRPC::Core::Channel::SSL_TARGET => opts.server_host_override + }) + end # Add service account creds if specified wants_creds = %w(all compute_engine_creds service_account_creds) @@ -603,7 +606,7 @@ class NamedTests if not op.metadata.has_key?(initial_metadata_key) fail AssertionError, "Expected initial metadata. None received" elsif op.metadata[initial_metadata_key] != metadata[initial_metadata_key] - fail AssertionError, + fail AssertionError, "Expected initial metadata: #{metadata[initial_metadata_key]}. "\ "Received: #{op.metadata[initial_metadata_key]}" end @@ -611,7 +614,7 @@ class NamedTests fail AssertionError, "Expected trailing metadata. None received" elsif op.trailing_metadata[trailing_metadata_key] != metadata[trailing_metadata_key] - fail AssertionError, + fail AssertionError, "Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\ "Received: #{op.trailing_metadata[trailing_metadata_key]}" end @@ -639,7 +642,7 @@ class NamedTests fail AssertionError, "Expected trailing metadata. None received" elsif duplex_op.trailing_metadata[trailing_metadata_key] != metadata[trailing_metadata_key] - fail AssertionError, + fail AssertionError, "Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\ "Received: #{duplex_op.trailing_metadata[trailing_metadata_key]}" end @@ -710,7 +713,7 @@ Args = Struct.new(:default_service_account, :server_host, :server_host_override, # validates the command line options, returning them as a Hash. def parse_args args = Args.new - args.server_host_override = 'foo.test.google.fr' + args.server_host_override = '' OptionParser.new do |opts| opts.on('--oauth_scope scope', 'Scope for OAuth tokens') { |v| args['oauth_scope'] = v } diff --git a/src/ruby/tools/version.rb b/src/ruby/tools/version.rb index 389fb70684..2ad685a7eb 100644 --- a/src/ruby/tools/version.rb +++ b/src/ruby/tools/version.rb @@ -14,6 +14,6 @@ module GRPC module Tools - VERSION = '1.18.0.dev' + VERSION = '1.19.0.dev' end end |