diff options
author | Yash Tibrewal <yashkt@google.com> | 2018-12-26 15:10:41 -0800 |
---|---|---|
committer | Yash Tibrewal <yashkt@google.com> | 2018-12-26 15:10:41 -0800 |
commit | 24e37e249a4db24ff2c886960e3a00311e2591dd (patch) | |
tree | 7bcb4a002432e7ef04054750b1efe8dfd9ba8ade /src | |
parent | 0911e489e3fe22e2ca5d7c927dac83358f2f05b7 (diff) | |
parent | fc7d0911a3a44d7bc926d3db99b7300a0c0f33dc (diff) |
Merge branch 'master' into failhijackedrecv
Diffstat (limited to 'src')
400 files changed, 11328 insertions, 7093 deletions
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 7986aca696..b004687250 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -132,6 +132,7 @@ grpc::string GetHeaderIncludes(grpc_generator::File* file, "grpcpp/impl/codegen/async_generic_service.h", "grpcpp/impl/codegen/async_stream.h", "grpcpp/impl/codegen/async_unary_call.h", + "grpcpp/impl/codegen/client_callback.h", "grpcpp/impl/codegen/method_handler_impl.h", "grpcpp/impl/codegen/proto_utils.h", "grpcpp/impl/codegen/rpc_method.h", @@ -580,11 +581,22 @@ void PrintHeaderClientMethodCallbackInterfaces( "const $Request$* request, $Response$* response, " "std::function<void(::grpc::Status)>) = 0;\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add support for client-side streaming + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor< $Request$>* " + "reactor) = 0;\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add support for server-side streaming + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor< $Response$>* " + "reactor) = 0;\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add support for bidi streaming + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< " + "$Request$,$Response$>* reactor) = 0;\n"); } } @@ -631,11 +643,23 @@ void PrintHeaderClientMethodCallback(grpc_generator::Printer* printer, "const $Request$* request, $Response$* response, " "std::function<void(::grpc::Status)>) override;\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add support for client-side streaming + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor< $Request$>* " + "reactor) override;\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add support for server-side streaming + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor< $Response$>* " + "reactor) override;\n"); + } else if (method->BidiStreaming()) { - // TODO(vjpai): Add support for bidi streaming + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< " + "$Request$,$Response$>* reactor) override;\n"); } } @@ -865,6 +889,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerReadReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedReadReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } else if (ServerOnlyStreaming(method)) { printer->Print( *vars, @@ -876,6 +905,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerWriteReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedWriteReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } else if (method->BidiStreaming()) { printer->Print( *vars, @@ -887,6 +921,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerBidiReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedBidiReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } } @@ -915,22 +954,36 @@ void PrintHeaderServerMethodCallback( *vars, " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" " new ::grpc::internal::CallbackUnaryHandler< " - "ExperimentalWithCallbackMethod_$Method$<BaseClass>, $RealRequest$, " - "$RealResponse$>(\n" + "$RealRequest$, $RealResponse$>(\n" " [this](::grpc::ServerContext* context,\n" " const $RealRequest$* request,\n" " $RealResponse$* response,\n" " ::grpc::experimental::ServerCallbackRpcController* " "controller) {\n" - " this->$" + " return this->$" "Method$(context, request, response, controller);\n" - " }, this));\n"); + " }));\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } printer->Print(*vars, "}\n"); printer->Print(*vars, @@ -967,8 +1020,7 @@ void PrintHeaderServerMethodRawCallback( *vars, " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" " new ::grpc::internal::CallbackUnaryHandler< " - "ExperimentalWithRawCallbackMethod_$Method$<BaseClass>, $RealRequest$, " - "$RealResponse$>(\n" + "$RealRequest$, $RealResponse$>(\n" " [this](::grpc::ServerContext* context,\n" " const $RealRequest$* request,\n" " $RealResponse$* response,\n" @@ -976,13 +1028,28 @@ void PrintHeaderServerMethodRawCallback( "controller) {\n" " this->$" "Method$(context, request, response, controller);\n" - " }, this));\n"); + " }));\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } printer->Print(*vars, "}\n"); printer->Print(*vars, @@ -1607,7 +1674,19 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context, response);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print( + *vars, + "void $ns$$Service$::" + "Stub::experimental_async::$Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor< $Request$>* reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackWriterFactory< " + "$Request$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, response, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; @@ -1641,7 +1720,19 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context, request);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print( + *vars, + "void $ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor< $Response$>* reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackReaderFactory< " + "$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, request, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; @@ -1675,7 +1766,19 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print( + *vars, + "void $ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< $Request$,$Response$>* " + "reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackReaderWriterFactory< " + "$Request$,$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; diff --git a/src/compiler/csharp_generator.cc b/src/compiler/csharp_generator.cc index a923ce8e38..59ddbd82f6 100644 --- a/src/compiler/csharp_generator.cc +++ b/src/compiler/csharp_generator.cc @@ -609,6 +609,42 @@ void GenerateBindServiceMethod(Printer* out, const ServiceDescriptor* service) { out->Print("\n"); } +void GenerateBindServiceWithBinderMethod(Printer* out, + const ServiceDescriptor* service) { + out->Print( + "/// <summary>Register service method implementations with a service " + "binder. Useful when customizing the service binding logic.\n" + "/// Note: this method is part of an experimental API that can change or " + "be " + "removed without any prior notice.</summary>\n"); + out->Print( + "/// <param name=\"serviceBinder\">Service methods will be bound by " + "calling <c>AddMethod</c> on this object." + "</param>\n"); + out->Print( + "/// <param name=\"serviceImpl\">An object implementing the server-side" + " handling logic.</param>\n"); + out->Print( + "public static void BindService(grpc::ServiceBinderBase serviceBinder, " + "$implclass$ " + "serviceImpl)\n", + "implclass", GetServerClassName(service)); + out->Print("{\n"); + out->Indent(); + + for (int i = 0; i < service->method_count(); i++) { + const MethodDescriptor* method = service->method(i); + out->Print( + "serviceBinder.AddMethod($methodfield$, serviceImpl.$methodname$);\n", + "methodfield", GetMethodFieldName(method), "methodname", + method->name()); + } + + out->Outdent(); + out->Print("}\n"); + out->Print("\n"); +} + void GenerateService(Printer* out, const ServiceDescriptor* service, bool generate_client, bool generate_server, bool internal_access) { @@ -637,6 +673,7 @@ void GenerateService(Printer* out, const ServiceDescriptor* service, } if (generate_server) { GenerateBindServiceMethod(out, service); + GenerateBindServiceWithBinderMethod(out, service); } out->Outdent(); diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc index d42961fc60..dd741f1e2d 100644 --- a/src/core/ext/filters/client_channel/client_channel.cc +++ b/src/core/ext/filters/client_channel/client_channel.cc @@ -34,9 +34,10 @@ #include "src/core/ext/filters/client_channel/backup_poller.h" #include "src/core/ext/filters/client_channel/http_connect_handshaker.h" #include "src/core/ext/filters/client_channel/lb_policy_registry.h" -#include "src/core/ext/filters/client_channel/method_params.h" #include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/request_routing.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" #include "src/core/ext/filters/client_channel/retry_throttle.h" #include "src/core/ext/filters/client_channel/subchannel.h" #include "src/core/ext/filters/deadline/deadline_filter.h" @@ -63,6 +64,8 @@ #include "src/core/lib/transport/status_metadata.h" using grpc_core::internal::ClientChannelMethodParams; +using grpc_core::internal::ClientChannelMethodParamsTable; +using grpc_core::internal::ProcessedResolverResult; using grpc_core::internal::ServerRetryThrottleData; /* Client channel implementation */ @@ -83,36 +86,19 @@ grpc_core::TraceFlag grpc_client_channel_trace(false, "client_channel"); struct external_connectivity_watcher; -typedef grpc_core::SliceHashTable< - grpc_core::RefCountedPtr<ClientChannelMethodParams>> - MethodParamsTable; - typedef struct client_channel_channel_data { - grpc_core::OrphanablePtr<grpc_core::Resolver> resolver; - bool started_resolving; + grpc_core::ManualConstructor<grpc_core::RequestRouter> request_router; + bool deadline_checking_enabled; - grpc_client_channel_factory* client_channel_factory; bool enable_retries; size_t per_rpc_retry_buffer_size; /** combiner protecting all variables below in this data structure */ grpc_combiner* combiner; - /** currently active load balancer */ - grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> lb_policy; /** retry throttle data */ grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data; /** maps method names to method_parameters structs */ - grpc_core::RefCountedPtr<MethodParamsTable> method_params_table; - /** incoming resolver result - set by resolver.next() */ - grpc_channel_args* resolver_result; - /** a list of closures that are all waiting for resolver result to come in */ - grpc_closure_list waiting_for_resolver_result_closures; - /** resolver callback */ - grpc_closure on_resolver_result_changed; - /** connectivity state being tracked */ - grpc_connectivity_state_tracker state_tracker; - /** when an lb_policy arrives, should we try to exit idle */ - bool exit_idle_when_lb_policy_arrives; + grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> method_params_table; /** owning stack */ grpc_channel_stack* owning_stack; /** interested parties (owned) */ @@ -129,541 +115,40 @@ typedef struct client_channel_channel_data { grpc_core::UniquePtr<char> info_lb_policy_name; /** service config in JSON form */ grpc_core::UniquePtr<char> info_service_config_json; - /* backpointer to grpc_channel's channelz node */ - grpc_core::channelz::ClientChannelNode* channelz_channel; - /* caches if the last resolution event contained addresses */ - bool previous_resolution_contained_addresses; } channel_data; -typedef struct { - channel_data* chand; - /** used as an identifier, don't dereference it because the LB policy may be - * non-existing when the callback is run */ - grpc_core::LoadBalancingPolicy* lb_policy; - grpc_closure closure; -} reresolution_request_args; - -/** We create one watcher for each new lb_policy that is returned from a - resolver, to watch for state changes from the lb_policy. When a state - change is seen, we update the channel, and create a new watcher. */ -typedef struct { - channel_data* chand; - grpc_closure on_changed; - grpc_connectivity_state state; - grpc_core::LoadBalancingPolicy* lb_policy; -} lb_policy_connectivity_watcher; - -static void watch_lb_policy_locked(channel_data* chand, - grpc_core::LoadBalancingPolicy* lb_policy, - grpc_connectivity_state current_state); - -static const char* channel_connectivity_state_change_string( - grpc_connectivity_state state) { - switch (state) { - case GRPC_CHANNEL_IDLE: - return "Channel state change to IDLE"; - case GRPC_CHANNEL_CONNECTING: - return "Channel state change to CONNECTING"; - case GRPC_CHANNEL_READY: - return "Channel state change to READY"; - case GRPC_CHANNEL_TRANSIENT_FAILURE: - return "Channel state change to TRANSIENT_FAILURE"; - case GRPC_CHANNEL_SHUTDOWN: - return "Channel state change to SHUTDOWN"; - } - GPR_UNREACHABLE_CODE(return "UNKNOWN"); -} - -static void set_channel_connectivity_state_locked(channel_data* chand, - grpc_connectivity_state state, - grpc_error* error, - const char* reason) { - /* TODO: Improve failure handling: - * - Make it possible for policies to return GRPC_CHANNEL_TRANSIENT_FAILURE. - * - Hand over pending picks from old policies during the switch that happens - * when resolver provides an update. */ - if (chand->lb_policy != nullptr) { - if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { - /* cancel picks with wait_for_ready=false */ - chand->lb_policy->CancelMatchingPicksLocked( - /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY, - /* check= */ 0, GRPC_ERROR_REF(error)); - } else if (state == GRPC_CHANNEL_SHUTDOWN) { - /* cancel all picks */ - chand->lb_policy->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0, - GRPC_ERROR_REF(error)); - } - } - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: setting connectivity state to %s", chand, - grpc_connectivity_state_name(state)); - } - if (chand->channelz_channel != nullptr) { - chand->channelz_channel->AddTraceEvent( - grpc_core::channelz::ChannelTrace::Severity::Info, - grpc_slice_from_static_string( - channel_connectivity_state_change_string(state))); - } - grpc_connectivity_state_set(&chand->state_tracker, state, error, reason); -} - -static void on_lb_policy_state_changed_locked(void* arg, grpc_error* error) { - lb_policy_connectivity_watcher* w = - static_cast<lb_policy_connectivity_watcher*>(arg); - /* check if the notification is for the latest policy */ - if (w->lb_policy == w->chand->lb_policy.get()) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: lb_policy=%p state changed to %s", w->chand, - w->lb_policy, grpc_connectivity_state_name(w->state)); - } - set_channel_connectivity_state_locked(w->chand, w->state, - GRPC_ERROR_REF(error), "lb_changed"); - if (w->state != GRPC_CHANNEL_SHUTDOWN) { - watch_lb_policy_locked(w->chand, w->lb_policy, w->state); - } - } - GRPC_CHANNEL_STACK_UNREF(w->chand->owning_stack, "watch_lb_policy"); - gpr_free(w); -} - -static void watch_lb_policy_locked(channel_data* chand, - grpc_core::LoadBalancingPolicy* lb_policy, - grpc_connectivity_state current_state) { - lb_policy_connectivity_watcher* w = - static_cast<lb_policy_connectivity_watcher*>(gpr_malloc(sizeof(*w))); - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "watch_lb_policy"); - w->chand = chand; - GRPC_CLOSURE_INIT(&w->on_changed, on_lb_policy_state_changed_locked, w, - grpc_combiner_scheduler(chand->combiner)); - w->state = current_state; - w->lb_policy = lb_policy; - lb_policy->NotifyOnStateChangeLocked(&w->state, &w->on_changed); -} - -static void start_resolving_locked(channel_data* chand) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: starting name resolution", chand); - } - GPR_ASSERT(!chand->started_resolving); - chand->started_resolving = true; - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "resolver"); - chand->resolver->NextLocked(&chand->resolver_result, - &chand->on_resolver_result_changed); -} - -typedef struct { - char* server_name; - grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data; -} service_config_parsing_state; - -static void parse_retry_throttle_params( - const grpc_json* field, service_config_parsing_state* parsing_state) { - if (strcmp(field->key, "retryThrottling") == 0) { - if (parsing_state->retry_throttle_data != nullptr) return; // Duplicate. - if (field->type != GRPC_JSON_OBJECT) return; - int max_milli_tokens = 0; - int milli_token_ratio = 0; - for (grpc_json* sub_field = field->child; sub_field != nullptr; - sub_field = sub_field->next) { - if (sub_field->key == nullptr) return; - if (strcmp(sub_field->key, "maxTokens") == 0) { - if (max_milli_tokens != 0) return; // Duplicate. - if (sub_field->type != GRPC_JSON_NUMBER) return; - max_milli_tokens = gpr_parse_nonnegative_int(sub_field->value); - if (max_milli_tokens == -1) return; - max_milli_tokens *= 1000; - } else if (strcmp(sub_field->key, "tokenRatio") == 0) { - if (milli_token_ratio != 0) return; // Duplicate. - if (sub_field->type != GRPC_JSON_NUMBER) return; - // We support up to 3 decimal digits. - size_t whole_len = strlen(sub_field->value); - uint32_t multiplier = 1; - uint32_t decimal_value = 0; - const char* decimal_point = strchr(sub_field->value, '.'); - if (decimal_point != nullptr) { - whole_len = static_cast<size_t>(decimal_point - sub_field->value); - multiplier = 1000; - size_t decimal_len = strlen(decimal_point + 1); - if (decimal_len > 3) decimal_len = 3; - if (!gpr_parse_bytes_to_uint32(decimal_point + 1, decimal_len, - &decimal_value)) { - return; - } - uint32_t decimal_multiplier = 1; - for (size_t i = 0; i < (3 - decimal_len); ++i) { - decimal_multiplier *= 10; - } - decimal_value *= decimal_multiplier; - } - uint32_t whole_value; - if (!gpr_parse_bytes_to_uint32(sub_field->value, whole_len, - &whole_value)) { - return; - } - milli_token_ratio = - static_cast<int>((whole_value * multiplier) + decimal_value); - if (milli_token_ratio <= 0) return; - } - } - parsing_state->retry_throttle_data = - grpc_core::internal::ServerRetryThrottleMap::GetDataForServer( - parsing_state->server_name, max_milli_tokens, milli_token_ratio); - } -} - -// Invoked from the resolver NextLocked() callback when the resolver -// is shutting down. -static void on_resolver_shutdown_locked(channel_data* chand, - grpc_error* error) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down", chand); - } - if (chand->lb_policy != nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand, - chand->lb_policy.get()); - } - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } - if (chand->resolver != nullptr) { - // This should never happen; it can only be triggered by a resolver - // implementation spotaneously deciding to report shutdown without - // being orphaned. This code is included just to be defensive. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: spontaneous shutdown from resolver %p", - chand, chand->resolver.get()); - } - chand->resolver.reset(); - set_channel_connectivity_state_locked( - chand, GRPC_CHANNEL_SHUTDOWN, - GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Resolver spontaneous shutdown", &error, 1), - "resolver_spontaneous_shutdown"); - } - grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures, - GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Channel disconnected", &error, 1)); - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "resolver"); - grpc_channel_args_destroy(chand->resolver_result); - chand->resolver_result = nullptr; - GRPC_ERROR_UNREF(error); -} - -// Returns the LB policy name from the resolver result. -static grpc_core::UniquePtr<char> -get_lb_policy_name_from_resolver_result_locked(channel_data* chand) { - // Find LB policy name in channel args. - const grpc_arg* channel_arg = - grpc_channel_args_find(chand->resolver_result, GRPC_ARG_LB_POLICY_NAME); - const char* lb_policy_name = grpc_channel_arg_get_string(channel_arg); - // Special case: If at least one balancer address is present, we use - // the grpclb policy, regardless of what the resolver actually specified. - channel_arg = - grpc_channel_args_find(chand->resolver_result, GRPC_ARG_LB_ADDRESSES); - if (channel_arg != nullptr && channel_arg->type == GRPC_ARG_POINTER) { - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(channel_arg->value.pointer.p); - if (grpc_lb_addresses_contains_balancer_address(*addresses)) { - if (lb_policy_name != nullptr && - gpr_stricmp(lb_policy_name, "grpclb") != 0) { - gpr_log(GPR_INFO, - "resolver requested LB policy %s but provided at least one " - "balancer address -- forcing use of grpclb LB policy", - lb_policy_name); - } - lb_policy_name = "grpclb"; - } - } - // Use pick_first if nothing was specified and we didn't select grpclb - // above. - if (lb_policy_name == nullptr) lb_policy_name = "pick_first"; - return grpc_core::UniquePtr<char>(gpr_strdup(lb_policy_name)); -} - -static void request_reresolution_locked(void* arg, grpc_error* error) { - reresolution_request_args* args = - static_cast<reresolution_request_args*>(arg); - channel_data* chand = args->chand; - // If this invocation is for a stale LB policy, treat it as an LB shutdown - // signal. - if (args->lb_policy != chand->lb_policy.get() || error != GRPC_ERROR_NONE || - chand->resolver == nullptr) { - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "re-resolution"); - gpr_free(args); - return; - } - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: started name re-resolving", chand); - } - chand->resolver->RequestReresolutionLocked(); - // Give back the closure to the LB policy. - chand->lb_policy->SetReresolutionClosureLocked(&args->closure); -} - -using TraceStringVector = grpc_core::InlinedVector<char*, 3>; - -// Creates a new LB policy, replacing any previous one. -// If the new policy is created successfully, sets *connectivity_state and -// *connectivity_error to its initial connectivity state; otherwise, -// leaves them unchanged. -static void create_new_lb_policy_locked( - channel_data* chand, char* lb_policy_name, - grpc_connectivity_state* connectivity_state, - grpc_error** connectivity_error, TraceStringVector* trace_strings) { - grpc_core::LoadBalancingPolicy::Args lb_policy_args; - lb_policy_args.combiner = chand->combiner; - lb_policy_args.client_channel_factory = chand->client_channel_factory; - lb_policy_args.args = chand->resolver_result; - grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> new_lb_policy = - grpc_core::LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( - lb_policy_name, lb_policy_args); - if (GPR_UNLIKELY(new_lb_policy == nullptr)) { - gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name); - if (chand->channelz_channel != nullptr) { - char* str; - gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name); - trace_strings->push_back(str); - } - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: created new LB policy \"%s\" (%p)", chand, - lb_policy_name, new_lb_policy.get()); - } - if (chand->channelz_channel != nullptr) { - char* str; - gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name); - trace_strings->push_back(str); - } - // Swap out the LB policy and update the fds in - // chand->interested_parties. - if (chand->lb_policy != nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand, - chand->lb_policy.get()); - } - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy->HandOffPendingPicksLocked(new_lb_policy.get()); - } - chand->lb_policy = std::move(new_lb_policy); - grpc_pollset_set_add_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - // Set up re-resolution callback. - reresolution_request_args* args = - static_cast<reresolution_request_args*>(gpr_zalloc(sizeof(*args))); - args->chand = chand; - args->lb_policy = chand->lb_policy.get(); - GRPC_CLOSURE_INIT(&args->closure, request_reresolution_locked, args, - grpc_combiner_scheduler(chand->combiner)); - GRPC_CHANNEL_STACK_REF(chand->owning_stack, "re-resolution"); - chand->lb_policy->SetReresolutionClosureLocked(&args->closure); - // Get the new LB policy's initial connectivity state and start a - // connectivity watch. - GRPC_ERROR_UNREF(*connectivity_error); - *connectivity_state = - chand->lb_policy->CheckConnectivityLocked(connectivity_error); - if (chand->exit_idle_when_lb_policy_arrives) { - chand->lb_policy->ExitIdleLocked(); - chand->exit_idle_when_lb_policy_arrives = false; - } - watch_lb_policy_locked(chand, chand->lb_policy.get(), *connectivity_state); - } -} - -// Returns the service config (as a JSON string) from the resolver result. -// Also updates state in chand. -static grpc_core::UniquePtr<char> -get_service_config_from_resolver_result_locked(channel_data* chand) { - const grpc_arg* channel_arg = - grpc_channel_args_find(chand->resolver_result, GRPC_ARG_SERVICE_CONFIG); - const char* service_config_json = grpc_channel_arg_get_string(channel_arg); - if (service_config_json != nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"", - chand, service_config_json); - } - grpc_core::UniquePtr<grpc_core::ServiceConfig> service_config = - grpc_core::ServiceConfig::Create(service_config_json); - if (service_config != nullptr) { - if (chand->enable_retries) { - channel_arg = - grpc_channel_args_find(chand->resolver_result, GRPC_ARG_SERVER_URI); - const char* server_uri = grpc_channel_arg_get_string(channel_arg); - GPR_ASSERT(server_uri != nullptr); - grpc_uri* uri = grpc_uri_parse(server_uri, true); - GPR_ASSERT(uri->path[0] != '\0'); - service_config_parsing_state parsing_state; - parsing_state.server_name = - uri->path[0] == '/' ? uri->path + 1 : uri->path; - service_config->ParseGlobalParams(parse_retry_throttle_params, - &parsing_state); - grpc_uri_destroy(uri); - chand->retry_throttle_data = - std::move(parsing_state.retry_throttle_data); - } - chand->method_params_table = service_config->CreateMethodConfigTable( - ClientChannelMethodParams::CreateFromJson); - } - } - return grpc_core::UniquePtr<char>(gpr_strdup(service_config_json)); -} - -static void maybe_add_trace_message_for_address_changes_locked( - channel_data* chand, TraceStringVector* trace_strings) { - int resolution_contains_addresses = false; - const grpc_arg* channel_arg = - grpc_channel_args_find(chand->resolver_result, GRPC_ARG_LB_ADDRESSES); - if (channel_arg != nullptr && channel_arg->type == GRPC_ARG_POINTER) { - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(channel_arg->value.pointer.p); - if (addresses->num_addresses > 0) { - resolution_contains_addresses = true; - } - } - if (!resolution_contains_addresses && - chand->previous_resolution_contained_addresses) { - trace_strings->push_back(gpr_strdup("Address list became empty")); - } else if (resolution_contains_addresses && - !chand->previous_resolution_contained_addresses) { - trace_strings->push_back(gpr_strdup("Address list became non-empty")); - } - chand->previous_resolution_contained_addresses = - resolution_contains_addresses; -} - -static void concatenate_and_add_channel_trace_locked( - channel_data* chand, TraceStringVector* trace_strings) { - if (!trace_strings->empty()) { - gpr_strvec v; - gpr_strvec_init(&v); - gpr_strvec_add(&v, gpr_strdup("Resolution event: ")); - bool is_first = 1; - for (size_t i = 0; i < trace_strings->size(); ++i) { - if (!is_first) gpr_strvec_add(&v, gpr_strdup(", ")); - is_first = false; - gpr_strvec_add(&v, (*trace_strings)[i]); - } - char* flat; - size_t flat_len = 0; - flat = gpr_strvec_flatten(&v, &flat_len); - chand->channelz_channel->AddTraceEvent( - grpc_core::channelz::ChannelTrace::Severity::Info, - grpc_slice_new(flat, flat_len, gpr_free)); - gpr_strvec_destroy(&v); - } -} - -// Callback invoked when a resolver result is available. -static void on_resolver_result_changed_locked(void* arg, grpc_error* error) { +// Synchronous callback from chand->request_router to process a resolver +// result update. +static bool process_resolver_result_locked(void* arg, + const grpc_channel_args& args, + const char** lb_policy_name, + grpc_json** lb_policy_config) { channel_data* chand = static_cast<channel_data*>(arg); + ProcessedResolverResult resolver_result(args, chand->enable_retries); + grpc_core::UniquePtr<char> service_config_json = + resolver_result.service_config_json(); if (grpc_client_channel_trace.enabled()) { - const char* disposition = - chand->resolver_result != nullptr - ? "" - : (error == GRPC_ERROR_NONE ? " (transient error)" - : " (resolver shutdown)"); - gpr_log(GPR_INFO, - "chand=%p: got resolver result: resolver_result=%p error=%s%s", - chand, chand->resolver_result, grpc_error_string(error), - disposition); - } - // Handle shutdown. - if (error != GRPC_ERROR_NONE || chand->resolver == nullptr) { - on_resolver_shutdown_locked(chand, GRPC_ERROR_REF(error)); - return; - } - // Data used to set the channel's connectivity state. - bool set_connectivity_state = true; - // We only want to trace the address resolution in the follow cases: - // (a) Address resolution resulted in service config change. - // (b) Address resolution that causes number of backends to go from - // zero to non-zero. - // (c) Address resolution that causes number of backends to go from - // non-zero to zero. - // (d) Address resolution that causes a new LB policy to be created. - // - // we track a list of strings to eventually be concatenated and traced. - TraceStringVector trace_strings; - grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; - grpc_error* connectivity_error = - GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy"); - // chand->resolver_result will be null in the case of a transient - // resolution error. In that case, we don't have any new result to - // process, which means that we keep using the previous result (if any). - if (chand->resolver_result == nullptr) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: resolver transient failure", chand); - } - } else { - grpc_core::UniquePtr<char> lb_policy_name = - get_lb_policy_name_from_resolver_result_locked(chand); - // Check to see if we're already using the right LB policy. - // Note: It's safe to use chand->info_lb_policy_name here without - // taking a lock on chand->info_mu, because this function is the - // only thing that modifies its value, and it can only be invoked - // once at any given time. - bool lb_policy_name_changed = chand->info_lb_policy_name == nullptr || - gpr_stricmp(chand->info_lb_policy_name.get(), - lb_policy_name.get()) != 0; - if (chand->lb_policy != nullptr && !lb_policy_name_changed) { - // Continue using the same LB policy. Update with new addresses. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p: updating existing LB policy \"%s\" (%p)", - chand, lb_policy_name.get(), chand->lb_policy.get()); - } - chand->lb_policy->UpdateLocked(*chand->resolver_result); - // No need to set the channel's connectivity state; the existing - // watch on the LB policy will take care of that. - set_connectivity_state = false; - } else { - // Instantiate new LB policy. - create_new_lb_policy_locked(chand, lb_policy_name.get(), - &connectivity_state, &connectivity_error, - &trace_strings); - } - // Find service config. - grpc_core::UniquePtr<char> service_config_json = - get_service_config_from_resolver_result_locked(chand); - // Note: It's safe to use chand->info_service_config_json here without - // taking a lock on chand->info_mu, because this function is the - // only thing that modifies its value, and it can only be invoked - // once at any given time. - if (chand->channelz_channel != nullptr) { - if (((service_config_json == nullptr) != - (chand->info_service_config_json == nullptr)) || - (service_config_json != nullptr && - strcmp(service_config_json.get(), - chand->info_service_config_json.get()) != 0)) { - // TODO(ncteisen): might be worth somehow including a snippet of the - // config in the trace, at the risk of bloating the trace logs. - trace_strings.push_back(gpr_strdup("Service config changed")); - } - maybe_add_trace_message_for_address_changes_locked(chand, &trace_strings); - concatenate_and_add_channel_trace_locked(chand, &trace_strings); - } - // Swap out the data used by cc_get_channel_info(). - gpr_mu_lock(&chand->info_mu); - chand->info_lb_policy_name = std::move(lb_policy_name); - chand->info_service_config_json = std::move(service_config_json); - gpr_mu_unlock(&chand->info_mu); - // Clean up. - grpc_channel_args_destroy(chand->resolver_result); - chand->resolver_result = nullptr; - } - // Set the channel's connectivity state if needed. - if (set_connectivity_state) { - set_channel_connectivity_state_locked( - chand, connectivity_state, connectivity_error, "resolver_result"); - } else { - GRPC_ERROR_UNREF(connectivity_error); + gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"", + chand, service_config_json.get()); } - // Invoke closures that were waiting for results and renew the watch. - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - chand->resolver->NextLocked(&chand->resolver_result, - &chand->on_resolver_result_changed); + // Update channel state. + chand->retry_throttle_data = resolver_result.retry_throttle_data(); + chand->method_params_table = resolver_result.method_params_table(); + // Swap out the data used by cc_get_channel_info(). + gpr_mu_lock(&chand->info_mu); + chand->info_lb_policy_name = resolver_result.lb_policy_name(); + const bool service_config_changed = + ((service_config_json == nullptr) != + (chand->info_service_config_json == nullptr)) || + (service_config_json != nullptr && + strcmp(service_config_json.get(), + chand->info_service_config_json.get()) != 0); + chand->info_service_config_json = std::move(service_config_json); + gpr_mu_unlock(&chand->info_mu); + // Return results. + *lb_policy_name = chand->info_lb_policy_name.get(); + *lb_policy_config = resolver_result.lb_policy_config(); + return service_config_changed; } static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { @@ -673,15 +158,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); if (op->on_connectivity_state_change != nullptr) { - grpc_connectivity_state_notify_on_state_change( - &chand->state_tracker, op->connectivity_state, - op->on_connectivity_state_change); + chand->request_router->NotifyOnConnectivityStateChange( + op->connectivity_state, op->on_connectivity_state_change); op->on_connectivity_state_change = nullptr; op->connectivity_state = nullptr; } if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) { - if (chand->lb_policy == nullptr) { + if (chand->request_router->lb_policy() == nullptr) { grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("Ping with no load balancing"); GRPC_CLOSURE_SCHED(op->send_ping.on_initiate, GRPC_ERROR_REF(error)); @@ -689,14 +173,9 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { } else { grpc_error* error = GRPC_ERROR_NONE; grpc_core::LoadBalancingPolicy::PickState pick_state; - pick_state.initial_metadata = nullptr; - pick_state.initial_metadata_flags = 0; - pick_state.on_complete = nullptr; - memset(&pick_state.subchannel_call_context, 0, - sizeof(pick_state.subchannel_call_context)); - pick_state.user_data = nullptr; // Pick must return synchronously, because pick_state.on_complete is null. - GPR_ASSERT(chand->lb_policy->PickLocked(&pick_state, &error)); + GPR_ASSERT( + chand->request_router->lb_policy()->PickLocked(&pick_state, &error)); if (pick_state.connected_subchannel != nullptr) { pick_state.connected_subchannel->Ping(op->send_ping.on_initiate, op->send_ping.on_ack); @@ -715,37 +194,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) { } if (op->disconnect_with_error != GRPC_ERROR_NONE) { - if (chand->resolver != nullptr) { - set_channel_connectivity_state_locked( - chand, GRPC_CHANNEL_SHUTDOWN, - GRPC_ERROR_REF(op->disconnect_with_error), "disconnect"); - chand->resolver.reset(); - if (!chand->started_resolving) { - grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures, - GRPC_ERROR_REF(op->disconnect_with_error)); - GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures); - } - if (chand->lb_policy != nullptr) { - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } - } - GRPC_ERROR_UNREF(op->disconnect_with_error); + chand->request_router->ShutdownLocked(op->disconnect_with_error); } if (op->reset_connect_backoff) { - if (chand->resolver != nullptr) { - chand->resolver->ResetBackoffLocked(); - chand->resolver->RequestReresolutionLocked(); - } - if (chand->lb_policy != nullptr) { - chand->lb_policy->ResetBackoffLocked(); - } + chand->request_router->ResetConnectionBackoffLocked(); } GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "start_transport_op"); - GRPC_CLOSURE_SCHED(op->on_consumed, GRPC_ERROR_NONE); } @@ -796,12 +252,9 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); chand->owning_stack = args->channel_stack; - GRPC_CLOSURE_INIT(&chand->on_resolver_result_changed, - on_resolver_result_changed_locked, chand, - grpc_combiner_scheduler(chand->combiner)); + chand->deadline_checking_enabled = + grpc_deadline_checking_enabled(args->channel_args); chand->interested_parties = grpc_pollset_set_create(); - grpc_connectivity_state_init(&chand->state_tracker, GRPC_CHANNEL_IDLE, - "client_channel"); grpc_client_channel_start_backup_polling(chand->interested_parties); // Record max per-RPC retry buffer size. const grpc_arg* arg = grpc_channel_args_find( @@ -811,8 +264,6 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, // Record enable_retries. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_ENABLE_RETRIES); chand->enable_retries = grpc_channel_arg_get_bool(arg, true); - chand->channelz_channel = nullptr; - chand->previous_resolution_contained_addresses = false; // Record client channel factory. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_CLIENT_CHANNEL_FACTORY); @@ -824,9 +275,7 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "client channel factory arg must be a pointer"); } - grpc_client_channel_factory_ref( - static_cast<grpc_client_channel_factory*>(arg->value.pointer.p)); - chand->client_channel_factory = + grpc_client_channel_factory* client_channel_factory = static_cast<grpc_client_channel_factory*>(arg->value.pointer.p); // Get server name to resolve, using proxy mapper if needed. arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_SERVER_URI); @@ -842,39 +291,24 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem, grpc_channel_args* new_args = nullptr; grpc_proxy_mappers_map_name(arg->value.string, args->channel_args, &proxy_name, &new_args); - // Instantiate resolver. - chand->resolver = grpc_core::ResolverRegistry::CreateResolver( - proxy_name != nullptr ? proxy_name : arg->value.string, - new_args != nullptr ? new_args : args->channel_args, - chand->interested_parties, chand->combiner); - if (proxy_name != nullptr) gpr_free(proxy_name); - if (new_args != nullptr) grpc_channel_args_destroy(new_args); - if (chand->resolver == nullptr) { - return GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed"); - } - chand->deadline_checking_enabled = - grpc_deadline_checking_enabled(args->channel_args); - return GRPC_ERROR_NONE; + // Instantiate request router. + grpc_client_channel_factory_ref(client_channel_factory); + grpc_error* error = GRPC_ERROR_NONE; + chand->request_router.Init( + chand->owning_stack, chand->combiner, client_channel_factory, + chand->interested_parties, &grpc_client_channel_trace, + process_resolver_result_locked, chand, + proxy_name != nullptr ? proxy_name : arg->value.string /* target_uri */, + new_args != nullptr ? new_args : args->channel_args, &error); + gpr_free(proxy_name); + grpc_channel_args_destroy(new_args); + return error; } /* Destructor for channel_data */ static void cc_destroy_channel_elem(grpc_channel_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - if (chand->resolver != nullptr) { - // The only way we can get here is if we never started resolving, - // because we take a ref to the channel stack when we start - // resolving and do not release it until the resolver callback is - // invoked after the resolver shuts down. - chand->resolver.reset(); - } - if (chand->client_channel_factory != nullptr) { - grpc_client_channel_factory_unref(chand->client_channel_factory); - } - if (chand->lb_policy != nullptr) { - grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(), - chand->interested_parties); - chand->lb_policy.reset(); - } + chand->request_router.Destroy(); // TODO(roth): Once we convert the filter API to C++, there will no // longer be any need to explicitly reset these smart pointer data members. chand->info_lb_policy_name.reset(); @@ -882,7 +316,6 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) { chand->retry_throttle_data.reset(); chand->method_params_table.reset(); grpc_client_channel_stop_backup_polling(chand->interested_parties); - grpc_connectivity_state_destroy(&chand->state_tracker); grpc_pollset_set_destroy(chand->interested_parties); GRPC_COMBINER_UNREF(chand->combiner, "client_channel"); gpr_mu_destroy(&chand->info_mu); @@ -939,6 +372,7 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) { // - add census stats for retries namespace { + struct call_data; // State used for starting a retryable batch on a subchannel call. @@ -1024,21 +458,25 @@ struct subchannel_call_retry_state { bool started_recv_trailing_metadata : 1; bool completed_recv_trailing_metadata : 1; // State for callback processing. - bool retry_dispatched : 1; subchannel_batch_data* recv_initial_metadata_ready_deferred_batch = nullptr; grpc_error* recv_initial_metadata_error = GRPC_ERROR_NONE; subchannel_batch_data* recv_message_ready_deferred_batch = nullptr; grpc_error* recv_message_error = GRPC_ERROR_NONE; subchannel_batch_data* recv_trailing_metadata_internal_batch = nullptr; + // NOTE: Do not move this next to the metadata bitfields above. That would + // save space but will also result in a data race because compiler will + // generate a 2 byte store which overwrites the meta-data fields upon + // setting this field. + bool retry_dispatched : 1; }; // Pending batches stored in call data. -typedef struct { +struct pending_batch { // The pending batch. If nullptr, this slot is empty. grpc_transport_stream_op_batch* batch; // Indicates whether payload for send ops has been cached in call data. bool send_ops_cached; -} pending_batch; +}; /** Call data. Holds a pointer to grpc_subchannel_call and the associated machinery to create such a pointer. @@ -1075,11 +513,8 @@ struct call_data { for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches); ++i) { GPR_ASSERT(pending_batches[i].batch == nullptr); } - for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { - if (pick.subchannel_call_context[i].value != nullptr) { - pick.subchannel_call_context[i].destroy( - pick.subchannel_call_context[i].value); - } + if (have_request) { + request.Destroy(); } } @@ -1106,12 +541,11 @@ struct call_data { // Set when we get a cancel_stream op. grpc_error* cancel_error = GRPC_ERROR_NONE; - grpc_core::LoadBalancingPolicy::PickState pick; + grpc_core::ManualConstructor<grpc_core::RequestRouter::Request> request; + bool have_request = false; grpc_closure pick_closure; - grpc_closure pick_cancel_closure; grpc_polling_entity* pollent = nullptr; - bool pollent_added_to_interested_parties = false; // Batches are added to this list when received from above. // They are removed when we are done handling the batch (i.e., when @@ -1161,6 +595,7 @@ struct call_data { grpc_linked_mdelem* send_trailing_metadata_storage = nullptr; grpc_metadata_batch send_trailing_metadata; }; + } // namespace // Forward declarations. @@ -1563,8 +998,9 @@ static void do_retry(grpc_call_element* elem, "client_channel_call_retry"); calld->subchannel_call = nullptr; } - if (calld->pick.connected_subchannel != nullptr) { - calld->pick.connected_subchannel.reset(); + if (calld->have_request) { + calld->have_request = false; + calld->request.Destroy(); } // Compute backoff delay. grpc_millis next_attempt_time; @@ -1713,6 +1149,7 @@ static bool maybe_retry(grpc_call_element* elem, // namespace { + subchannel_batch_data::subchannel_batch_data(grpc_call_element* elem, call_data* calld, int refcount, bool set_on_complete) @@ -1753,6 +1190,7 @@ void subchannel_batch_data::destroy() { call_data* calld = static_cast<call_data*>(elem->call_data); GRPC_CALL_STACK_UNREF(calld->owning_call, "batch_data"); } + } // namespace // Creates a subchannel_batch_data object on the call's arena with the @@ -2769,17 +2207,18 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { const size_t parent_data_size = calld->enable_retries ? sizeof(subchannel_call_retry_state) : 0; const grpc_core::ConnectedSubchannel::CallArgs call_args = { - calld->pollent, // pollent - calld->path, // path - calld->call_start_time, // start_time - calld->deadline, // deadline - calld->arena, // arena - calld->pick.subchannel_call_context, // context - calld->call_combiner, // call_combiner - parent_data_size // parent_data_size + calld->pollent, // pollent + calld->path, // path + calld->call_start_time, // start_time + calld->deadline, // deadline + calld->arena, // arena + calld->request->pick()->subchannel_call_context, // context + calld->call_combiner, // call_combiner + parent_data_size // parent_data_size }; - grpc_error* new_error = calld->pick.connected_subchannel->CreateCall( - call_args, &calld->subchannel_call); + grpc_error* new_error = + calld->request->pick()->connected_subchannel->CreateCall( + call_args, &calld->subchannel_call); if (grpc_client_channel_trace.enabled()) { gpr_log(GPR_INFO, "chand=%p calld=%p: create subchannel_call=%p: error=%s", chand, calld, calld->subchannel_call, grpc_error_string(new_error)); @@ -2791,7 +2230,8 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) { if (parent_data_size > 0) { new (grpc_connected_subchannel_call_get_parent_data( calld->subchannel_call)) - subchannel_call_retry_state(calld->pick.subchannel_call_context); + subchannel_call_retry_state( + calld->request->pick()->subchannel_call_context); } pending_batches_resume(elem); } @@ -2803,7 +2243,7 @@ static void pick_done(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); channel_data* chand = static_cast<channel_data*>(elem->channel_data); call_data* calld = static_cast<call_data*>(elem->call_data); - if (GPR_UNLIKELY(calld->pick.connected_subchannel == nullptr)) { + if (GPR_UNLIKELY(calld->request->pick()->connected_subchannel == nullptr)) { // Failed to create subchannel. // If there was no error, this is an LB policy drop, in which case // we return an error; otherwise, we may retry. @@ -2832,135 +2272,27 @@ static void pick_done(void* arg, grpc_error* error) { } } -static void maybe_add_call_to_channel_interested_parties_locked( - grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (!calld->pollent_added_to_interested_parties) { - calld->pollent_added_to_interested_parties = true; - grpc_polling_entity_add_to_pollset_set(calld->pollent, - chand->interested_parties); - } -} - -static void maybe_del_call_from_channel_interested_parties_locked( - grpc_call_element* elem) { +// If the channel is in TRANSIENT_FAILURE and the call is not +// wait_for_ready=true, fails the call and returns true. +static bool fail_call_if_in_transient_failure(grpc_call_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); call_data* calld = static_cast<call_data*>(elem->call_data); - if (calld->pollent_added_to_interested_parties) { - calld->pollent_added_to_interested_parties = false; - grpc_polling_entity_del_from_pollset_set(calld->pollent, - chand->interested_parties); + grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch; + if (chand->request_router->GetConnectivityState() == + GRPC_CHANNEL_TRANSIENT_FAILURE && + (batch->payload->send_initial_metadata.send_initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) { + pending_batches_fail( + elem, + grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "channel is in state TRANSIENT_FAILURE"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), + true /* yield_call_combiner */); + return true; } + return false; } -// Invoked when a pick is completed to leave the client_channel combiner -// and continue processing in the call combiner. -// If needed, removes the call's polling entity from chand->interested_parties. -static void pick_done_locked(grpc_call_element* elem, grpc_error* error) { - call_data* calld = static_cast<call_data*>(elem->call_data); - maybe_del_call_from_channel_interested_parties_locked(elem); - GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_SCHED(&calld->pick_closure, error); -} - -namespace grpc_core { - -// Performs subchannel pick via LB policy. -class LbPicker { - public: - // Starts a pick on chand->lb_policy. - static void StartLocked(grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: starting pick on lb_policy=%p", - chand, calld, chand->lb_policy.get()); - } - // If this is a retry, use the send_initial_metadata payload that - // we've cached; otherwise, use the pending batch. The - // send_initial_metadata batch will be the first pending batch in the - // list, as set by get_batch_index() above. - calld->pick.initial_metadata = - calld->seen_send_initial_metadata - ? &calld->send_initial_metadata - : calld->pending_batches[0] - .batch->payload->send_initial_metadata.send_initial_metadata; - calld->pick.initial_metadata_flags = - calld->seen_send_initial_metadata - ? calld->send_initial_metadata_flags - : calld->pending_batches[0] - .batch->payload->send_initial_metadata - .send_initial_metadata_flags; - GRPC_CLOSURE_INIT(&calld->pick_closure, &LbPicker::DoneLocked, elem, - grpc_combiner_scheduler(chand->combiner)); - calld->pick.on_complete = &calld->pick_closure; - GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback"); - grpc_error* error = GRPC_ERROR_NONE; - const bool pick_done = chand->lb_policy->PickLocked(&calld->pick, &error); - if (GPR_LIKELY(pick_done)) { - // Pick completed synchronously. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed synchronously", - chand, calld); - } - pick_done_locked(elem, error); - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback"); - } else { - // Pick will be returned asynchronously. - // Add the polling entity from call_data to the channel_data's - // interested_parties, so that the I/O of the LB policy can be done - // under it. It will be removed in pick_done_locked(). - maybe_add_call_to_channel_interested_parties_locked(elem); - // Request notification on call cancellation. - GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback_cancel"); - grpc_call_combiner_set_notify_on_cancel( - calld->call_combiner, - GRPC_CLOSURE_INIT(&calld->pick_cancel_closure, - &LbPicker::CancelLocked, elem, - grpc_combiner_scheduler(chand->combiner))); - } - } - - private: - // Callback invoked by LoadBalancingPolicy::PickLocked() for async picks. - // Unrefs the LB policy and invokes pick_done_locked(). - static void DoneLocked(void* arg, grpc_error* error) { - grpc_call_element* elem = static_cast<grpc_call_element*>(arg); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed asynchronously", - chand, calld); - } - pick_done_locked(elem, GRPC_ERROR_REF(error)); - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback"); - } - - // Note: This runs under the client_channel combiner, but will NOT be - // holding the call combiner. - static void CancelLocked(void* arg, grpc_error* error) { - grpc_call_element* elem = static_cast<grpc_call_element*>(arg); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - // Note: chand->lb_policy may have changed since we started our pick, - // in which case we will be cancelling the pick on a policy other than - // the one we started it on. However, this will just be a no-op. - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE && chand->lb_policy != nullptr)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: cancelling pick from LB policy %p", chand, - calld, chand->lb_policy.get()); - } - chand->lb_policy->CancelPickLocked(&calld->pick, GRPC_ERROR_REF(error)); - } - GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback_cancel"); - } -}; - -} // namespace grpc_core - // Applies service config to the call. Must be invoked once we know // that the resolver has returned results to the channel. static void apply_service_config_to_call_locked(grpc_call_element* elem) { @@ -3017,224 +2349,66 @@ static void apply_service_config_to_call_locked(grpc_call_element* elem) { } } -// If the channel is in TRANSIENT_FAILURE and the call is not -// wait_for_ready=true, fails the call and returns true. -static bool fail_call_if_in_transient_failure(grpc_call_element* elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch; - if (grpc_connectivity_state_check(&chand->state_tracker) == - GRPC_CHANNEL_TRANSIENT_FAILURE && - (batch->payload->send_initial_metadata.send_initial_metadata_flags & - GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) { - pending_batches_fail( - elem, - grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "channel is in state TRANSIENT_FAILURE"), - GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE), - true /* yield_call_combiner */); - return true; - } - return false; -} - // Invoked once resolver results are available. -static void process_service_config_and_start_lb_pick_locked( - grpc_call_element* elem) { +static bool maybe_apply_service_config_to_call_locked(void* arg) { + grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); // Only get service config data on the first attempt. if (GPR_LIKELY(calld->num_attempts_completed == 0)) { apply_service_config_to_call_locked(elem); // Check this after applying service config, since it may have // affected the call's wait_for_ready value. - if (fail_call_if_in_transient_failure(elem)) return; + if (fail_call_if_in_transient_failure(elem)) return false; } - // Start LB pick. - grpc_core::LbPicker::StartLocked(elem); + return true; } -namespace grpc_core { - -// Handles waiting for a resolver result. -// Used only for the first call on an idle channel. -class ResolverResultWaiter { - public: - explicit ResolverResultWaiter(grpc_call_element* elem) : elem_(elem) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: deferring pick pending resolver result", - chand, calld); - } - // Add closure to be run when a resolver result is available. - GRPC_CLOSURE_INIT(&done_closure_, &ResolverResultWaiter::DoneLocked, this, - grpc_combiner_scheduler(chand->combiner)); - AddToWaitingList(); - // Set cancellation closure, so that we abort if the call is cancelled. - GRPC_CLOSURE_INIT(&cancel_closure_, &ResolverResultWaiter::CancelLocked, - this, grpc_combiner_scheduler(chand->combiner)); - grpc_call_combiner_set_notify_on_cancel(calld->call_combiner, - &cancel_closure_); - } - - private: - // Adds closure_ to chand->waiting_for_resolver_result_closures. - void AddToWaitingList() { - channel_data* chand = static_cast<channel_data*>(elem_->channel_data); - grpc_closure_list_append(&chand->waiting_for_resolver_result_closures, - &done_closure_, GRPC_ERROR_NONE); - } - - // Invoked when a resolver result is available. - static void DoneLocked(void* arg, grpc_error* error) { - ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); - // If CancelLocked() has already run, delete ourselves without doing - // anything. Note that the call stack may have already been destroyed, - // so it's not safe to access anything in elem_. - if (GPR_UNLIKELY(self->finished_)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "call cancelled before resolver result"); - } - Delete(self); - return; - } - // Otherwise, process the resolver result. - grpc_call_element* elem = self->elem_; - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver failed to return data", - chand, calld); - } - pick_done_locked(elem, GRPC_ERROR_REF(error)); - } else if (GPR_UNLIKELY(chand->resolver == nullptr)) { - // Shutting down. - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver disconnected", chand, - calld); - } - pick_done_locked(elem, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); - } else if (GPR_UNLIKELY(chand->lb_policy == nullptr)) { - // Transient resolver failure. - // If call has wait_for_ready=true, try again; otherwise, fail. - uint32_t send_initial_metadata_flags = - calld->seen_send_initial_metadata - ? calld->send_initial_metadata_flags - : calld->pending_batches[0] - .batch->payload->send_initial_metadata - .send_initial_metadata_flags; - if (send_initial_metadata_flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: resolver returned but no LB policy; " - "wait_for_ready=true; trying again", - chand, calld); - } - // Re-add ourselves to the waiting list. - self->AddToWaitingList(); - // Return early so that we don't set finished_ to true below. - return; - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: resolver returned but no LB policy; " - "wait_for_ready=false; failing", - chand, calld); - } - pick_done_locked( - elem, - grpc_error_set_int( - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"), - GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); - } - } else { - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, "chand=%p calld=%p: resolver returned, doing LB pick", - chand, calld); - } - process_service_config_and_start_lb_pick_locked(elem); - } - self->finished_ = true; - } - - // Invoked when the call is cancelled. - // Note: This runs under the client_channel combiner, but will NOT be - // holding the call combiner. - static void CancelLocked(void* arg, grpc_error* error) { - ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); - // If DoneLocked() has already run, delete ourselves without doing anything. - if (GPR_LIKELY(self->finished_)) { - Delete(self); - return; - } - // If we are being cancelled, immediately invoke pick_done_locked() - // to propagate the error back to the caller. - if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { - grpc_call_element* elem = self->elem_; - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - call_data* calld = static_cast<call_data*>(elem->call_data); - if (grpc_client_channel_trace.enabled()) { - gpr_log(GPR_INFO, - "chand=%p calld=%p: cancelling call waiting for name " - "resolution", - chand, calld); - } - // Note: Although we are not in the call combiner here, we are - // basically stealing the call combiner from the pending pick, so - // it's safe to call pick_done_locked() here -- we are essentially - // calling it here instead of calling it in DoneLocked(). - pick_done_locked(elem, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( - "Pick cancelled", &error, 1)); - } - self->finished_ = true; - } - - grpc_call_element* elem_; - grpc_closure done_closure_; - grpc_closure cancel_closure_; - bool finished_ = false; -}; - -} // namespace grpc_core - static void start_pick_locked(void* arg, grpc_error* ignored) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); - GPR_ASSERT(calld->pick.connected_subchannel == nullptr); + GPR_ASSERT(!calld->have_request); GPR_ASSERT(calld->subchannel_call == nullptr); - if (GPR_LIKELY(chand->lb_policy != nullptr)) { - // We already have resolver results, so process the service config - // and start an LB pick. - process_service_config_and_start_lb_pick_locked(elem); - } else if (GPR_UNLIKELY(chand->resolver == nullptr)) { - pick_done_locked(elem, - GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); - } else { - // We do not yet have an LB policy, so wait for a resolver result. - if (GPR_UNLIKELY(!chand->started_resolving)) { - start_resolving_locked(chand); - } else { - // Normally, we want to do this check in - // process_service_config_and_start_lb_pick_locked(), so that we - // can honor the wait_for_ready setting in the service config. - // However, if the channel is in TRANSIENT_FAILURE at this point, that - // means that the resolver has returned a failure, so we're not going - // to get a service config right away. In that case, we fail the - // call now based on the wait_for_ready value passed in from the - // application. - if (fail_call_if_in_transient_failure(elem)) return; - } - // Create a new waiter, which will delete itself when done. - grpc_core::New<grpc_core::ResolverResultWaiter>(elem); - // Add the polling entity from call_data to the channel_data's - // interested_parties, so that the I/O of the resolver can be done - // under it. It will be removed in pick_done_locked(). - maybe_add_call_to_channel_interested_parties_locked(elem); + // Normally, we want to do this check until after we've processed the + // service config, so that we can honor the wait_for_ready setting in + // the service config. However, if the channel is in TRANSIENT_FAILURE + // and we don't have an LB policy at this point, that means that the + // resolver has returned a failure, so we're not going to get a service + // config right away. In that case, we fail the call now based on the + // wait_for_ready value passed in from the application. + if (chand->request_router->lb_policy() == nullptr && + fail_call_if_in_transient_failure(elem)) { + return; } + // If this is a retry, use the send_initial_metadata payload that + // we've cached; otherwise, use the pending batch. The + // send_initial_metadata batch will be the first pending batch in the + // list, as set by get_batch_index() above. + // TODO(roth): What if the LB policy needs to add something to the + // call's initial metadata, and then there's a retry? We don't want + // the new metadata to be added twice. We might need to somehow + // allocate the subchannel batch earlier so that we can give the + // subchannel's copy of the metadata batch (which is copied for each + // attempt) to the LB policy instead the one from the parent channel. + grpc_metadata_batch* initial_metadata = + calld->seen_send_initial_metadata + ? &calld->send_initial_metadata + : calld->pending_batches[0] + .batch->payload->send_initial_metadata.send_initial_metadata; + uint32_t* initial_metadata_flags = + calld->seen_send_initial_metadata + ? &calld->send_initial_metadata_flags + : &calld->pending_batches[0] + .batch->payload->send_initial_metadata + .send_initial_metadata_flags; + GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem, + grpc_schedule_on_exec_ctx); + calld->request.Init(calld->owning_call, calld->call_combiner, calld->pollent, + initial_metadata, initial_metadata_flags, + maybe_apply_service_config_to_call_locked, elem, + &calld->pick_closure); + calld->have_request = true; + chand->request_router->RouteCallLocked(calld->request.get()); } // @@ -3374,23 +2548,10 @@ const grpc_channel_filter grpc_client_channel_filter = { "client-channel", }; -static void try_to_connect_locked(void* arg, grpc_error* error_ignored) { - channel_data* chand = static_cast<channel_data*>(arg); - if (chand->lb_policy != nullptr) { - chand->lb_policy->ExitIdleLocked(); - } else { - chand->exit_idle_when_lb_policy_arrives = true; - if (!chand->started_resolving && chand->resolver != nullptr) { - start_resolving_locked(chand); - } - } - GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect"); -} - void grpc_client_channel_set_channelz_node( grpc_channel_element* elem, grpc_core::channelz::ClientChannelNode* node) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - chand->channelz_channel = node; + chand->request_router->set_channelz_node(node); } void grpc_client_channel_populate_child_refs( @@ -3398,17 +2559,22 @@ void grpc_client_channel_populate_child_refs( grpc_core::channelz::ChildRefsList* child_subchannels, grpc_core::channelz::ChildRefsList* child_channels) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - if (chand->lb_policy != nullptr) { - chand->lb_policy->FillChildRefsForChannelz(child_subchannels, - child_channels); + if (chand->request_router->lb_policy() != nullptr) { + chand->request_router->lb_policy()->FillChildRefsForChannelz( + child_subchannels, child_channels); } } +static void try_to_connect_locked(void* arg, grpc_error* error_ignored) { + channel_data* chand = static_cast<channel_data*>(arg); + chand->request_router->ExitIdleLocked(); + GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect"); +} + grpc_connectivity_state grpc_client_channel_check_connectivity_state( grpc_channel_element* elem, int try_to_connect) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - grpc_connectivity_state out = - grpc_connectivity_state_check(&chand->state_tracker); + grpc_connectivity_state out = chand->request_router->GetConnectivityState(); if (out == GRPC_CHANNEL_IDLE && try_to_connect) { GRPC_CHANNEL_STACK_REF(chand->owning_stack, "try_to_connect"); GRPC_CLOSURE_SCHED( @@ -3453,19 +2619,19 @@ static void external_connectivity_watcher_list_append( } static void external_connectivity_watcher_list_remove( - channel_data* chand, external_connectivity_watcher* too_remove) { + channel_data* chand, external_connectivity_watcher* to_remove) { GPR_ASSERT( - lookup_external_connectivity_watcher(chand, too_remove->on_complete)); + lookup_external_connectivity_watcher(chand, to_remove->on_complete)); gpr_mu_lock(&chand->external_connectivity_watcher_list_mu); - if (too_remove == chand->external_connectivity_watcher_list_head) { - chand->external_connectivity_watcher_list_head = too_remove->next; + if (to_remove == chand->external_connectivity_watcher_list_head) { + chand->external_connectivity_watcher_list_head = to_remove->next; gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); return; } external_connectivity_watcher* w = chand->external_connectivity_watcher_list_head; while (w != nullptr) { - if (w->next == too_remove) { + if (w->next == to_remove) { w->next = w->next->next; gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu); return; @@ -3517,15 +2683,15 @@ static void watch_connectivity_state_locked(void* arg, GRPC_CLOSURE_RUN(w->watcher_timer_init, GRPC_ERROR_NONE); GRPC_CLOSURE_INIT(&w->my_closure, on_external_watch_complete_locked, w, grpc_combiner_scheduler(w->chand->combiner)); - grpc_connectivity_state_notify_on_state_change(&w->chand->state_tracker, - w->state, &w->my_closure); + w->chand->request_router->NotifyOnConnectivityStateChange(w->state, + &w->my_closure); } else { GPR_ASSERT(w->watcher_timer_init == nullptr); found = lookup_external_connectivity_watcher(w->chand, w->on_complete); if (found) { GPR_ASSERT(found->on_complete == w->on_complete); - grpc_connectivity_state_notify_on_state_change( - &found->chand->state_tracker, nullptr, &found->my_closure); + found->chand->request_router->NotifyOnConnectivityStateChange( + nullptr, &found->my_closure); } grpc_polling_entity_del_from_pollset_set(&w->pollent, w->chand->interested_parties); diff --git a/src/core/ext/filters/client_channel/health/health_check_client.cc b/src/core/ext/filters/client_channel/health/health_check_client.cc index 587919596f..2232c57120 100644 --- a/src/core/ext/filters/client_channel/health/health_check_client.cc +++ b/src/core/ext/filters/client_channel/health/health_check_client.cc @@ -51,8 +51,7 @@ HealthCheckClient::HealthCheckClient( RefCountedPtr<ConnectedSubchannel> connected_subchannel, grpc_pollset_set* interested_parties, grpc_core::RefCountedPtr<grpc_core::channelz::SubchannelNode> channelz_node) - : InternallyRefCountedWithTracing<HealthCheckClient>( - &grpc_health_check_client_trace), + : InternallyRefCounted<HealthCheckClient>(&grpc_health_check_client_trace), service_name_(service_name), connected_subchannel_(std::move(connected_subchannel)), interested_parties_(interested_parties), @@ -281,8 +280,7 @@ bool DecodeResponse(grpc_slice_buffer* slice_buffer, grpc_error** error) { HealthCheckClient::CallState::CallState( RefCountedPtr<HealthCheckClient> health_check_client, grpc_pollset_set* interested_parties) - : InternallyRefCountedWithTracing<CallState>( - &grpc_health_check_client_trace), + : InternallyRefCounted<CallState>(&grpc_health_check_client_trace), health_check_client_(std::move(health_check_client)), pollent_(grpc_polling_entity_create_from_pollset_set(interested_parties)), arena_(gpr_arena_create(health_check_client_->connected_subchannel_ diff --git a/src/core/ext/filters/client_channel/health/health_check_client.h b/src/core/ext/filters/client_channel/health/health_check_client.h index f6babef7d6..2369b73fea 100644 --- a/src/core/ext/filters/client_channel/health/health_check_client.h +++ b/src/core/ext/filters/client_channel/health/health_check_client.h @@ -41,8 +41,7 @@ namespace grpc_core { -class HealthCheckClient - : public InternallyRefCountedWithTracing<HealthCheckClient> { +class HealthCheckClient : public InternallyRefCounted<HealthCheckClient> { public: HealthCheckClient(const char* service_name, RefCountedPtr<ConnectedSubchannel> connected_subchannel, @@ -61,7 +60,7 @@ class HealthCheckClient private: // Contains a call to the backend and all the data related to the call. - class CallState : public InternallyRefCountedWithTracing<CallState> { + class CallState : public InternallyRefCounted<CallState> { public: CallState(RefCountedPtr<HealthCheckClient> health_check_client, grpc_pollset_set* interested_parties_); diff --git a/src/core/ext/filters/client_channel/lb_policy.cc b/src/core/ext/filters/client_channel/lb_policy.cc index e065f45639..b4e803689e 100644 --- a/src/core/ext/filters/client_channel/lb_policy.cc +++ b/src/core/ext/filters/client_channel/lb_policy.cc @@ -27,7 +27,7 @@ grpc_core::DebugOnlyTraceFlag grpc_trace_lb_policy_refcount( namespace grpc_core { LoadBalancingPolicy::LoadBalancingPolicy(const Args& args) - : InternallyRefCountedWithTracing(&grpc_trace_lb_policy_refcount), + : InternallyRefCounted(&grpc_trace_lb_policy_refcount), combiner_(GRPC_COMBINER_REF(args.combiner, "lb_policy")), client_channel_factory_(args.client_channel_factory), interested_parties_(grpc_pollset_set_create()), diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h index b0040457a6..293d8e960c 100644 --- a/src/core/ext/filters/client_channel/lb_policy.h +++ b/src/core/ext/filters/client_channel/lb_policy.h @@ -42,8 +42,7 @@ namespace grpc_core { /// /// Any I/O done by the LB policy should be done under the pollset_set /// returned by \a interested_parties(). -class LoadBalancingPolicy - : public InternallyRefCountedWithTracing<LoadBalancingPolicy> { +class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> { public: struct Args { /// The combiner under which all LB policy calls will be run. @@ -56,18 +55,20 @@ class LoadBalancingPolicy grpc_client_channel_factory* client_channel_factory = nullptr; /// Channel args from the resolver. /// Note that the LB policy gets the set of addresses from the - /// GRPC_ARG_LB_ADDRESSES channel arg. + /// GRPC_ARG_SERVER_ADDRESS_LIST channel arg. grpc_channel_args* args = nullptr; + /// Load balancing config from the resolver. + grpc_json* lb_config = nullptr; }; /// State used for an LB pick. struct PickState { /// Initial metadata associated with the picking call. grpc_metadata_batch* initial_metadata = nullptr; - /// Bitmask used for selective cancelling. See + /// Pointer to bitmask used for selective cancelling. See /// \a CancelMatchingPicksLocked() and \a GRPC_INITIAL_METADATA_* in /// grpc_types.h. - uint32_t initial_metadata_flags = 0; + uint32_t* initial_metadata_flags = nullptr; /// Storage for LB token in \a initial_metadata, or nullptr if not used. grpc_linked_mdelem lb_token_mdelem_storage; /// Closure to run when pick is complete, if not completed synchronously. @@ -79,11 +80,6 @@ class LoadBalancingPolicy /// Will be populated with context to pass to the subchannel call, if /// needed. grpc_call_context_element subchannel_call_context[GRPC_CONTEXT_COUNT] = {}; - /// Upon success, \a *user_data will be set to whatever opaque information - /// may need to be propagated from the LB policy, or nullptr if not needed. - // TODO(roth): As part of revamping our metadata APIs, try to find a - // way to clean this up and C++-ify it. - void** user_data = nullptr; /// Next pointer. For internal use by LB policy. PickState* next = nullptr; }; @@ -92,10 +88,14 @@ class LoadBalancingPolicy LoadBalancingPolicy(const LoadBalancingPolicy&) = delete; LoadBalancingPolicy& operator=(const LoadBalancingPolicy&) = delete; - /// Updates the policy with a new set of \a args from the resolver. - /// Note that the LB policy gets the set of addresses from the - /// GRPC_ARG_LB_ADDRESSES channel arg. - virtual void UpdateLocked(const grpc_channel_args& args) GRPC_ABSTRACT; + /// Returns the name of the LB policy. + virtual const char* name() const GRPC_ABSTRACT; + + /// Updates the policy with a new set of \a args and a new \a lb_config from + /// the resolver. Note that the LB policy gets the set of addresses from the + /// GRPC_ARG_SERVER_ADDRESS_LIST channel arg. + virtual void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) GRPC_ABSTRACT; /// Finds an appropriate subchannel for a call, based on data in \a pick. /// \a pick must remain alive until the pick is complete. @@ -208,12 +208,6 @@ class LoadBalancingPolicy grpc_pollset_set* interested_parties_; /// Callback to force a re-resolution. grpc_closure* request_reresolution_; - - // Dummy classes needed for alignment issues. - // See https://github.com/grpc/grpc/issues/16032 for context. - // TODO(ncteisen): remove this as soon as the issue is resolved. - channelz::ChildRefsList dummy_list_foo; - channelz::ChildRefsList dummy_list_bar; }; } // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc index dbb90b438c..ba40febd53 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc @@ -84,6 +84,7 @@ #include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel_index.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channel_args.h" @@ -113,17 +114,24 @@ #define GRPC_GRPCLB_RECONNECT_JITTER 0.2 #define GRPC_GRPCLB_DEFAULT_FALLBACK_TIMEOUT_MS 10000 +#define GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN "grpc.grpclb_address_lb_token" + namespace grpc_core { TraceFlag grpc_lb_glb_trace(false, "glb"); namespace { +constexpr char kGrpclb[] = "grpclb"; + class GrpcLb : public LoadBalancingPolicy { public: - GrpcLb(const grpc_lb_addresses* addresses, const Args& args); + explicit GrpcLb(const Args& args); + + const char* name() const override { return kGrpclb; } - void UpdateLocked(const grpc_channel_args& args) override; + void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; void CancelPickLocked(PickState* pick, grpc_error* error) override; void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, @@ -160,9 +168,6 @@ class GrpcLb : public LoadBalancingPolicy { // Our on_complete closure and the original one. grpc_closure on_complete; grpc_closure* original_on_complete; - // The LB token associated with the pick. This is set via user_data in - // the pick. - grpc_mdelem lb_token; // Stats for client-side load reporting. RefCountedPtr<GrpcLbClientStats> client_stats; // Next pending pick. @@ -170,8 +175,7 @@ class GrpcLb : public LoadBalancingPolicy { }; /// Contains a call to the LB server and all the data related to the call. - class BalancerCallState - : public InternallyRefCountedWithTracing<BalancerCallState> { + class BalancerCallState : public InternallyRefCounted<BalancerCallState> { public: explicit BalancerCallState( RefCountedPtr<LoadBalancingPolicy> parent_grpclb_policy); @@ -329,7 +333,7 @@ class GrpcLb : public LoadBalancingPolicy { // 0 means not using fallback. int lb_fallback_timeout_ms_ = 0; // The backend addresses from the resolver. - grpc_lb_addresses* fallback_backend_addresses_ = nullptr; + UniquePtr<ServerAddressList> fallback_backend_addresses_; // Fallback timer. bool fallback_timer_callback_pending_ = false; grpc_timer lb_fallback_timer_; @@ -349,7 +353,7 @@ class GrpcLb : public LoadBalancingPolicy { // serverlist parsing code // -// vtable for LB tokens in grpc_lb_addresses +// vtable for LB token channel arg. void* lb_token_copy(void* token) { return token == nullptr ? nullptr @@ -361,38 +365,13 @@ void lb_token_destroy(void* token) { } } int lb_token_cmp(void* token1, void* token2) { - if (token1 > token2) return 1; - if (token1 < token2) return -1; + // Always indicate a match, since we don't want this channel arg to + // affect the subchannel's key in the index. return 0; } -const grpc_lb_user_data_vtable lb_token_vtable = { +const grpc_arg_pointer_vtable lb_token_arg_vtable = { lb_token_copy, lb_token_destroy, lb_token_cmp}; -// Returns the backend addresses extracted from the given addresses. -grpc_lb_addresses* ExtractBackendAddresses(const grpc_lb_addresses* addresses) { - // First pass: count the number of backend addresses. - size_t num_backends = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (!addresses->addresses[i].is_balancer) { - ++num_backends; - } - } - // Second pass: actually populate the addresses and (empty) LB tokens. - grpc_lb_addresses* backend_addresses = - grpc_lb_addresses_create(num_backends, &lb_token_vtable); - size_t num_copied = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) continue; - const grpc_resolved_address* addr = &addresses->addresses[i].address; - grpc_lb_addresses_set_address(backend_addresses, num_copied, &addr->addr, - addr->len, false /* is_balancer */, - nullptr /* balancer_name */, - (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload); - ++num_copied; - } - return backend_addresses; -} - bool IsServerValid(const grpc_grpclb_server* server, size_t idx, bool log) { if (server->drop) return false; const grpc_grpclb_ip_address* ip = &server->ip_address; @@ -440,30 +419,16 @@ void ParseServer(const grpc_grpclb_server* server, } // Returns addresses extracted from \a serverlist. -grpc_lb_addresses* ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { - size_t num_valid = 0; - /* first pass: count how many are valid in order to allocate the necessary - * memory in a single block */ +ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { + ServerAddressList addresses; for (size_t i = 0; i < serverlist->num_servers; ++i) { - if (IsServerValid(serverlist->servers[i], i, true)) ++num_valid; - } - grpc_lb_addresses* lb_addresses = - grpc_lb_addresses_create(num_valid, &lb_token_vtable); - /* second pass: actually populate the addresses and LB tokens (aka user data - * to the outside world) to be read by the RR policy during its creation. - * Given that the validity tests are very cheap, they are performed again - * instead of marking the valid ones during the first pass, as this would - * incurr in an allocation due to the arbitrary number of server */ - size_t addr_idx = 0; - for (size_t sl_idx = 0; sl_idx < serverlist->num_servers; ++sl_idx) { - const grpc_grpclb_server* server = serverlist->servers[sl_idx]; - if (!IsServerValid(serverlist->servers[sl_idx], sl_idx, false)) continue; - GPR_ASSERT(addr_idx < num_valid); - /* address processing */ + const grpc_grpclb_server* server = serverlist->servers[i]; + if (!IsServerValid(serverlist->servers[i], i, false)) continue; + // Address processing. grpc_resolved_address addr; ParseServer(server, &addr); - /* lb token processing */ - void* user_data; + // LB token processing. + grpc_mdelem lb_token; if (server->has_load_balance_token) { const size_t lb_token_max_length = GPR_ARRAY_SIZE(server->load_balance_token); @@ -471,9 +436,7 @@ grpc_lb_addresses* ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { strnlen(server->load_balance_token, lb_token_max_length); grpc_slice lb_token_mdstr = grpc_slice_from_copied_buffer( server->load_balance_token, lb_token_length); - user_data = - (void*)grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr) - .payload; + lb_token = grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr); } else { char* uri = grpc_sockaddr_to_uri(&addr); gpr_log(GPR_INFO, @@ -481,15 +444,18 @@ grpc_lb_addresses* ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { "be used instead", uri); gpr_free(uri); - user_data = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload; + lb_token = GRPC_MDELEM_LB_TOKEN_EMPTY; } - grpc_lb_addresses_set_address(lb_addresses, addr_idx, &addr.addr, addr.len, - false /* is_balancer */, - nullptr /* balancer_name */, user_data); - ++addr_idx; - } - GPR_ASSERT(addr_idx == num_valid); - return lb_addresses; + // Add address. + grpc_arg arg = grpc_channel_arg_pointer_create( + const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), + (void*)lb_token.payload, &lb_token_arg_vtable); + grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1); + addresses.emplace_back(addr, args); + // Clean up. + GRPC_MDELEM_UNREF(lb_token); + } + return addresses; } // @@ -498,7 +464,7 @@ grpc_lb_addresses* ProcessServerlist(const grpc_grpclb_serverlist* serverlist) { GrpcLb::BalancerCallState::BalancerCallState( RefCountedPtr<LoadBalancingPolicy> parent_grpclb_policy) - : InternallyRefCountedWithTracing<BalancerCallState>(&grpc_lb_glb_trace), + : InternallyRefCounted<BalancerCallState>(&grpc_lb_glb_trace), grpclb_policy_(std::move(parent_grpclb_policy)) { GPR_ASSERT(grpclb_policy_ != nullptr); GPR_ASSERT(!grpclb_policy()->shutting_down_); @@ -565,8 +531,7 @@ void GrpcLb::BalancerCallState::Orphan() { void GrpcLb::BalancerCallState::StartQuery() { GPR_ASSERT(lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, - "[grpclb %p] Starting LB call (lb_calld: %p, lb_call: %p)", + gpr_log(GPR_INFO, "[grpclb %p] lb_calld=%p: Starting LB call %p", grpclb_policy_.get(), this, lb_call_); } // Create the ops. @@ -710,8 +675,9 @@ void GrpcLb::BalancerCallState::SendClientLoadReportLocked() { grpc_call_error call_error = grpc_call_start_batch_and_execute( lb_call_, &op, 1, &client_load_report_closure_); if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { - gpr_log(GPR_ERROR, "[grpclb %p] call_error=%d", grpclb_policy_.get(), - call_error); + gpr_log(GPR_ERROR, + "[grpclb %p] lb_calld=%p call_error=%d sending client load report", + grpclb_policy_.get(), this, call_error); GPR_ASSERT(GRPC_CALL_OK == call_error); } } @@ -748,7 +714,7 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( void* arg, grpc_error* error) { BalancerCallState* lb_calld = static_cast<BalancerCallState*>(arg); GrpcLb* grpclb_policy = lb_calld->grpclb_policy(); - // Empty payload means the LB call was cancelled. + // Null payload means the LB call was cancelled. if (lb_calld != grpclb_policy->lb_calld_.get() || lb_calld->recv_message_payload_ == nullptr) { lb_calld->Unref(DEBUG_LOCATION, "on_message_received"); @@ -772,15 +738,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( &initial_response->client_stats_report_interval)); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; " - "client load reporting interval = %" PRId64 " milliseconds", - grpclb_policy, lb_calld->client_stats_report_interval_); + "[grpclb %p] lb_calld=%p: Received initial LB response " + "message; client load reporting interval = %" PRId64 + " milliseconds", + grpclb_policy, lb_calld, + lb_calld->client_stats_report_interval_); } } else if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Received initial LB response message; client load " - "reporting NOT enabled", - grpclb_policy); + "[grpclb %p] lb_calld=%p: Received initial LB response message; " + "client load reporting NOT enabled", + grpclb_policy, lb_calld); } grpc_grpclb_initial_response_destroy(initial_response); lb_calld->seen_initial_response_ = true; @@ -790,74 +758,67 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked( GPR_ASSERT(lb_calld->lb_call_ != nullptr); if (grpc_lb_glb_trace.enabled()) { gpr_log(GPR_INFO, - "[grpclb %p] Serverlist with %" PRIuPTR " servers received", - grpclb_policy, serverlist->num_servers); + "[grpclb %p] lb_calld=%p: Serverlist with %" PRIuPTR + " servers received", + grpclb_policy, lb_calld, serverlist->num_servers); for (size_t i = 0; i < serverlist->num_servers; ++i) { grpc_resolved_address addr; ParseServer(serverlist->servers[i], &addr); char* ipport; grpc_sockaddr_to_string(&ipport, &addr, false); - gpr_log(GPR_INFO, "[grpclb %p] Serverlist[%" PRIuPTR "]: %s", - grpclb_policy, i, ipport); + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Serverlist[%" PRIuPTR "]: %s", + grpclb_policy, lb_calld, i, ipport); gpr_free(ipport); } } - /* update serverlist */ - if (serverlist->num_servers > 0) { - // Start sending client load report only after we start using the - // serverlist returned from the current LB call. - if (lb_calld->client_stats_report_interval_ > 0 && - lb_calld->client_stats_ == nullptr) { - lb_calld->client_stats_.reset(New<GrpcLbClientStats>()); - // TODO(roth): We currently track this ref manually. Once the - // ClosureRef API is ready, we should pass the RefCountedPtr<> along - // with the callback. - auto self = lb_calld->Ref(DEBUG_LOCATION, "client_load_report"); - self.release(); - lb_calld->ScheduleNextClientLoadReportLocked(); - } - if (grpc_grpclb_serverlist_equals(grpclb_policy->serverlist_, - serverlist)) { - if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, - "[grpclb %p] Incoming server list identical to current, " - "ignoring.", - grpclb_policy); - } - grpc_grpclb_destroy_serverlist(serverlist); - } else { /* new serverlist */ - if (grpclb_policy->serverlist_ != nullptr) { - /* dispose of the old serverlist */ - grpc_grpclb_destroy_serverlist(grpclb_policy->serverlist_); - } else { - /* or dispose of the fallback */ - grpc_lb_addresses_destroy(grpclb_policy->fallback_backend_addresses_); - grpclb_policy->fallback_backend_addresses_ = nullptr; - if (grpclb_policy->fallback_timer_callback_pending_) { - grpc_timer_cancel(&grpclb_policy->lb_fallback_timer_); - } - } - // and update the copy in the GrpcLb instance. This - // serverlist instance will be destroyed either upon the next - // update or when the GrpcLb instance is destroyed. - grpclb_policy->serverlist_ = serverlist; - grpclb_policy->serverlist_index_ = 0; - grpclb_policy->CreateOrUpdateRoundRobinPolicyLocked(); - } - } else { + // Start sending client load report only after we start using the + // serverlist returned from the current LB call. + if (lb_calld->client_stats_report_interval_ > 0 && + lb_calld->client_stats_ == nullptr) { + lb_calld->client_stats_.reset(New<GrpcLbClientStats>()); + // TODO(roth): We currently track this ref manually. Once the + // ClosureRef API is ready, we should pass the RefCountedPtr<> along + // with the callback. + auto self = lb_calld->Ref(DEBUG_LOCATION, "client_load_report"); + self.release(); + lb_calld->ScheduleNextClientLoadReportLocked(); + } + // Check if the serverlist differs from the previous one. + if (grpc_grpclb_serverlist_equals(grpclb_policy->serverlist_, serverlist)) { if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, "[grpclb %p] Received empty server list, ignoring.", - grpclb_policy); + gpr_log(GPR_INFO, + "[grpclb %p] lb_calld=%p: Incoming server list identical to " + "current, ignoring.", + grpclb_policy, lb_calld); } grpc_grpclb_destroy_serverlist(serverlist); + } else { // New serverlist. + if (grpclb_policy->serverlist_ != nullptr) { + // Dispose of the old serverlist. + grpc_grpclb_destroy_serverlist(grpclb_policy->serverlist_); + } else { + // Dispose of the fallback. + grpclb_policy->fallback_backend_addresses_.reset(); + if (grpclb_policy->fallback_timer_callback_pending_) { + grpc_timer_cancel(&grpclb_policy->lb_fallback_timer_); + } + } + // Update the serverlist in the GrpcLb instance. This serverlist + // instance will be destroyed either upon the next update or when the + // GrpcLb instance is destroyed. + grpclb_policy->serverlist_ = serverlist; + grpclb_policy->serverlist_index_ = 0; + grpclb_policy->CreateOrUpdateRoundRobinPolicyLocked(); } } else { // No valid initial response or serverlist found. char* response_slice_str = grpc_dump_slice(response_slice, GPR_DUMP_ASCII | GPR_DUMP_HEX); gpr_log(GPR_ERROR, - "[grpclb %p] Invalid LB response received: '%s'. Ignoring.", - grpclb_policy, response_slice_str); + "[grpclb %p] lb_calld=%p: Invalid LB response received: '%s'. " + "Ignoring.", + grpclb_policy, lb_calld, response_slice_str); gpr_free(response_slice_str); } grpc_slice_unref_internal(response_slice); @@ -888,9 +849,9 @@ void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked( char* status_details = grpc_slice_to_c_string(lb_calld->lb_call_status_details_); gpr_log(GPR_INFO, - "[grpclb %p] Status from LB server received. Status = %d, details " - "= '%s', (lb_calld: %p, lb_call: %p), error '%s'", - grpclb_policy, lb_calld->lb_call_status_, status_details, lb_calld, + "[grpclb %p] lb_calld=%p: Status from LB server received. " + "Status = %d, details = '%s', (lb_call: %p), error '%s'", + grpclb_policy, lb_calld, lb_calld->lb_call_status_, status_details, lb_calld->lb_call_, grpc_error_string(error)); gpr_free(status_details); } @@ -919,31 +880,25 @@ void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked( // helper code for creating balancer channel // -grpc_lb_addresses* ExtractBalancerAddresses( - const grpc_lb_addresses* addresses) { - size_t num_grpclb_addrs = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) ++num_grpclb_addrs; - } - // There must be at least one balancer address, or else the - // client_channel would not have chosen this LB policy. - GPR_ASSERT(num_grpclb_addrs > 0); - grpc_lb_addresses* lb_addresses = - grpc_lb_addresses_create(num_grpclb_addrs, nullptr); - size_t lb_addresses_idx = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (!addresses->addresses[i].is_balancer) continue; - if (GPR_UNLIKELY(addresses->addresses[i].user_data != nullptr)) { - gpr_log(GPR_ERROR, - "This LB policy doesn't support user data. It will be ignored"); +ServerAddressList ExtractBalancerAddresses(const ServerAddressList& addresses) { + ServerAddressList balancer_addresses; + for (size_t i = 0; i < addresses.size(); ++i) { + if (addresses[i].IsBalancer()) { + // Strip out the is_balancer channel arg, since we don't want to + // recursively use the grpclb policy in the channel used to talk to + // the balancers. Note that we do NOT strip out the balancer_name + // channel arg, since we need that to set the authority correctly + // to talk to the balancers. + static const char* args_to_remove[] = { + GRPC_ARG_ADDRESS_IS_BALANCER, + }; + balancer_addresses.emplace_back( + addresses[i].address(), + grpc_channel_args_copy_and_remove(addresses[i].args(), args_to_remove, + GPR_ARRAY_SIZE(args_to_remove))); } - grpc_lb_addresses_set_address( - lb_addresses, lb_addresses_idx++, addresses->addresses[i].address.addr, - addresses->addresses[i].address.len, false /* is balancer */, - addresses->addresses[i].balancer_name, nullptr /* user data */); } - GPR_ASSERT(num_grpclb_addrs == lb_addresses_idx); - return lb_addresses; + return balancer_addresses; } /* Returns the channel args for the LB channel, used to create a bidirectional @@ -955,10 +910,10 @@ grpc_lb_addresses* ExtractBalancerAddresses( * above the grpclb policy. * - \a args: other args inherited from the grpclb policy. */ grpc_channel_args* BuildBalancerChannelArgs( - const grpc_lb_addresses* addresses, + const ServerAddressList& addresses, FakeResolverResponseGenerator* response_generator, const grpc_channel_args* args) { - grpc_lb_addresses* lb_addresses = ExtractBalancerAddresses(addresses); + ServerAddressList balancer_addresses = ExtractBalancerAddresses(addresses); // Channel args to remove. static const char* args_to_remove[] = { // LB policy name, since we want to use the default (pick_first) in @@ -976,7 +931,7 @@ grpc_channel_args* BuildBalancerChannelArgs( // is_balancer=true. We need the LB channel to return addresses with // is_balancer=false so that it does not wind up recursively using the // grpclb LB policy, as per the special case logic in client_channel.c. - GRPC_ARG_LB_ADDRESSES, + GRPC_ARG_SERVER_ADDRESS_LIST, // The fake resolver response generator, because we are replacing it // with the one from the grpclb policy, used to propagate updates to // the LB channel. @@ -992,10 +947,10 @@ grpc_channel_args* BuildBalancerChannelArgs( }; // Channel args to add. const grpc_arg args_to_add[] = { - // New LB addresses. + // New address list. // Note that we pass these in both when creating the LB channel // and via the fake resolver. The latter is what actually gets used. - grpc_lb_addresses_create_channel_arg(lb_addresses), + CreateServerAddressListChannelArg(&balancer_addresses), // The fake resolver response generator, which we use to inject // address updates into the LB channel. grpc_core::FakeResolverResponseGenerator::MakeChannelArg( @@ -1013,18 +968,14 @@ grpc_channel_args* BuildBalancerChannelArgs( args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), args_to_add, GPR_ARRAY_SIZE(args_to_add)); // Make any necessary modifications for security. - new_args = grpc_lb_policy_grpclb_modify_lb_channel_args(new_args); - // Clean up. - grpc_lb_addresses_destroy(lb_addresses); - return new_args; + return grpc_lb_policy_grpclb_modify_lb_channel_args(new_args); } // // ctor and dtor // -GrpcLb::GrpcLb(const grpc_lb_addresses* addresses, - const LoadBalancingPolicy::Args& args) +GrpcLb::GrpcLb(const LoadBalancingPolicy::Args& args) : LoadBalancingPolicy(args), response_generator_(MakeRefCounted<FakeResolverResponseGenerator>()), lb_call_backoff_( @@ -1081,9 +1032,6 @@ GrpcLb::~GrpcLb() { if (serverlist_ != nullptr) { grpc_grpclb_destroy_serverlist(serverlist_); } - if (fallback_backend_addresses_ != nullptr) { - grpc_lb_addresses_destroy(fallback_backend_addresses_); - } grpc_subchannel_index_unref(); } @@ -1131,7 +1079,6 @@ void GrpcLb::HandOffPendingPicksLocked(LoadBalancingPolicy* new_policy) { while ((pp = pending_picks_) != nullptr) { pending_picks_ = pp->next; pp->pick->on_complete = pp->original_on_complete; - pp->pick->user_data = nullptr; grpc_error* error = GRPC_ERROR_NONE; if (new_policy->PickLocked(pp->pick, &error)) { // Synchronous return; schedule closure. @@ -1193,7 +1140,7 @@ void GrpcLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pp != nullptr) { PendingPick* next = pp->next; - if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { // Note: pp is deleted in this callback. GRPC_CLOSURE_SCHED(&pp->on_complete, @@ -1285,9 +1232,27 @@ void GrpcLb::NotifyOnStateChangeLocked(grpc_connectivity_state* current, notify); } +// Returns the backend addresses extracted from the given addresses. +UniquePtr<ServerAddressList> ExtractBackendAddresses( + const ServerAddressList& addresses) { + void* lb_token = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload; + grpc_arg arg = grpc_channel_arg_pointer_create( + const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), lb_token, + &lb_token_arg_vtable); + auto backend_addresses = MakeUnique<ServerAddressList>(); + for (size_t i = 0; i < addresses.size(); ++i) { + if (!addresses[i].IsBalancer()) { + backend_addresses->emplace_back( + addresses[i].address(), + grpc_channel_args_copy_and_add(addresses[i].args(), &arg, 1)); + } + } + return backend_addresses; +} + void GrpcLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { - const grpc_arg* arg = grpc_channel_args_find(&args, GRPC_ARG_LB_ADDRESSES); - if (GPR_UNLIKELY(arg == nullptr || arg->type != GRPC_ARG_POINTER)) { + const ServerAddressList* addresses = FindServerAddressListChannelArg(&args); + if (addresses == nullptr) { // Ignore this update. gpr_log( GPR_ERROR, @@ -1295,13 +1260,8 @@ void GrpcLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { this); return; } - const grpc_lb_addresses* addresses = - static_cast<const grpc_lb_addresses*>(arg->value.pointer.p); // Update fallback address list. - if (fallback_backend_addresses_ != nullptr) { - grpc_lb_addresses_destroy(fallback_backend_addresses_); - } - fallback_backend_addresses_ = ExtractBackendAddresses(addresses); + fallback_backend_addresses_ = ExtractBackendAddresses(*addresses); // Make sure that GRPC_ARG_LB_POLICY_NAME is set in channel args, // since we use this to trigger the client_load_reporting filter. static const char* args_to_remove[] = {GRPC_ARG_LB_POLICY_NAME}; @@ -1312,7 +1272,7 @@ void GrpcLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { &args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), &new_arg, 1); // Construct args for balancer channel. grpc_channel_args* lb_channel_args = - BuildBalancerChannelArgs(addresses, response_generator_.get(), &args); + BuildBalancerChannelArgs(*addresses, response_generator_.get(), &args); // Create balancer channel if needed. if (lb_channel_ == nullptr) { char* uri_str; @@ -1331,7 +1291,7 @@ void GrpcLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { grpc_channel_args_destroy(lb_channel_args); } -void GrpcLb::UpdateLocked(const grpc_channel_args& args) { +void GrpcLb::UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) { ProcessChannelArgsLocked(args); // Update the existing RR policy. if (rr_policy_ != nullptr) CreateOrUpdateRoundRobinPolicyLocked(); @@ -1518,12 +1478,17 @@ void DestroyClientStats(void* arg) { } void GrpcLb::PendingPickSetMetadataAndContext(PendingPick* pp) { - /* if connected_subchannel is nullptr, no pick has been made by the RR - * policy (e.g., all addresses failed to connect). There won't be any - * user_data/token available */ + // If connected_subchannel is nullptr, no pick has been made by the RR + // policy (e.g., all addresses failed to connect). There won't be any + // LB token available. if (pp->pick->connected_subchannel != nullptr) { - if (GPR_LIKELY(!GRPC_MDISNULL(pp->lb_token))) { - AddLbTokenToInitialMetadata(GRPC_MDELEM_REF(pp->lb_token), + const grpc_arg* arg = + grpc_channel_args_find(pp->pick->connected_subchannel->args(), + GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN); + if (arg != nullptr) { + grpc_mdelem lb_token = { + reinterpret_cast<uintptr_t>(arg->value.pointer.p)}; + AddLbTokenToInitialMetadata(GRPC_MDELEM_REF(lb_token), &pp->pick->lb_token_mdelem_storage, pp->pick->initial_metadata); } else { @@ -1582,7 +1547,7 @@ void GrpcLb::AddPendingPick(PendingPick* pp) { bool GrpcLb::PickFromRoundRobinPolicyLocked(bool force_async, PendingPick* pp, grpc_error** error) { // Check for drops if we are not using fallback backend addresses. - if (serverlist_ != nullptr) { + if (serverlist_ != nullptr && serverlist_->num_servers > 0) { // Look at the index into the serverlist to see if we should drop this call. grpc_grpclb_server* server = serverlist_->servers[serverlist_index_++]; if (serverlist_index_ == serverlist_->num_servers) { @@ -1607,12 +1572,10 @@ bool GrpcLb::PickFromRoundRobinPolicyLocked(bool force_async, PendingPick* pp, return true; } } - // Set client_stats and user_data. + // Set client_stats. if (lb_calld_ != nullptr && lb_calld_->client_stats() != nullptr) { pp->client_stats = lb_calld_->client_stats()->Ref(); } - GPR_ASSERT(pp->pick->user_data == nullptr); - pp->pick->user_data = (void**)&pp->lb_token; // Pick via the RR policy. bool pick_done = rr_policy_->PickLocked(pp->pick, error); if (pick_done) { @@ -1640,6 +1603,10 @@ void GrpcLb::CreateRoundRobinPolicyLocked(const Args& args) { this); return; } + if (grpc_lb_glb_trace.enabled()) { + gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, + rr_policy_.get()); + } // TODO(roth): We currently track this ref manually. Once the new // ClosureRef API is done, pass the RefCountedPtr<> along with the closure. auto self = Ref(DEBUG_LOCATION, "on_rr_reresolution_requested"); @@ -1677,11 +1644,11 @@ void GrpcLb::CreateRoundRobinPolicyLocked(const Args& args) { } grpc_channel_args* GrpcLb::CreateRoundRobinPolicyArgsLocked() { - grpc_lb_addresses* addresses; + ServerAddressList tmp_addresses; + ServerAddressList* addresses = &tmp_addresses; bool is_backend_from_grpclb_load_balancer = false; if (serverlist_ != nullptr) { - GPR_ASSERT(serverlist_->num_servers > 0); - addresses = ProcessServerlist(serverlist_); + tmp_addresses = ProcessServerlist(serverlist_); is_backend_from_grpclb_load_balancer = true; } else { // If CreateOrUpdateRoundRobinPolicyLocked() is invoked when we haven't @@ -1690,14 +1657,14 @@ grpc_channel_args* GrpcLb::CreateRoundRobinPolicyArgsLocked() { // empty, in which case the new round_robin policy will keep the requested // picks pending. GPR_ASSERT(fallback_backend_addresses_ != nullptr); - addresses = grpc_lb_addresses_copy(fallback_backend_addresses_); + addresses = fallback_backend_addresses_.get(); } GPR_ASSERT(addresses != nullptr); - // Replace the LB addresses in the channel args that we pass down to + // Replace the server address list in the channel args that we pass down to // the subchannel. - static const char* keys_to_remove[] = {GRPC_ARG_LB_ADDRESSES}; + static const char* keys_to_remove[] = {GRPC_ARG_SERVER_ADDRESS_LIST}; grpc_arg args_to_add[3] = { - grpc_lb_addresses_create_channel_arg(addresses), + CreateServerAddressListChannelArg(addresses), // A channel arg indicating if the target is a backend inferred from a // grpclb load balancer. grpc_channel_arg_integer_create( @@ -1714,7 +1681,6 @@ grpc_channel_args* GrpcLb::CreateRoundRobinPolicyArgsLocked() { grpc_channel_args* args = grpc_channel_args_copy_and_add_and_remove( args_, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), args_to_add, num_args_to_add); - grpc_lb_addresses_destroy(addresses); return args; } @@ -1727,17 +1693,13 @@ void GrpcLb::CreateOrUpdateRoundRobinPolicyLocked() { gpr_log(GPR_INFO, "[grpclb %p] Updating RR policy %p", this, rr_policy_.get()); } - rr_policy_->UpdateLocked(*args); + rr_policy_->UpdateLocked(*args, nullptr); } else { LoadBalancingPolicy::Args lb_policy_args; lb_policy_args.combiner = combiner(); lb_policy_args.client_channel_factory = client_channel_factory(); lb_policy_args.args = args; CreateRoundRobinPolicyLocked(lb_policy_args); - if (grpc_lb_glb_trace.enabled()) { - gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this, - rr_policy_.get()); - } } grpc_channel_args_destroy(args); } @@ -1847,22 +1809,21 @@ class GrpcLbFactory : public LoadBalancingPolicyFactory { OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy( const LoadBalancingPolicy::Args& args) const override { /* Count the number of gRPC-LB addresses. There must be at least one. */ - const grpc_arg* arg = - grpc_channel_args_find(args.args, GRPC_ARG_LB_ADDRESSES); - if (arg == nullptr || arg->type != GRPC_ARG_POINTER) { - return nullptr; - } - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(arg->value.pointer.p); - size_t num_grpclb_addrs = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) ++num_grpclb_addrs; + const ServerAddressList* addresses = + FindServerAddressListChannelArg(args.args); + if (addresses == nullptr) return nullptr; + bool found_balancer = false; + for (size_t i = 0; i < addresses->size(); ++i) { + if ((*addresses)[i].IsBalancer()) { + found_balancer = true; + break; + } } - if (num_grpclb_addrs == 0) return nullptr; - return OrphanablePtr<LoadBalancingPolicy>(New<GrpcLb>(addresses, args)); + if (!found_balancer) return nullptr; + return OrphanablePtr<LoadBalancingPolicy>(New<GrpcLb>(args)); } - const char* name() const override { return "grpclb"; } + const char* name() const override { return kGrpclb; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h index 825065a9c3..3b2dc370eb 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.h @@ -21,7 +21,7 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include <grpc/impl/codegen/grpc_types.h> /// Makes any necessary modifications to \a args for use in the grpclb /// balancer channel. diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc index 441efd5e23..657ff69312 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc @@ -26,6 +26,7 @@ #include <grpc/support/string_util.h> #include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/iomgr/sockaddr_utils.h" @@ -42,22 +43,23 @@ int BalancerNameCmp(const grpc_core::UniquePtr<char>& a, } RefCountedPtr<TargetAuthorityTable> CreateTargetAuthorityTable( - grpc_lb_addresses* addresses) { + const ServerAddressList& addresses) { TargetAuthorityTable::Entry* target_authority_entries = - static_cast<TargetAuthorityTable::Entry*>(gpr_zalloc( - sizeof(*target_authority_entries) * addresses->num_addresses)); - for (size_t i = 0; i < addresses->num_addresses; ++i) { + static_cast<TargetAuthorityTable::Entry*>( + gpr_zalloc(sizeof(*target_authority_entries) * addresses.size())); + for (size_t i = 0; i < addresses.size(); ++i) { char* addr_str; - GPR_ASSERT(grpc_sockaddr_to_string( - &addr_str, &addresses->addresses[i].address, true) > 0); + GPR_ASSERT( + grpc_sockaddr_to_string(&addr_str, &addresses[i].address(), true) > 0); target_authority_entries[i].key = grpc_slice_from_copied_string(addr_str); - target_authority_entries[i].value.reset( - gpr_strdup(addresses->addresses[i].balancer_name)); gpr_free(addr_str); + char* balancer_name = grpc_channel_arg_get_string(grpc_channel_args_find( + addresses[i].args(), GRPC_ARG_ADDRESS_BALANCER_NAME)); + target_authority_entries[i].value.reset(gpr_strdup(balancer_name)); } RefCountedPtr<TargetAuthorityTable> target_authority_table = - TargetAuthorityTable::Create(addresses->num_addresses, - target_authority_entries, BalancerNameCmp); + TargetAuthorityTable::Create(addresses.size(), target_authority_entries, + BalancerNameCmp); gpr_free(target_authority_entries); return target_authority_table; } @@ -72,13 +74,12 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args( grpc_arg args_to_add[2]; size_t num_args_to_add = 0; // Add arg for targets info table. - const grpc_arg* arg = grpc_channel_args_find(args, GRPC_ARG_LB_ADDRESSES); - GPR_ASSERT(arg != nullptr); - GPR_ASSERT(arg->type == GRPC_ARG_POINTER); - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(arg->value.pointer.p); + grpc_core::ServerAddressList* addresses = + grpc_core::FindServerAddressListChannelArg(args); + GPR_ASSERT(addresses != nullptr); grpc_core::RefCountedPtr<grpc_core::TargetAuthorityTable> - target_authority_table = grpc_core::CreateTargetAuthorityTable(addresses); + target_authority_table = + grpc_core::CreateTargetAuthorityTable(*addresses); args_to_add[num_args_to_add++] = grpc_core::CreateTargetAuthorityTableChannelArg( target_authority_table.get()); @@ -87,22 +88,18 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h b/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h index 9ca7b28d8e..71d371c880 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.h @@ -25,7 +25,7 @@ #include "src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.h" #include "src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.h" -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/lib/iomgr/exec_ctx.h" #define GRPC_GRPCLB_SERVICE_NAME_MAX_LENGTH 128 diff --git a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc index eb494486b9..d6ff74ec7f 100644 --- a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc +++ b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc @@ -24,6 +24,7 @@ #include "src/core/ext/filters/client_channel/lb_policy/subchannel_list.h" #include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel.h" #include "src/core/ext/filters/client_channel/subchannel_index.h" #include "src/core/lib/channel/channel_args.h" @@ -42,11 +43,16 @@ namespace { // pick_first LB policy // +constexpr char kPickFirst[] = "pick_first"; + class PickFirst : public LoadBalancingPolicy { public: explicit PickFirst(const Args& args); - void UpdateLocked(const grpc_channel_args& args) override; + const char* name() const override { return kPickFirst; } + + void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; void CancelPickLocked(PickState* pick, grpc_error* error) override; void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, @@ -74,11 +80,9 @@ class PickFirst : public LoadBalancingPolicy { PickFirstSubchannelData( SubchannelList<PickFirstSubchannelList, PickFirstSubchannelData>* subchannel_list, - const grpc_lb_user_data_vtable* user_data_vtable, - const grpc_lb_address& address, grpc_subchannel* subchannel, + const ServerAddress& address, grpc_subchannel* subchannel, grpc_combiner* combiner) - : SubchannelData(subchannel_list, user_data_vtable, address, subchannel, - combiner) {} + : SubchannelData(subchannel_list, address, subchannel, combiner) {} void ProcessConnectivityChangeLocked( grpc_connectivity_state connectivity_state, grpc_error* error) override; @@ -94,7 +98,7 @@ class PickFirst : public LoadBalancingPolicy { PickFirstSubchannelData> { public: PickFirstSubchannelList(PickFirst* policy, TraceFlag* tracer, - const grpc_lb_addresses* addresses, + const ServerAddressList& addresses, grpc_combiner* combiner, grpc_client_channel_factory* client_channel_factory, const grpc_channel_args& args) @@ -159,7 +163,7 @@ PickFirst::PickFirst(const Args& args) : LoadBalancingPolicy(args) { if (grpc_lb_pick_first_trace.enabled()) { gpr_log(GPR_INFO, "Pick First %p created.", this); } - UpdateLocked(*args.args); + UpdateLocked(*args.args, args.lb_config); grpc_subchannel_index_ref(); } @@ -234,7 +238,7 @@ void PickFirst::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pick != nullptr) { PickState* next = pick->next; - if ((pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { GRPC_CLOSURE_SCHED(pick->on_complete, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( @@ -333,10 +337,11 @@ void PickFirst::UpdateChildRefsLocked() { child_subchannels_ = std::move(cs); } -void PickFirst::UpdateLocked(const grpc_channel_args& args) { +void PickFirst::UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) { AutoChildRefsUpdater guard(this); - const grpc_arg* arg = grpc_channel_args_find(&args, GRPC_ARG_LB_ADDRESSES); - if (arg == nullptr || arg->type != GRPC_ARG_POINTER) { + const ServerAddressList* addresses = FindServerAddressListChannelArg(&args); + if (addresses == nullptr) { if (subchannel_list_ == nullptr) { // If we don't have a current subchannel list, go into TRANSIENT FAILURE. grpc_connectivity_state_set( @@ -352,19 +357,17 @@ void PickFirst::UpdateLocked(const grpc_channel_args& args) { } return; } - const grpc_lb_addresses* addresses = - static_cast<const grpc_lb_addresses*>(arg->value.pointer.p); if (grpc_lb_pick_first_trace.enabled()) { gpr_log(GPR_INFO, "Pick First %p received update with %" PRIuPTR " addresses", this, - addresses->num_addresses); + addresses->size()); } grpc_arg new_arg = grpc_channel_arg_integer_create( const_cast<char*>(GRPC_ARG_INHIBIT_HEALTH_CHECKING), 1); grpc_channel_args* new_args = grpc_channel_args_copy_and_add(&args, &new_arg, 1); auto subchannel_list = MakeOrphanable<PickFirstSubchannelList>( - this, &grpc_lb_pick_first_trace, addresses, combiner(), + this, &grpc_lb_pick_first_trace, *addresses, combiner(), client_channel_factory(), *new_args); grpc_channel_args_destroy(new_args); if (subchannel_list->num_subchannels() == 0) { @@ -378,6 +381,31 @@ void PickFirst::UpdateLocked(const grpc_channel_args& args) { selected_ = nullptr; return; } + // If one of the subchannels in the new list is already in state + // READY, then select it immediately. This can happen when the + // currently selected subchannel is also present in the update. It + // can also happen if one of the subchannels in the update is already + // in the subchannel index because it's in use by another channel. + for (size_t i = 0; i < subchannel_list->num_subchannels(); ++i) { + PickFirstSubchannelData* sd = subchannel_list->subchannel(i); + grpc_error* error = GRPC_ERROR_NONE; + grpc_connectivity_state state = sd->CheckConnectivityStateLocked(&error); + GRPC_ERROR_UNREF(error); + if (state == GRPC_CHANNEL_READY) { + subchannel_list_ = std::move(subchannel_list); + sd->ProcessUnselectedReadyLocked(); + sd->StartConnectivityWatchLocked(); + // If there was a previously pending update (which may or may + // not have contained the currently selected subchannel), drop + // it, so that it doesn't override what we've done here. + latest_pending_subchannel_list_.reset(); + // Make sure that subsequent calls to ExitIdleLocked() don't cause + // us to start watching a subchannel other than the one we've + // selected. + started_picking_ = true; + return; + } + } if (selected_ == nullptr) { // We don't yet have a selected subchannel, so replace the current // subchannel list immediately. @@ -385,46 +413,14 @@ void PickFirst::UpdateLocked(const grpc_channel_args& args) { // If we've started picking, start trying to connect to the first // subchannel in the new list. if (started_picking_) { - subchannel_list_->subchannel(0) - ->CheckConnectivityStateAndStartWatchingLocked(); + // Note: No need to use CheckConnectivityStateAndStartWatchingLocked() + // here, since we've already checked the initial connectivity + // state of all subchannels above. + subchannel_list_->subchannel(0)->StartConnectivityWatchLocked(); } } else { - // We do have a selected subchannel. - // Check if it's present in the new list. If so, we're done. - for (size_t i = 0; i < subchannel_list->num_subchannels(); ++i) { - PickFirstSubchannelData* sd = subchannel_list->subchannel(i); - if (sd->subchannel() == selected_->subchannel()) { - // The currently selected subchannel is in the update: we are done. - if (grpc_lb_pick_first_trace.enabled()) { - gpr_log(GPR_INFO, - "Pick First %p found already selected subchannel %p " - "at update index %" PRIuPTR " of %" PRIuPTR "; update done", - this, selected_->subchannel(), i, - subchannel_list->num_subchannels()); - } - // Make sure it's in state READY. It might not be if we grabbed - // the combiner while a connectivity state notification - // informing us otherwise is pending. - // Note that CheckConnectivityStateLocked() also takes a ref to - // the connected subchannel. - grpc_error* error = GRPC_ERROR_NONE; - if (sd->CheckConnectivityStateLocked(&error) == GRPC_CHANNEL_READY) { - selected_ = sd; - subchannel_list_ = std::move(subchannel_list); - sd->StartConnectivityWatchLocked(); - // If there was a previously pending update (which may or may - // not have contained the currently selected subchannel), drop - // it, so that it doesn't override what we've done here. - latest_pending_subchannel_list_.reset(); - return; - } - GRPC_ERROR_UNREF(error); - } - } - // Not keeping the previous selected subchannel, so set the latest - // pending subchannel list to the new subchannel list. We will wait - // for it to report READY before swapping it into the current - // subchannel list. + // We do have a selected subchannel, so keep using it until one of + // the subchannels in the new list reports READY. if (latest_pending_subchannel_list_ != nullptr) { if (grpc_lb_pick_first_trace.enabled()) { gpr_log(GPR_INFO, @@ -438,8 +434,11 @@ void PickFirst::UpdateLocked(const grpc_channel_args& args) { // If we've started picking, start trying to connect to the first // subchannel in the new list. if (started_picking_) { + // Note: No need to use CheckConnectivityStateAndStartWatchingLocked() + // here, since we've already checked the initial connectivity + // state of all subchannels above. latest_pending_subchannel_list_->subchannel(0) - ->CheckConnectivityStateAndStartWatchingLocked(); + ->StartConnectivityWatchLocked(); } } } @@ -627,7 +626,7 @@ class PickFirstFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<PickFirst>(args)); } - const char* name() const override { return "pick_first"; } + const char* name() const override { return kPickFirst; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc index e9ed85cf66..3bcb33ef11 100644 --- a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc +++ b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc @@ -53,11 +53,16 @@ namespace { // round_robin LB policy // +constexpr char kRoundRobin[] = "round_robin"; + class RoundRobin : public LoadBalancingPolicy { public: explicit RoundRobin(const Args& args); - void UpdateLocked(const grpc_channel_args& args) override; + const char* name() const override { return kRoundRobin; } + + void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; void CancelPickLocked(PickState* pick, grpc_error* error) override; void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, @@ -81,8 +86,6 @@ class RoundRobin : public LoadBalancingPolicy { // Data for a particular subchannel in a subchannel list. // This subclass adds the following functionality: - // - Tracks user_data associated with each address, which will be - // returned along with picks that select the subchannel. // - Tracks the previous connectivity state of the subchannel, so that // we know how many subchannels are in each state. class RoundRobinSubchannelData @@ -92,26 +95,9 @@ class RoundRobin : public LoadBalancingPolicy { RoundRobinSubchannelData( SubchannelList<RoundRobinSubchannelList, RoundRobinSubchannelData>* subchannel_list, - const grpc_lb_user_data_vtable* user_data_vtable, - const grpc_lb_address& address, grpc_subchannel* subchannel, + const ServerAddress& address, grpc_subchannel* subchannel, grpc_combiner* combiner) - : SubchannelData(subchannel_list, user_data_vtable, address, subchannel, - combiner), - user_data_vtable_(user_data_vtable), - user_data_(user_data_vtable_ != nullptr - ? user_data_vtable_->copy(address.user_data) - : nullptr) {} - - void UnrefSubchannelLocked(const char* reason) override { - SubchannelData::UnrefSubchannelLocked(reason); - if (user_data_ != nullptr) { - GPR_ASSERT(user_data_vtable_ != nullptr); - user_data_vtable_->destroy(user_data_); - user_data_ = nullptr; - } - } - - void* user_data() const { return user_data_; } + : SubchannelData(subchannel_list, address, subchannel, combiner) {} grpc_connectivity_state connectivity_state() const { return last_connectivity_state_; @@ -124,8 +110,6 @@ class RoundRobin : public LoadBalancingPolicy { void ProcessConnectivityChangeLocked( grpc_connectivity_state connectivity_state, grpc_error* error) override; - const grpc_lb_user_data_vtable* user_data_vtable_; - void* user_data_ = nullptr; grpc_connectivity_state last_connectivity_state_ = GRPC_CHANNEL_IDLE; }; @@ -136,7 +120,7 @@ class RoundRobin : public LoadBalancingPolicy { public: RoundRobinSubchannelList( RoundRobin* policy, TraceFlag* tracer, - const grpc_lb_addresses* addresses, grpc_combiner* combiner, + const ServerAddressList& addresses, grpc_combiner* combiner, grpc_client_channel_factory* client_channel_factory, const grpc_channel_args& args) : SubchannelList(policy, tracer, addresses, combiner, @@ -232,7 +216,7 @@ RoundRobin::RoundRobin(const Args& args) : LoadBalancingPolicy(args) { gpr_mu_init(&child_refs_mu_); grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE, "round_robin"); - UpdateLocked(*args.args); + UpdateLocked(*args.args, args.lb_config); if (grpc_lb_round_robin_trace.enabled()) { gpr_log(GPR_INFO, "[RR %p] Created with %" PRIuPTR " subchannels", this, subchannel_list_->num_subchannels()); @@ -311,7 +295,7 @@ void RoundRobin::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pick != nullptr) { PickState* next = pick->next; - if ((pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { pick->connected_subchannel.reset(); GRPC_CLOSURE_SCHED(pick->on_complete, @@ -353,9 +337,6 @@ bool RoundRobin::DoPickLocked(PickState* pick) { subchannel_list_->subchannel(next_ready_index); GPR_ASSERT(sd->connected_subchannel() != nullptr); pick->connected_subchannel = sd->connected_subchannel()->Ref(); - if (pick->user_data != nullptr) { - *pick->user_data = sd->user_data(); - } if (grpc_lb_round_robin_trace.enabled()) { gpr_log(GPR_INFO, "[RR %p] Picked target <-- Subchannel %p (connected %p) (sl %p, " @@ -664,10 +645,11 @@ void RoundRobin::NotifyOnStateChangeLocked(grpc_connectivity_state* current, notify); } -void RoundRobin::UpdateLocked(const grpc_channel_args& args) { - const grpc_arg* arg = grpc_channel_args_find(&args, GRPC_ARG_LB_ADDRESSES); +void RoundRobin::UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) { AutoChildRefsUpdater guard(this); - if (GPR_UNLIKELY(arg == nullptr || arg->type != GRPC_ARG_POINTER)) { + const ServerAddressList* addresses = FindServerAddressListChannelArg(&args); + if (addresses == nullptr) { gpr_log(GPR_ERROR, "[RR %p] update provided no addresses; ignoring", this); // If we don't have a current subchannel list, go into TRANSIENT_FAILURE. // Otherwise, keep using the current subchannel list (ignore this update). @@ -679,11 +661,9 @@ void RoundRobin::UpdateLocked(const grpc_channel_args& args) { } return; } - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(arg->value.pointer.p); if (grpc_lb_round_robin_trace.enabled()) { gpr_log(GPR_INFO, "[RR %p] received update with %" PRIuPTR " addresses", - this, addresses->num_addresses); + this, addresses->size()); } // Replace latest_pending_subchannel_list_. if (latest_pending_subchannel_list_ != nullptr) { @@ -694,7 +674,7 @@ void RoundRobin::UpdateLocked(const grpc_channel_args& args) { } } latest_pending_subchannel_list_ = MakeOrphanable<RoundRobinSubchannelList>( - this, &grpc_lb_round_robin_trace, addresses, combiner(), + this, &grpc_lb_round_robin_trace, *addresses, combiner(), client_channel_factory(), args); // If we haven't started picking yet or the new list is empty, // immediately promote the new list to the current list. @@ -724,7 +704,7 @@ class RoundRobinFactory : public LoadBalancingPolicyFactory { return OrphanablePtr<LoadBalancingPolicy>(New<RoundRobin>(args)); } - const char* name() const override { return "round_robin"; } + const char* name() const override { return kRoundRobin; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h index 4ec9e935ed..6f31a643c1 100644 --- a/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h +++ b/src/core/ext/filters/client_channel/lb_policy/subchannel_list.h @@ -26,6 +26,7 @@ #include <grpc/support/alloc.h> #include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/debug/trace.h" @@ -141,8 +142,7 @@ class SubchannelData { protected: SubchannelData( SubchannelList<SubchannelListType, SubchannelDataType>* subchannel_list, - const grpc_lb_user_data_vtable* user_data_vtable, - const grpc_lb_address& address, grpc_subchannel* subchannel, + const ServerAddress& address, grpc_subchannel* subchannel, grpc_combiner* combiner); virtual ~SubchannelData(); @@ -156,9 +156,8 @@ class SubchannelData { grpc_connectivity_state connectivity_state, grpc_error* error) GRPC_ABSTRACT; - // Unrefs the subchannel. May be overridden by subclasses that need - // to perform extra cleanup when unreffing the subchannel. - virtual void UnrefSubchannelLocked(const char* reason); + // Unrefs the subchannel. + void UnrefSubchannelLocked(const char* reason); private: // Updates connected_subchannel_ based on pending_connectivity_state_unsafe_. @@ -186,8 +185,7 @@ class SubchannelData { // A list of subchannels. template <typename SubchannelListType, typename SubchannelDataType> -class SubchannelList - : public InternallyRefCountedWithTracing<SubchannelListType> { +class SubchannelList : public InternallyRefCounted<SubchannelListType> { public: typedef InlinedVector<SubchannelDataType, 10> SubchannelVector; @@ -226,15 +224,14 @@ class SubchannelList // Note: Caller must ensure that this is invoked inside of the combiner. void Orphan() override { ShutdownLocked(); - InternallyRefCountedWithTracing<SubchannelListType>::Unref(DEBUG_LOCATION, - "shutdown"); + InternallyRefCounted<SubchannelListType>::Unref(DEBUG_LOCATION, "shutdown"); } GRPC_ABSTRACT_BASE_CLASS protected: SubchannelList(LoadBalancingPolicy* policy, TraceFlag* tracer, - const grpc_lb_addresses* addresses, grpc_combiner* combiner, + const ServerAddressList& addresses, grpc_combiner* combiner, grpc_client_channel_factory* client_channel_factory, const grpc_channel_args& args); @@ -279,8 +276,7 @@ class SubchannelList template <typename SubchannelListType, typename SubchannelDataType> SubchannelData<SubchannelListType, SubchannelDataType>::SubchannelData( SubchannelList<SubchannelListType, SubchannelDataType>* subchannel_list, - const grpc_lb_user_data_vtable* user_data_vtable, - const grpc_lb_address& address, grpc_subchannel* subchannel, + const ServerAddress& address, grpc_subchannel* subchannel, grpc_combiner* combiner) : subchannel_list_(subchannel_list), subchannel_(subchannel), @@ -490,19 +486,19 @@ void SubchannelData<SubchannelListType, SubchannelDataType>::ShutdownLocked() { template <typename SubchannelListType, typename SubchannelDataType> SubchannelList<SubchannelListType, SubchannelDataType>::SubchannelList( LoadBalancingPolicy* policy, TraceFlag* tracer, - const grpc_lb_addresses* addresses, grpc_combiner* combiner, + const ServerAddressList& addresses, grpc_combiner* combiner, grpc_client_channel_factory* client_channel_factory, const grpc_channel_args& args) - : InternallyRefCountedWithTracing<SubchannelListType>(tracer), + : InternallyRefCounted<SubchannelListType>(tracer), policy_(policy), tracer_(tracer), combiner_(GRPC_COMBINER_REF(combiner, "subchannel_list")) { if (tracer_->enabled()) { gpr_log(GPR_INFO, "[%s %p] Creating subchannel list %p for %" PRIuPTR " subchannels", - tracer_->name(), policy, this, addresses->num_addresses); + tracer_->name(), policy, this, addresses.size()); } - subchannels_.reserve(addresses->num_addresses); + subchannels_.reserve(addresses.size()); // We need to remove the LB addresses in order to be able to compare the // subchannel keys of subchannels from a different batch of addresses. // We also remove the inhibit-health-checking arg, since we are @@ -510,19 +506,27 @@ SubchannelList<SubchannelListType, SubchannelDataType>::SubchannelList( inhibit_health_checking_ = grpc_channel_arg_get_bool( grpc_channel_args_find(&args, GRPC_ARG_INHIBIT_HEALTH_CHECKING), false); static const char* keys_to_remove[] = {GRPC_ARG_SUBCHANNEL_ADDRESS, - GRPC_ARG_LB_ADDRESSES, + GRPC_ARG_SERVER_ADDRESS_LIST, GRPC_ARG_INHIBIT_HEALTH_CHECKING}; // Create a subchannel for each address. grpc_subchannel_args sc_args; - for (size_t i = 0; i < addresses->num_addresses; i++) { - // If there were any balancer, we would have chosen grpclb policy instead. - GPR_ASSERT(!addresses->addresses[i].is_balancer); + for (size_t i = 0; i < addresses.size(); i++) { + // If there were any balancer addresses, we would have chosen grpclb + // policy, which does not use a SubchannelList. + GPR_ASSERT(!addresses[i].IsBalancer()); memset(&sc_args, 0, sizeof(grpc_subchannel_args)); - grpc_arg addr_arg = - grpc_create_subchannel_address_arg(&addresses->addresses[i].address); + InlinedVector<grpc_arg, 4> args_to_add; + args_to_add.emplace_back( + grpc_create_subchannel_address_arg(&addresses[i].address())); + if (addresses[i].args() != nullptr) { + for (size_t j = 0; j < addresses[i].args()->num_args; ++j) { + args_to_add.emplace_back(addresses[i].args()->args[j]); + } + } grpc_channel_args* new_args = grpc_channel_args_copy_and_add_and_remove( - &args, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), &addr_arg, 1); - gpr_free(addr_arg.value.string); + &args, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), + args_to_add.data(), args_to_add.size()); + gpr_free(args_to_add[0].value.string); sc_args.args = new_args; grpc_subchannel* subchannel = grpc_client_channel_factory_create_subchannel( client_channel_factory, &sc_args); @@ -530,8 +534,7 @@ SubchannelList<SubchannelListType, SubchannelDataType>::SubchannelList( if (subchannel == nullptr) { // Subchannel could not be created. if (tracer_->enabled()) { - char* address_uri = - grpc_sockaddr_to_uri(&addresses->addresses[i].address); + char* address_uri = grpc_sockaddr_to_uri(&addresses[i].address()); gpr_log(GPR_INFO, "[%s %p] could not create subchannel for address uri %s, " "ignoring", @@ -541,8 +544,7 @@ SubchannelList<SubchannelListType, SubchannelDataType>::SubchannelList( continue; } if (tracer_->enabled()) { - char* address_uri = - grpc_sockaddr_to_uri(&addresses->addresses[i].address); + char* address_uri = grpc_sockaddr_to_uri(&addresses[i].address()); gpr_log(GPR_INFO, "[%s %p] subchannel list %p index %" PRIuPTR ": Created subchannel %p for address uri %s", @@ -550,8 +552,7 @@ SubchannelList<SubchannelListType, SubchannelDataType>::SubchannelList( address_uri); gpr_free(address_uri); } - subchannels_.emplace_back(this, addresses->user_data_vtable, - addresses->addresses[i], subchannel, combiner); + subchannels_.emplace_back(this, addresses[i], subchannel, combiner); } } diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc index 59923ab861..8787f5bcc2 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc @@ -26,30 +26,26 @@ /// channel that uses pick_first to select from the list of balancer /// addresses. /// -/// The first time the policy gets a request for a pick, a ping, or to exit -/// the idle state, \a StartPickingLocked() is called. This method is -/// responsible for instantiating the internal *streaming* call to the LB -/// server (whichever address pick_first chose). The call will be complete -/// when either the balancer sends status or when we cancel the call (e.g., -/// because we are shutting down). In needed, we retry the call. If we -/// received at least one valid message from the server, a new call attempt -/// will be made immediately; otherwise, we apply back-off delays between -/// attempts. +/// The first time the xDS policy gets a request for a pick or to exit the idle +/// state, \a StartPickingLocked() is called. This method is responsible for +/// instantiating the internal *streaming* call to the LB server (whichever +/// address pick_first chose). The call will be complete when either the +/// balancer sends status or when we cancel the call (e.g., because we are +/// shutting down). In needed, we retry the call. If we received at least one +/// valid message from the server, a new call attempt will be made immediately; +/// otherwise, we apply back-off delays between attempts. /// -/// We maintain an internal round_robin policy instance for distributing +/// We maintain an internal child policy (round_robin) instance for distributing /// requests across backends. Whenever we receive a new serverlist from -/// the balancer, we update the round_robin policy with the new list of -/// addresses. If we cannot communicate with the balancer on startup, -/// however, we may enter fallback mode, in which case we will populate -/// the RR policy's addresses from the backend addresses returned by the -/// resolver. +/// the balancer, we update the child policy with the new list of +/// addresses. /// -/// Once an RR policy instance is in place (and getting updated as described), -/// calls for a pick, a ping, or a cancellation will be serviced right -/// away by forwarding them to the RR instance. Any time there's no RR -/// policy available (i.e., right after the creation of the gRPCLB policy), -/// pick and ping requests are added to a list of pending picks and pings -/// to be flushed and serviced when the RR policy instance becomes available. +/// Once a child policy instance is in place (and getting updated as +/// described), calls for a pick, or a cancellation will be serviced right away +/// by forwarding them to the child policy instance. Any time there's no child +/// policy available (i.e., right after the creation of the xDS policy), pick +/// requests are added to a list of pending picks to be flushed and serviced +/// when the child policy instance becomes available. /// /// \see https://github.com/grpc/grpc/blob/master/doc/load-balancing.md for the /// high level design and details. @@ -83,6 +79,7 @@ #include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/ext/filters/client_channel/subchannel_index.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channel_args.h" @@ -118,11 +115,16 @@ TraceFlag grpc_lb_xds_trace(false, "xds"); namespace { +constexpr char kXds[] = "xds_experimental"; + class XdsLb : public LoadBalancingPolicy { public: - XdsLb(const grpc_lb_addresses* addresses, const Args& args); + explicit XdsLb(const Args& args); + + const char* name() const override { return kXds; } - void UpdateLocked(const grpc_channel_args& args) override; + void UpdateLocked(const grpc_channel_args& args, + grpc_json* lb_config) override; bool PickLocked(PickState* pick, grpc_error** error) override; void CancelPickLocked(PickState* pick, grpc_error* error) override; void CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, @@ -141,10 +143,10 @@ class XdsLb : public LoadBalancingPolicy { private: /// Linked list of pending pick requests. It stores all information needed to - /// eventually call (Round Robin's) pick() on them. They mainly stay pending - /// waiting for the RR policy to be created. + /// eventually call pick() on them. They mainly stay pending waiting for the + /// child policy to be created. /// - /// Note that when a pick is sent to the RR policy, we inject our own + /// Note that when a pick is sent to the child policy, we inject our own /// on_complete callback, so that we can intercept the result before /// invoking the original on_complete callback. This allows us to set the /// LB token metadata and add client_stats to the call context. @@ -159,9 +161,6 @@ class XdsLb : public LoadBalancingPolicy { // Our on_complete closure and the original one. grpc_closure on_complete; grpc_closure* original_on_complete; - // The LB token associated with the pick. This is set via user_data in - // the pick. - grpc_mdelem lb_token; // Stats for client-side load reporting. RefCountedPtr<XdsLbClientStats> client_stats; // Next pending pick. @@ -169,8 +168,7 @@ class XdsLb : public LoadBalancingPolicy { }; /// Contains a call to the LB server and all the data related to the call. - class BalancerCallState - : public InternallyRefCountedWithTracing<BalancerCallState> { + class BalancerCallState : public InternallyRefCounted<BalancerCallState> { public: explicit BalancerCallState( RefCountedPtr<LoadBalancingPolicy> parent_xdslb_policy); @@ -202,7 +200,6 @@ class XdsLb : public LoadBalancingPolicy { static bool LoadReportCountersAreZero(xds_grpclb_request* request); static void MaybeSendClientLoadReportLocked(void* arg, grpc_error* error); - static void ClientLoadReportDoneLocked(void* arg, grpc_error* error); static void OnInitialRequestSentLocked(void* arg, grpc_error* error); static void OnBalancerMessageReceivedLocked(void* arg, grpc_error* error); static void OnBalancerStatusReceivedLocked(void* arg, grpc_error* error); @@ -261,23 +258,23 @@ class XdsLb : public LoadBalancingPolicy { grpc_error* error); // Pending pick methods. - static void PendingPickSetMetadataAndContext(PendingPick* pp); + static void PendingPickCleanup(PendingPick* pp); PendingPick* PendingPickCreate(PickState* pick); void AddPendingPick(PendingPick* pp); static void OnPendingPickComplete(void* arg, grpc_error* error); - // Methods for dealing with the RR policy. - void CreateOrUpdateRoundRobinPolicyLocked(); - grpc_channel_args* CreateRoundRobinPolicyArgsLocked(); - void CreateRoundRobinPolicyLocked(const Args& args); - bool PickFromRoundRobinPolicyLocked(bool force_async, PendingPick* pp, - grpc_error** error); - void UpdateConnectivityStateFromRoundRobinPolicyLocked( - grpc_error* rr_state_error); - static void OnRoundRobinConnectivityChangedLocked(void* arg, - grpc_error* error); - static void OnRoundRobinRequestReresolutionLocked(void* arg, - grpc_error* error); + // Methods for dealing with the child policy. + void CreateOrUpdateChildPolicyLocked(); + grpc_channel_args* CreateChildPolicyArgsLocked(); + void CreateChildPolicyLocked(const Args& args); + bool PickFromChildPolicyLocked(bool force_async, PendingPick* pp, + grpc_error** error); + void UpdateConnectivityStateFromChildPolicyLocked( + grpc_error* child_state_error); + static void OnChildPolicyConnectivityChangedLocked(void* arg, + grpc_error* error); + static void OnChildPolicyRequestReresolutionLocked(void* arg, + grpc_error* error); // Who the client is trying to communicate with. const char* server_name_ = nullptr; @@ -324,67 +321,35 @@ class XdsLb : public LoadBalancingPolicy { // 0 means not using fallback. int lb_fallback_timeout_ms_ = 0; // The backend addresses from the resolver. - grpc_lb_addresses* fallback_backend_addresses_ = nullptr; + UniquePtr<ServerAddressList> fallback_backend_addresses_; // Fallback timer. bool fallback_timer_callback_pending_ = false; grpc_timer lb_fallback_timer_; grpc_closure lb_on_fallback_; - // Pending picks that are waiting on the RR policy's connectivity. + // Pending picks that are waiting on the xDS policy's connectivity. PendingPick* pending_picks_ = nullptr; - // The RR policy to use for the backends. - OrphanablePtr<LoadBalancingPolicy> rr_policy_; - grpc_connectivity_state rr_connectivity_state_; - grpc_closure on_rr_connectivity_changed_; - grpc_closure on_rr_request_reresolution_; + // The policy to use for the backends. + OrphanablePtr<LoadBalancingPolicy> child_policy_; + grpc_connectivity_state child_connectivity_state_; + grpc_closure on_child_connectivity_changed_; + grpc_closure on_child_request_reresolution_; }; // // serverlist parsing code // -// vtable for LB tokens in grpc_lb_addresses -void* lb_token_copy(void* token) { - return token == nullptr - ? nullptr - : (void*)GRPC_MDELEM_REF(grpc_mdelem{(uintptr_t)token}).payload; -} -void lb_token_destroy(void* token) { - if (token != nullptr) { - GRPC_MDELEM_UNREF(grpc_mdelem{(uintptr_t)token}); - } -} -int lb_token_cmp(void* token1, void* token2) { - if (token1 > token2) return 1; - if (token1 < token2) return -1; - return 0; -} -const grpc_lb_user_data_vtable lb_token_vtable = { - lb_token_copy, lb_token_destroy, lb_token_cmp}; - // Returns the backend addresses extracted from the given addresses. -grpc_lb_addresses* ExtractBackendAddresses(const grpc_lb_addresses* addresses) { - // First pass: count the number of backend addresses. - size_t num_backends = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (!addresses->addresses[i].is_balancer) { - ++num_backends; +UniquePtr<ServerAddressList> ExtractBackendAddresses( + const ServerAddressList& addresses) { + auto backend_addresses = MakeUnique<ServerAddressList>(); + for (size_t i = 0; i < addresses.size(); ++i) { + if (!addresses[i].IsBalancer()) { + backend_addresses->emplace_back(addresses[i]); } } - // Second pass: actually populate the addresses and (empty) LB tokens. - grpc_lb_addresses* backend_addresses = - grpc_lb_addresses_create(num_backends, &lb_token_vtable); - size_t num_copied = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) continue; - const grpc_resolved_address* addr = &addresses->addresses[i].address; - grpc_lb_addresses_set_address(backend_addresses, num_copied, &addr->addr, - addr->len, false /* is_balancer */, - nullptr /* balancer_name */, - (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload); - ++num_copied; - } return backend_addresses; } @@ -434,56 +399,17 @@ void ParseServer(const xds_grpclb_server* server, grpc_resolved_address* addr) { } // Returns addresses extracted from \a serverlist. -grpc_lb_addresses* ProcessServerlist(const xds_grpclb_serverlist* serverlist) { - size_t num_valid = 0; - /* first pass: count how many are valid in order to allocate the necessary - * memory in a single block */ +UniquePtr<ServerAddressList> ProcessServerlist( + const xds_grpclb_serverlist* serverlist) { + auto addresses = MakeUnique<ServerAddressList>(); for (size_t i = 0; i < serverlist->num_servers; ++i) { - if (IsServerValid(serverlist->servers[i], i, true)) ++num_valid; - } - grpc_lb_addresses* lb_addresses = - grpc_lb_addresses_create(num_valid, &lb_token_vtable); - /* second pass: actually populate the addresses and LB tokens (aka user data - * to the outside world) to be read by the RR policy during its creation. - * Given that the validity tests are very cheap, they are performed again - * instead of marking the valid ones during the first pass, as this would - * incurr in an allocation due to the arbitrary number of server */ - size_t addr_idx = 0; - for (size_t sl_idx = 0; sl_idx < serverlist->num_servers; ++sl_idx) { - const xds_grpclb_server* server = serverlist->servers[sl_idx]; - if (!IsServerValid(serverlist->servers[sl_idx], sl_idx, false)) continue; - GPR_ASSERT(addr_idx < num_valid); - /* address processing */ + const xds_grpclb_server* server = serverlist->servers[i]; + if (!IsServerValid(serverlist->servers[i], i, false)) continue; grpc_resolved_address addr; ParseServer(server, &addr); - /* lb token processing */ - void* user_data; - if (server->has_load_balance_token) { - const size_t lb_token_max_length = - GPR_ARRAY_SIZE(server->load_balance_token); - const size_t lb_token_length = - strnlen(server->load_balance_token, lb_token_max_length); - grpc_slice lb_token_mdstr = grpc_slice_from_copied_buffer( - server->load_balance_token, lb_token_length); - user_data = - (void*)grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr) - .payload; - } else { - char* uri = grpc_sockaddr_to_uri(&addr); - gpr_log(GPR_INFO, - "Missing LB token for backend address '%s'. The empty token will " - "be used instead", - uri); - gpr_free(uri); - user_data = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload; - } - grpc_lb_addresses_set_address(lb_addresses, addr_idx, &addr.addr, addr.len, - false /* is_balancer */, - nullptr /* balancer_name */, user_data); - ++addr_idx; + addresses->emplace_back(addr, nullptr); } - GPR_ASSERT(addr_idx == num_valid); - return lb_addresses; + return addresses; } // @@ -492,7 +418,7 @@ grpc_lb_addresses* ProcessServerlist(const xds_grpclb_serverlist* serverlist) { XdsLb::BalancerCallState::BalancerCallState( RefCountedPtr<LoadBalancingPolicy> parent_xdslb_policy) - : InternallyRefCountedWithTracing<BalancerCallState>(&grpc_lb_xds_trace), + : InternallyRefCounted<BalancerCallState>(&grpc_lb_xds_trace), xdslb_policy_(std::move(parent_xdslb_policy)) { GPR_ASSERT(xdslb_policy_ != nullptr); GPR_ASSERT(!xdslb_policy()->shutting_down_); @@ -671,6 +597,7 @@ bool XdsLb::BalancerCallState::LoadReportCountersAreZero( (drop_entries == nullptr || drop_entries->empty()); } +// TODO(vpowar): Use LRS to send the client Load Report. void XdsLb::BalancerCallState::SendClientLoadReportLocked() { // Construct message payload. GPR_ASSERT(send_message_payload_ == nullptr); @@ -688,38 +615,8 @@ void XdsLb::BalancerCallState::SendClientLoadReportLocked() { } else { last_client_load_report_counters_were_zero_ = false; } - grpc_slice request_payload_slice = xds_grpclb_request_encode(request); - send_message_payload_ = - grpc_raw_byte_buffer_create(&request_payload_slice, 1); - grpc_slice_unref_internal(request_payload_slice); + // TODO(vpowar): Send the report on LRS stream. xds_grpclb_request_destroy(request); - // Send the report. - grpc_op op; - memset(&op, 0, sizeof(op)); - op.op = GRPC_OP_SEND_MESSAGE; - op.data.send_message.send_message = send_message_payload_; - GRPC_CLOSURE_INIT(&client_load_report_closure_, ClientLoadReportDoneLocked, - this, grpc_combiner_scheduler(xdslb_policy()->combiner())); - grpc_call_error call_error = grpc_call_start_batch_and_execute( - lb_call_, &op, 1, &client_load_report_closure_); - if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) { - gpr_log(GPR_ERROR, "[xdslb %p] call_error=%d", xdslb_policy_.get(), - call_error); - GPR_ASSERT(GRPC_CALL_OK == call_error); - } -} - -void XdsLb::BalancerCallState::ClientLoadReportDoneLocked(void* arg, - grpc_error* error) { - BalancerCallState* lb_calld = static_cast<BalancerCallState*>(arg); - XdsLb* xdslb_policy = lb_calld->xdslb_policy(); - grpc_byte_buffer_destroy(lb_calld->send_message_payload_); - lb_calld->send_message_payload_ = nullptr; - if (error != GRPC_ERROR_NONE || lb_calld != xdslb_policy->lb_calld_.get()) { - lb_calld->Unref(DEBUG_LOCATION, "client_load_report"); - return; - } - lb_calld->ScheduleNextClientLoadReportLocked(); } void XdsLb::BalancerCallState::OnInitialRequestSentLocked(void* arg, @@ -823,8 +720,7 @@ void XdsLb::BalancerCallState::OnBalancerMessageReceivedLocked( xds_grpclb_destroy_serverlist(xdslb_policy->serverlist_); } else { /* or dispose of the fallback */ - grpc_lb_addresses_destroy(xdslb_policy->fallback_backend_addresses_); - xdslb_policy->fallback_backend_addresses_ = nullptr; + xdslb_policy->fallback_backend_addresses_.reset(); if (xdslb_policy->fallback_timer_callback_pending_) { grpc_timer_cancel(&xdslb_policy->lb_fallback_timer_); } @@ -833,7 +729,7 @@ void XdsLb::BalancerCallState::OnBalancerMessageReceivedLocked( // serverlist instance will be destroyed either upon the next // update or when the XdsLb instance is destroyed. xdslb_policy->serverlist_ = serverlist; - xdslb_policy->CreateOrUpdateRoundRobinPolicyLocked(); + xdslb_policy->CreateOrUpdateChildPolicyLocked(); } } else { if (grpc_lb_xds_trace.enabled()) { @@ -866,7 +762,7 @@ void XdsLb::BalancerCallState::OnBalancerMessageReceivedLocked( &lb_calld->lb_on_balancer_message_received_); GPR_ASSERT(GRPC_CALL_OK == call_error); } else { - lb_calld->Unref(DEBUG_LOCATION, "on_message_received+grpclb_shutdown"); + lb_calld->Unref(DEBUG_LOCATION, "on_message_received+xds_shutdown"); } } @@ -910,31 +806,15 @@ void XdsLb::BalancerCallState::OnBalancerStatusReceivedLocked( // helper code for creating balancer channel // -grpc_lb_addresses* ExtractBalancerAddresses( - const grpc_lb_addresses* addresses) { - size_t num_grpclb_addrs = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) ++num_grpclb_addrs; - } - // There must be at least one balancer address, or else the - // client_channel would not have chosen this LB policy. - GPR_ASSERT(num_grpclb_addrs > 0); - grpc_lb_addresses* lb_addresses = - grpc_lb_addresses_create(num_grpclb_addrs, nullptr); - size_t lb_addresses_idx = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (!addresses->addresses[i].is_balancer) continue; - if (GPR_UNLIKELY(addresses->addresses[i].user_data != nullptr)) { - gpr_log(GPR_ERROR, - "This LB policy doesn't support user data. It will be ignored"); +UniquePtr<ServerAddressList> ExtractBalancerAddresses( + const ServerAddressList& addresses) { + auto balancer_addresses = MakeUnique<ServerAddressList>(); + for (size_t i = 0; i < addresses.size(); ++i) { + if (addresses[i].IsBalancer()) { + balancer_addresses->emplace_back(addresses[i]); } - grpc_lb_addresses_set_address( - lb_addresses, lb_addresses_idx++, addresses->addresses[i].address.addr, - addresses->addresses[i].address.len, false /* is balancer */, - addresses->addresses[i].balancer_name, nullptr /* user data */); } - GPR_ASSERT(num_grpclb_addrs == lb_addresses_idx); - return lb_addresses; + return balancer_addresses; } /* Returns the channel args for the LB channel, used to create a bidirectional @@ -944,12 +824,13 @@ grpc_lb_addresses* ExtractBalancerAddresses( * - \a addresses: corresponding to the balancers. * - \a response_generator: in order to propagate updates from the resolver * above the grpclb policy. - * - \a args: other args inherited from the grpclb policy. */ + * - \a args: other args inherited from the xds policy. */ grpc_channel_args* BuildBalancerChannelArgs( - const grpc_lb_addresses* addresses, + const ServerAddressList& addresses, FakeResolverResponseGenerator* response_generator, const grpc_channel_args* args) { - grpc_lb_addresses* lb_addresses = ExtractBalancerAddresses(addresses); + UniquePtr<ServerAddressList> balancer_addresses = + ExtractBalancerAddresses(addresses); // Channel args to remove. static const char* args_to_remove[] = { // LB policy name, since we want to use the default (pick_first) in @@ -966,10 +847,10 @@ grpc_channel_args* BuildBalancerChannelArgs( // resolver will have is_balancer=false, whereas our own addresses have // is_balancer=true. We need the LB channel to return addresses with // is_balancer=false so that it does not wind up recursively using the - // grpclb LB policy, as per the special case logic in client_channel.c. - GRPC_ARG_LB_ADDRESSES, + // xds LB policy, as per the special case logic in client_channel.c. + GRPC_ARG_SERVER_ADDRESS_LIST, // The fake resolver response generator, because we are replacing it - // with the one from the grpclb policy, used to propagate updates to + // with the one from the xds policy, used to propagate updates to // the LB channel. GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR, // The LB channel should use the authority indicated by the target @@ -983,15 +864,15 @@ grpc_channel_args* BuildBalancerChannelArgs( }; // Channel args to add. const grpc_arg args_to_add[] = { - // New LB addresses. + // New server address list. // Note that we pass these in both when creating the LB channel // and via the fake resolver. The latter is what actually gets used. - grpc_lb_addresses_create_channel_arg(lb_addresses), + CreateServerAddressListChannelArg(balancer_addresses.get()), // The fake resolver response generator, which we use to inject // address updates into the LB channel. grpc_core::FakeResolverResponseGenerator::MakeChannelArg( response_generator), - // A channel arg indicating the target is a grpclb load balancer. + // A channel arg indicating the target is a xds load balancer. grpc_channel_arg_integer_create( const_cast<char*>(GRPC_ARG_ADDRESS_IS_XDS_LOAD_BALANCER), 1), // A channel arg indicating this is an internal channels, aka it is @@ -1004,18 +885,15 @@ grpc_channel_args* BuildBalancerChannelArgs( args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), args_to_add, GPR_ARRAY_SIZE(args_to_add)); // Make any necessary modifications for security. - new_args = grpc_lb_policy_xds_modify_lb_channel_args(new_args); - // Clean up. - grpc_lb_addresses_destroy(lb_addresses); - return new_args; + return grpc_lb_policy_xds_modify_lb_channel_args(new_args); } // // ctor and dtor // -XdsLb::XdsLb(const grpc_lb_addresses* addresses, - const LoadBalancingPolicy::Args& args) +// TODO(vishalpowar): Use lb_config in args to configure LB policy. +XdsLb::XdsLb(const LoadBalancingPolicy::Args& args) : LoadBalancingPolicy(args), response_generator_(MakeRefCounted<FakeResolverResponseGenerator>()), lb_call_backoff_( @@ -1031,11 +909,11 @@ XdsLb::XdsLb(const grpc_lb_addresses* addresses, GRPC_CLOSURE_INIT(&lb_channel_on_connectivity_changed_, &XdsLb::OnBalancerChannelConnectivityChangedLocked, this, grpc_combiner_scheduler(args.combiner)); - GRPC_CLOSURE_INIT(&on_rr_connectivity_changed_, - &XdsLb::OnRoundRobinConnectivityChangedLocked, this, + GRPC_CLOSURE_INIT(&on_child_connectivity_changed_, + &XdsLb::OnChildPolicyConnectivityChangedLocked, this, grpc_combiner_scheduler(args.combiner)); - GRPC_CLOSURE_INIT(&on_rr_request_reresolution_, - &XdsLb::OnRoundRobinRequestReresolutionLocked, this, + GRPC_CLOSURE_INIT(&on_child_request_reresolution_, + &XdsLb::OnChildPolicyRequestReresolutionLocked, this, grpc_combiner_scheduler(args.combiner)); grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE, "xds"); // Record server name. @@ -1071,9 +949,6 @@ XdsLb::~XdsLb() { if (serverlist_ != nullptr) { xds_grpclb_destroy_serverlist(serverlist_); } - if (fallback_backend_addresses_ != nullptr) { - grpc_lb_addresses_destroy(fallback_backend_addresses_); - } grpc_subchannel_index_unref(); } @@ -1087,7 +962,7 @@ void XdsLb::ShutdownLocked() { if (fallback_timer_callback_pending_) { grpc_timer_cancel(&lb_fallback_timer_); } - rr_policy_.reset(); + child_policy_.reset(); TryReresolutionLocked(&grpc_lb_xds_trace, GRPC_ERROR_CANCELLED); // We destroy the LB channel here instead of in our destructor because // destroying the channel triggers a last callback to @@ -1100,7 +975,7 @@ void XdsLb::ShutdownLocked() { gpr_mu_unlock(&lb_channel_mu_); } grpc_connectivity_state_set(&state_tracker_, GRPC_CHANNEL_SHUTDOWN, - GRPC_ERROR_REF(error), "grpclb_shutdown"); + GRPC_ERROR_REF(error), "xds_shutdown"); // Clear pending picks. PendingPick* pp; while ((pp = pending_picks_) != nullptr) { @@ -1121,7 +996,6 @@ void XdsLb::HandOffPendingPicksLocked(LoadBalancingPolicy* new_policy) { while ((pp = pending_picks_) != nullptr) { pending_picks_ = pp->next; pp->pick->on_complete = pp->original_on_complete; - pp->pick->user_data = nullptr; grpc_error* error = GRPC_ERROR_NONE; if (new_policy->PickLocked(pp->pick, &error)) { // Synchronous return; schedule closure. @@ -1133,13 +1007,13 @@ void XdsLb::HandOffPendingPicksLocked(LoadBalancingPolicy* new_policy) { // Cancel a specific pending pick. // -// A grpclb pick progresses as follows: -// - If there's a Round Robin policy (rr_policy_) available, it'll be -// handed over to the RR policy (in CreateRoundRobinPolicyLocked()). From -// that point onwards, it'll be RR's responsibility. For cancellations, that -// implies the pick needs also be cancelled by the RR instance. -// - Otherwise, without an RR instance, picks stay pending at this policy's -// level (grpclb), inside the pending_picks_ list. To cancel these, +// A pick progresses as follows: +// - If there's a child policy available, it'll be handed over to child policy +// (in CreateChildPolicyLocked()). From that point onwards, it'll be the +// child policy's responsibility. For cancellations, that implies the pick +// needs to be also cancelled by the child policy instance. +// - Otherwise, without a child policy instance, picks stay pending at this +// policy's level (xds), inside the pending_picks_ list. To cancel these, // we invoke the completion closure and set the pick's connected // subchannel to nullptr right here. void XdsLb::CancelPickLocked(PickState* pick, grpc_error* error) { @@ -1159,21 +1033,21 @@ void XdsLb::CancelPickLocked(PickState* pick, grpc_error* error) { } pp = next; } - if (rr_policy_ != nullptr) { - rr_policy_->CancelPickLocked(pick, GRPC_ERROR_REF(error)); + if (child_policy_ != nullptr) { + child_policy_->CancelPickLocked(pick, GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); } // Cancel all pending picks. // -// A grpclb pick progresses as follows: -// - If there's a Round Robin policy (rr_policy_) available, it'll be -// handed over to the RR policy (in CreateRoundRobinPolicyLocked()). From -// that point onwards, it'll be RR's responsibility. For cancellations, that -// implies the pick needs also be cancelled by the RR instance. -// - Otherwise, without an RR instance, picks stay pending at this policy's -// level (grpclb), inside the pending_picks_ list. To cancel these, +// A pick progresses as follows: +// - If there's a child policy available, it'll be handed over to child policy +// (in CreateChildPolicyLocked()). From that point onwards, it'll be the +// child policy's responsibility. For cancellations, that implies the pick +// needs to be also cancelled by the child policy instance. +// - Otherwise, without a child policy instance, picks stay pending at this +// policy's level (xds), inside the pending_picks_ list. To cancel these, // we invoke the completion closure and set the pick's connected // subchannel to nullptr right here. void XdsLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, @@ -1183,7 +1057,7 @@ void XdsLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, pending_picks_ = nullptr; while (pp != nullptr) { PendingPick* next = pp->next; - if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == + if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) == initial_metadata_flags_eq) { // Note: pp is deleted in this callback. GRPC_CLOSURE_SCHED(&pp->on_complete, @@ -1195,10 +1069,10 @@ void XdsLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask, } pp = next; } - if (rr_policy_ != nullptr) { - rr_policy_->CancelMatchingPicksLocked(initial_metadata_flags_mask, - initial_metadata_flags_eq, - GRPC_ERROR_REF(error)); + if (child_policy_ != nullptr) { + child_policy_->CancelMatchingPicksLocked(initial_metadata_flags_mask, + initial_metadata_flags_eq, + GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); } @@ -1213,22 +1087,21 @@ void XdsLb::ResetBackoffLocked() { if (lb_channel_ != nullptr) { grpc_channel_reset_connect_backoff(lb_channel_); } - if (rr_policy_ != nullptr) { - rr_policy_->ResetBackoffLocked(); + if (child_policy_ != nullptr) { + child_policy_->ResetBackoffLocked(); } } bool XdsLb::PickLocked(PickState* pick, grpc_error** error) { PendingPick* pp = PendingPickCreate(pick); bool pick_done = false; - if (rr_policy_ != nullptr) { + if (child_policy_ != nullptr) { if (grpc_lb_xds_trace.enabled()) { - gpr_log(GPR_INFO, "[xdslb %p] about to PICK from RR %p", this, - rr_policy_.get()); + gpr_log(GPR_INFO, "[xdslb %p] about to PICK from policy %p", this, + child_policy_.get()); } - pick_done = - PickFromRoundRobinPolicyLocked(false /* force_async */, pp, error); - } else { // rr_policy_ == NULL + pick_done = PickFromChildPolicyLocked(false /* force_async */, pp, error); + } else { // child_policy_ == NULL if (pick->on_complete == nullptr) { *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( "No pick result available but synchronous result required."); @@ -1236,7 +1109,7 @@ bool XdsLb::PickLocked(PickState* pick, grpc_error** error) { } else { if (grpc_lb_xds_trace.enabled()) { gpr_log(GPR_INFO, - "[xdslb %p] No RR policy. Adding to grpclb's pending picks", + "[xdslb %p] No child policy. Adding to xds's pending picks", this); } AddPendingPick(pp); @@ -1251,8 +1124,8 @@ bool XdsLb::PickLocked(PickState* pick, grpc_error** error) { void XdsLb::FillChildRefsForChannelz(channelz::ChildRefsList* child_subchannels, channelz::ChildRefsList* child_channels) { - // delegate to the RoundRobin to fill the children subchannels. - rr_policy_->FillChildRefsForChannelz(child_subchannels, child_channels); + // delegate to the child_policy_ to fill the children subchannels. + child_policy_->FillChildRefsForChannelz(child_subchannels, child_channels); MutexLock lock(&lb_channel_mu_); if (lb_channel_ != nullptr) { grpc_core::channelz::ChannelNode* channel_node = @@ -1275,21 +1148,16 @@ void XdsLb::NotifyOnStateChangeLocked(grpc_connectivity_state* current, } void XdsLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { - const grpc_arg* arg = grpc_channel_args_find(&args, GRPC_ARG_LB_ADDRESSES); - if (GPR_UNLIKELY(arg == nullptr || arg->type != GRPC_ARG_POINTER)) { + const ServerAddressList* addresses = FindServerAddressListChannelArg(&args); + if (addresses == nullptr) { // Ignore this update. gpr_log(GPR_ERROR, "[xdslb %p] No valid LB addresses channel arg in update, ignoring.", this); return; } - const grpc_lb_addresses* addresses = - static_cast<const grpc_lb_addresses*>(arg->value.pointer.p); // Update fallback address list. - if (fallback_backend_addresses_ != nullptr) { - grpc_lb_addresses_destroy(fallback_backend_addresses_); - } - fallback_backend_addresses_ = ExtractBackendAddresses(addresses); + fallback_backend_addresses_ = ExtractBackendAddresses(*addresses); // Make sure that GRPC_ARG_LB_POLICY_NAME is set in channel args, // since we use this to trigger the client_load_reporting filter. static const char* args_to_remove[] = {GRPC_ARG_LB_POLICY_NAME}; @@ -1300,7 +1168,7 @@ void XdsLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { &args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), &new_arg, 1); // Construct args for balancer channel. grpc_channel_args* lb_channel_args = - BuildBalancerChannelArgs(addresses, response_generator_.get(), &args); + BuildBalancerChannelArgs(*addresses, response_generator_.get(), &args); // Create balancer channel if needed. if (lb_channel_ == nullptr) { char* uri_str; @@ -1319,12 +1187,15 @@ void XdsLb::ProcessChannelArgsLocked(const grpc_channel_args& args) { grpc_channel_args_destroy(lb_channel_args); } -void XdsLb::UpdateLocked(const grpc_channel_args& args) { +// TODO(vishalpowar): Use lb_config to configure LB policy. +void XdsLb::UpdateLocked(const grpc_channel_args& args, grpc_json* lb_config) { ProcessChannelArgsLocked(args); - // Note: We have disabled fallback mode in the code, so we don't need to - // handle fallback address changes. + // Update the existing child policy. + // Note: We have disabled fallback mode in the code, so this child policy must + // have been created from a serverlist. // TODO(vpowar): Handle the fallback_address changes when we add support for // fallback in xDS. + if (child_policy_ != nullptr) CreateOrUpdateChildPolicyLocked(); // Start watching the LB channel connectivity for connection, if not // already doing so. if (!watching_lb_channel_) { @@ -1445,8 +1316,8 @@ void XdsLb::OnBalancerChannelConnectivityChangedLocked(void* arg, XdsLb* xdslb_policy = static_cast<XdsLb*>(arg); if (xdslb_policy->shutting_down_) goto done; // Re-initialize the lb_call. This should also take care of updating the - // embedded RR policy. Note that the current RR policy, if any, will stay in - // effect until an update from the new lb_call is received. + // child policy. Note that the current child policy, if any, will + // stay in effect until an update from the new lb_call is received. switch (xdslb_policy->lb_channel_connectivity_) { case GRPC_CHANNEL_CONNECTING: case GRPC_CHANNEL_TRANSIENT_FAILURE: { @@ -1488,37 +1359,15 @@ void XdsLb::OnBalancerChannelConnectivityChangedLocked(void* arg, // PendingPick // -// Adds lb_token of selected subchannel (address) to the call's initial -// metadata. -grpc_error* AddLbTokenToInitialMetadata( - grpc_mdelem lb_token, grpc_linked_mdelem* lb_token_mdelem_storage, - grpc_metadata_batch* initial_metadata) { - GPR_ASSERT(lb_token_mdelem_storage != nullptr); - GPR_ASSERT(!GRPC_MDISNULL(lb_token)); - return grpc_metadata_batch_add_tail(initial_metadata, lb_token_mdelem_storage, - lb_token); -} - // Destroy function used when embedding client stats in call context. void DestroyClientStats(void* arg) { static_cast<XdsLbClientStats*>(arg)->Unref(); } -void XdsLb::PendingPickSetMetadataAndContext(PendingPick* pp) { - /* if connected_subchannel is nullptr, no pick has been made by the RR - * policy (e.g., all addresses failed to connect). There won't be any - * user_data/token available */ +void XdsLb::PendingPickCleanup(PendingPick* pp) { + // If connected_subchannel is nullptr, no pick has been made by the + // child policy (e.g., all addresses failed to connect). if (pp->pick->connected_subchannel != nullptr) { - if (GPR_LIKELY(!GRPC_MDISNULL(pp->lb_token))) { - AddLbTokenToInitialMetadata(GRPC_MDELEM_REF(pp->lb_token), - &pp->pick->lb_token_mdelem_storage, - pp->pick->initial_metadata); - } else { - gpr_log(GPR_ERROR, - "[xdslb %p] No LB token for connected subchannel pick %p", - pp->xdslb_policy, pp->pick); - abort(); - } // Pass on client stats via context. Passes ownership of the reference. if (pp->client_stats != nullptr) { pp->pick->subchannel_call_context[GRPC_GRPCLB_CLIENT_STATS].value = @@ -1532,11 +1381,11 @@ void XdsLb::PendingPickSetMetadataAndContext(PendingPick* pp) { } /* The \a on_complete closure passed as part of the pick requires keeping a - * reference to its associated round robin instance. We wrap this closure in - * order to unref the round robin instance upon its invocation */ + * reference to its associated child policy instance. We wrap this closure in + * order to unref the child policy instance upon its invocation */ void XdsLb::OnPendingPickComplete(void* arg, grpc_error* error) { PendingPick* pp = static_cast<PendingPick*>(arg); - PendingPickSetMetadataAndContext(pp); + PendingPickCleanup(pp); GRPC_CLOSURE_SCHED(pp->original_on_complete, GRPC_ERROR_REF(error)); Delete(pp); } @@ -1558,26 +1407,24 @@ void XdsLb::AddPendingPick(PendingPick* pp) { } // -// code for interacting with the RR policy +// code for interacting with the child policy // -// Performs a pick over \a rr_policy_. Given that a pick can return +// Performs a pick over \a child_policy_. Given that a pick can return // immediately (ignoring its completion callback), we need to perform the // cleanups this callback would otherwise be responsible for. // If \a force_async is true, then we will manually schedule the // completion callback even if the pick is available immediately. -bool XdsLb::PickFromRoundRobinPolicyLocked(bool force_async, PendingPick* pp, - grpc_error** error) { - // Set client_stats and user_data. +bool XdsLb::PickFromChildPolicyLocked(bool force_async, PendingPick* pp, + grpc_error** error) { + // Set client_stats. if (lb_calld_ != nullptr && lb_calld_->client_stats() != nullptr) { pp->client_stats = lb_calld_->client_stats()->Ref(); } - GPR_ASSERT(pp->pick->user_data == nullptr); - pp->pick->user_data = (void**)&pp->lb_token; - // Pick via the RR policy. - bool pick_done = rr_policy_->PickLocked(pp->pick, error); + // Pick via the child policy. + bool pick_done = child_policy_->PickLocked(pp->pick, error); if (pick_done) { - PendingPickSetMetadataAndContext(pp); + PendingPickCleanup(pp); if (force_async) { GRPC_CLOSURE_SCHED(pp->original_on_complete, *error); *error = GRPC_ERROR_NONE; @@ -1586,71 +1433,72 @@ bool XdsLb::PickFromRoundRobinPolicyLocked(bool force_async, PendingPick* pp, Delete(pp); } // else, the pending pick will be registered and taken care of by the - // pending pick list inside the RR policy. Eventually, + // pending pick list inside the child policy. Eventually, // OnPendingPickComplete() will be called, which will (among other // things) add the LB token to the call's initial metadata. return pick_done; } -void XdsLb::CreateRoundRobinPolicyLocked(const Args& args) { - GPR_ASSERT(rr_policy_ == nullptr); - rr_policy_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( +void XdsLb::CreateChildPolicyLocked(const Args& args) { + GPR_ASSERT(child_policy_ == nullptr); + child_policy_ = LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( "round_robin", args); - if (GPR_UNLIKELY(rr_policy_ == nullptr)) { - gpr_log(GPR_ERROR, "[xdslb %p] Failure creating a RoundRobin policy", this); + if (GPR_UNLIKELY(child_policy_ == nullptr)) { + gpr_log(GPR_ERROR, "[xdslb %p] Failure creating a child policy", this); return; } // TODO(roth): We currently track this ref manually. Once the new // ClosureRef API is done, pass the RefCountedPtr<> along with the closure. - auto self = Ref(DEBUG_LOCATION, "on_rr_reresolution_requested"); + auto self = Ref(DEBUG_LOCATION, "on_child_reresolution_requested"); self.release(); - rr_policy_->SetReresolutionClosureLocked(&on_rr_request_reresolution_); - grpc_error* rr_state_error = nullptr; - rr_connectivity_state_ = rr_policy_->CheckConnectivityLocked(&rr_state_error); - // Connectivity state is a function of the RR policy updated/created. - UpdateConnectivityStateFromRoundRobinPolicyLocked(rr_state_error); - // Add the gRPC LB's interested_parties pollset_set to that of the newly - // created RR policy. This will make the RR policy progress upon activity on - // gRPC LB, which in turn is tied to the application's call. - grpc_pollset_set_add_pollset_set(rr_policy_->interested_parties(), + child_policy_->SetReresolutionClosureLocked(&on_child_request_reresolution_); + grpc_error* child_state_error = nullptr; + child_connectivity_state_ = + child_policy_->CheckConnectivityLocked(&child_state_error); + // Connectivity state is a function of the child policy updated/created. + UpdateConnectivityStateFromChildPolicyLocked(child_state_error); + // Add the xDS's interested_parties pollset_set to that of the newly created + // child policy. This will make the child policy progress upon activity on + // xDS LB, which in turn is tied to the application's call. + grpc_pollset_set_add_pollset_set(child_policy_->interested_parties(), interested_parties()); - // Subscribe to changes to the connectivity of the new RR. + // Subscribe to changes to the connectivity of the new child policy. // TODO(roth): We currently track this ref manually. Once the new // ClosureRef API is done, pass the RefCountedPtr<> along with the closure. - self = Ref(DEBUG_LOCATION, "on_rr_connectivity_changed"); + self = Ref(DEBUG_LOCATION, "on_child_connectivity_changed"); self.release(); - rr_policy_->NotifyOnStateChangeLocked(&rr_connectivity_state_, - &on_rr_connectivity_changed_); - rr_policy_->ExitIdleLocked(); - // Send pending picks to RR policy. + child_policy_->NotifyOnStateChangeLocked(&child_connectivity_state_, + &on_child_connectivity_changed_); + child_policy_->ExitIdleLocked(); + // Send pending picks to child policy. PendingPick* pp; while ((pp = pending_picks_)) { pending_picks_ = pp->next; if (grpc_lb_xds_trace.enabled()) { - gpr_log(GPR_INFO, - "[xdslb %p] Pending pick about to (async) PICK from RR %p", this, - rr_policy_.get()); + gpr_log( + GPR_INFO, + "[xdslb %p] Pending pick about to (async) PICK from child policy %p", + this, child_policy_.get()); } grpc_error* error = GRPC_ERROR_NONE; - PickFromRoundRobinPolicyLocked(true /* force_async */, pp, &error); + PickFromChildPolicyLocked(true /* force_async */, pp, &error); } } -grpc_channel_args* XdsLb::CreateRoundRobinPolicyArgsLocked() { - grpc_lb_addresses* addresses; +grpc_channel_args* XdsLb::CreateChildPolicyArgsLocked() { bool is_backend_from_grpclb_load_balancer = false; // This should never be invoked if we do not have serverlist_, as fallback // mode is disabled for xDS plugin. GPR_ASSERT(serverlist_ != nullptr); GPR_ASSERT(serverlist_->num_servers > 0); - addresses = ProcessServerlist(serverlist_); - is_backend_from_grpclb_load_balancer = true; + UniquePtr<ServerAddressList> addresses = ProcessServerlist(serverlist_); GPR_ASSERT(addresses != nullptr); - // Replace the LB addresses in the channel args that we pass down to + is_backend_from_grpclb_load_balancer = true; + // Replace the server address list in the channel args that we pass down to // the subchannel. - static const char* keys_to_remove[] = {GRPC_ARG_LB_ADDRESSES}; + static const char* keys_to_remove[] = {GRPC_ARG_SERVER_ADDRESS_LIST}; const grpc_arg args_to_add[] = { - grpc_lb_addresses_create_channel_arg(addresses), + CreateServerAddressListChannelArg(addresses.get()), // A channel arg indicating if the target is a backend inferred from a // grpclb load balancer. grpc_channel_arg_integer_create( @@ -1660,70 +1508,71 @@ grpc_channel_args* XdsLb::CreateRoundRobinPolicyArgsLocked() { grpc_channel_args* args = grpc_channel_args_copy_and_add_and_remove( args_, keys_to_remove, GPR_ARRAY_SIZE(keys_to_remove), args_to_add, GPR_ARRAY_SIZE(args_to_add)); - grpc_lb_addresses_destroy(addresses); return args; } -void XdsLb::CreateOrUpdateRoundRobinPolicyLocked() { +void XdsLb::CreateOrUpdateChildPolicyLocked() { if (shutting_down_) return; - grpc_channel_args* args = CreateRoundRobinPolicyArgsLocked(); + grpc_channel_args* args = CreateChildPolicyArgsLocked(); GPR_ASSERT(args != nullptr); - if (rr_policy_ != nullptr) { + if (child_policy_ != nullptr) { if (grpc_lb_xds_trace.enabled()) { - gpr_log(GPR_INFO, "[xdslb %p] Updating RR policy %p", this, - rr_policy_.get()); + gpr_log(GPR_INFO, "[xdslb %p] Updating the child policy %p", this, + child_policy_.get()); } - rr_policy_->UpdateLocked(*args); + // TODO(vishalpowar): Pass the correct LB config. + child_policy_->UpdateLocked(*args, nullptr); } else { LoadBalancingPolicy::Args lb_policy_args; lb_policy_args.combiner = combiner(); lb_policy_args.client_channel_factory = client_channel_factory(); lb_policy_args.args = args; - CreateRoundRobinPolicyLocked(lb_policy_args); + CreateChildPolicyLocked(lb_policy_args); if (grpc_lb_xds_trace.enabled()) { - gpr_log(GPR_INFO, "[xdslb %p] Created new RR policy %p", this, - rr_policy_.get()); + gpr_log(GPR_INFO, "[xdslb %p] Created a new child policy %p", this, + child_policy_.get()); } } grpc_channel_args_destroy(args); } -void XdsLb::OnRoundRobinRequestReresolutionLocked(void* arg, - grpc_error* error) { +void XdsLb::OnChildPolicyRequestReresolutionLocked(void* arg, + grpc_error* error) { XdsLb* xdslb_policy = static_cast<XdsLb*>(arg); if (xdslb_policy->shutting_down_ || error != GRPC_ERROR_NONE) { - xdslb_policy->Unref(DEBUG_LOCATION, "on_rr_reresolution_requested"); + xdslb_policy->Unref(DEBUG_LOCATION, "on_child_reresolution_requested"); return; } if (grpc_lb_xds_trace.enabled()) { - gpr_log( - GPR_INFO, - "[xdslb %p] Re-resolution requested from the internal RR policy (%p).", - xdslb_policy, xdslb_policy->rr_policy_.get()); + gpr_log(GPR_INFO, + "[xdslb %p] Re-resolution requested from child policy " + "(%p).", + xdslb_policy, xdslb_policy->child_policy_.get()); } // If we are talking to a balancer, we expect to get updated addresses form - // the balancer, so we can ignore the re-resolution request from the RR - // policy. Otherwise, handle the re-resolution request using the - // grpclb policy's original re-resolution closure. + // the balancer, so we can ignore the re-resolution request from the child + // policy. + // Otherwise, handle the re-resolution request using the xds policy's + // original re-resolution closure. if (xdslb_policy->lb_calld_ == nullptr || !xdslb_policy->lb_calld_->seen_initial_response()) { xdslb_policy->TryReresolutionLocked(&grpc_lb_xds_trace, GRPC_ERROR_NONE); } - // Give back the wrapper closure to the RR policy. - xdslb_policy->rr_policy_->SetReresolutionClosureLocked( - &xdslb_policy->on_rr_request_reresolution_); + // Give back the wrapper closure to the child policy. + xdslb_policy->child_policy_->SetReresolutionClosureLocked( + &xdslb_policy->on_child_request_reresolution_); } -void XdsLb::UpdateConnectivityStateFromRoundRobinPolicyLocked( - grpc_error* rr_state_error) { +void XdsLb::UpdateConnectivityStateFromChildPolicyLocked( + grpc_error* child_state_error) { const grpc_connectivity_state curr_glb_state = grpc_connectivity_state_check(&state_tracker_); /* The new connectivity status is a function of the previous one and the new - * input coming from the status of the RR policy. + * input coming from the status of the child policy. * - * current state (grpclb's) + * current state (xds's) * | - * v || I | C | R | TF | SD | <- new state (RR's) + * v || I | C | R | TF | SD | <- new state (child policy's) * ===++====+=====+=====+======+======+ * I || I | C | R | [I] | [I] | * ---++----+-----+-----+------+------+ @@ -1736,52 +1585,51 @@ void XdsLb::UpdateConnectivityStateFromRoundRobinPolicyLocked( * SD || NA | NA | NA | NA | NA | (*) * ---++----+-----+-----+------+------+ * - * A [STATE] indicates that the old RR policy is kept. In those cases, STATE - * is the current state of grpclb, which is left untouched. + * A [STATE] indicates that the old child policy is kept. In those cases, + * STATE is the current state of xds, which is left untouched. * * In summary, if the new state is TRANSIENT_FAILURE or SHUTDOWN, stick to - * the previous RR instance. + * the previous child policy instance. * * Note that the status is never updated to SHUTDOWN as a result of calling * this function. Only glb_shutdown() has the power to set that state. * * (*) This function mustn't be called during shutting down. */ GPR_ASSERT(curr_glb_state != GRPC_CHANNEL_SHUTDOWN); - switch (rr_connectivity_state_) { + switch (child_connectivity_state_) { case GRPC_CHANNEL_TRANSIENT_FAILURE: case GRPC_CHANNEL_SHUTDOWN: - GPR_ASSERT(rr_state_error != GRPC_ERROR_NONE); + GPR_ASSERT(child_state_error != GRPC_ERROR_NONE); break; case GRPC_CHANNEL_IDLE: case GRPC_CHANNEL_CONNECTING: case GRPC_CHANNEL_READY: - GPR_ASSERT(rr_state_error == GRPC_ERROR_NONE); + GPR_ASSERT(child_state_error == GRPC_ERROR_NONE); } if (grpc_lb_xds_trace.enabled()) { - gpr_log( - GPR_INFO, - "[xdslb %p] Setting grpclb's state to %s from new RR policy %p state.", - this, grpc_connectivity_state_name(rr_connectivity_state_), - rr_policy_.get()); + gpr_log(GPR_INFO, + "[xdslb %p] Setting xds's state to %s from child policy %p state.", + this, grpc_connectivity_state_name(child_connectivity_state_), + child_policy_.get()); } - grpc_connectivity_state_set(&state_tracker_, rr_connectivity_state_, - rr_state_error, + grpc_connectivity_state_set(&state_tracker_, child_connectivity_state_, + child_state_error, "update_lb_connectivity_status_locked"); } -void XdsLb::OnRoundRobinConnectivityChangedLocked(void* arg, - grpc_error* error) { +void XdsLb::OnChildPolicyConnectivityChangedLocked(void* arg, + grpc_error* error) { XdsLb* xdslb_policy = static_cast<XdsLb*>(arg); if (xdslb_policy->shutting_down_) { - xdslb_policy->Unref(DEBUG_LOCATION, "on_rr_connectivity_changed"); + xdslb_policy->Unref(DEBUG_LOCATION, "on_child_connectivity_changed"); return; } - xdslb_policy->UpdateConnectivityStateFromRoundRobinPolicyLocked( + xdslb_policy->UpdateConnectivityStateFromChildPolicyLocked( GRPC_ERROR_REF(error)); - // Resubscribe. Reuse the "on_rr_connectivity_changed" ref. - xdslb_policy->rr_policy_->NotifyOnStateChangeLocked( - &xdslb_policy->rr_connectivity_state_, - &xdslb_policy->on_rr_connectivity_changed_); + // Resubscribe. Reuse the "on_child_connectivity_changed" ref. + xdslb_policy->child_policy_->NotifyOnStateChangeLocked( + &xdslb_policy->child_connectivity_state_, + &xdslb_policy->on_child_connectivity_changed_); } // @@ -1793,22 +1641,21 @@ class XdsFactory : public LoadBalancingPolicyFactory { OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy( const LoadBalancingPolicy::Args& args) const override { /* Count the number of gRPC-LB addresses. There must be at least one. */ - const grpc_arg* arg = - grpc_channel_args_find(args.args, GRPC_ARG_LB_ADDRESSES); - if (arg == nullptr || arg->type != GRPC_ARG_POINTER) { - return nullptr; - } - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(arg->value.pointer.p); - size_t num_grpclb_addrs = 0; - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (addresses->addresses[i].is_balancer) ++num_grpclb_addrs; + const ServerAddressList* addresses = + FindServerAddressListChannelArg(args.args); + if (addresses == nullptr) return nullptr; + bool found_balancer_address = false; + for (size_t i = 0; i < addresses->size(); ++i) { + if ((*addresses)[i].IsBalancer()) { + found_balancer_address = true; + break; + } } - if (num_grpclb_addrs == 0) return nullptr; - return OrphanablePtr<LoadBalancingPolicy>(New<XdsLb>(addresses, args)); + if (!found_balancer_address) return nullptr; + return OrphanablePtr<LoadBalancingPolicy>(New<XdsLb>(args)); } - const char* name() const override { return "xds"; } + const char* name() const override { return kXds; } }; } // namespace diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel.h b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel.h index 32c4acc8a3..f713b7f563 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel.h +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel.h @@ -21,7 +21,7 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include <grpc/impl/codegen/grpc_types.h> /// Makes any necessary modifications to \a args for use in the xds /// balancer channel. diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc index 5ab72efce4..55c646e6ee 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc @@ -25,6 +25,7 @@ #include <string.h> #include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/iomgr/sockaddr_utils.h" @@ -41,22 +42,23 @@ int BalancerNameCmp(const grpc_core::UniquePtr<char>& a, } RefCountedPtr<TargetAuthorityTable> CreateTargetAuthorityTable( - grpc_lb_addresses* addresses) { + const ServerAddressList& addresses) { TargetAuthorityTable::Entry* target_authority_entries = - static_cast<TargetAuthorityTable::Entry*>(gpr_zalloc( - sizeof(*target_authority_entries) * addresses->num_addresses)); - for (size_t i = 0; i < addresses->num_addresses; ++i) { + static_cast<TargetAuthorityTable::Entry*>( + gpr_zalloc(sizeof(*target_authority_entries) * addresses.size())); + for (size_t i = 0; i < addresses.size(); ++i) { char* addr_str; - GPR_ASSERT(grpc_sockaddr_to_string( - &addr_str, &addresses->addresses[i].address, true) > 0); + GPR_ASSERT( + grpc_sockaddr_to_string(&addr_str, &addresses[i].address(), true) > 0); target_authority_entries[i].key = grpc_slice_from_copied_string(addr_str); - target_authority_entries[i].value.reset( - gpr_strdup(addresses->addresses[i].balancer_name)); gpr_free(addr_str); + char* balancer_name = grpc_channel_arg_get_string(grpc_channel_args_find( + addresses[i].args(), GRPC_ARG_ADDRESS_BALANCER_NAME)); + target_authority_entries[i].value.reset(gpr_strdup(balancer_name)); } RefCountedPtr<TargetAuthorityTable> target_authority_table = - TargetAuthorityTable::Create(addresses->num_addresses, - target_authority_entries, BalancerNameCmp); + TargetAuthorityTable::Create(addresses.size(), target_authority_entries, + BalancerNameCmp); gpr_free(target_authority_entries); return target_authority_table; } @@ -71,13 +73,12 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args( grpc_arg args_to_add[2]; size_t num_args_to_add = 0; // Add arg for targets info table. - const grpc_arg* arg = grpc_channel_args_find(args, GRPC_ARG_LB_ADDRESSES); - GPR_ASSERT(arg != nullptr); - GPR_ASSERT(arg->type == GRPC_ARG_POINTER); - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(arg->value.pointer.p); + grpc_core::ServerAddressList* addresses = + grpc_core::FindServerAddressListChannelArg(args); + GPR_ASSERT(addresses != nullptr); grpc_core::RefCountedPtr<grpc_core::TargetAuthorityTable> - target_authority_table = grpc_core::CreateTargetAuthorityTable(addresses); + target_authority_table = + grpc_core::CreateTargetAuthorityTable(*addresses); args_to_add[num_args_to_add++] = grpc_core::CreateTargetAuthorityTableChannelArg( target_authority_table.get()); @@ -86,22 +87,18 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_load_balancer_api.h b/src/core/ext/filters/client_channel/lb_policy/xds/xds_load_balancer_api.h index 9d08defa7e..6704995641 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_load_balancer_api.h +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_load_balancer_api.h @@ -25,7 +25,7 @@ #include "src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.h" #include "src/core/ext/filters/client_channel/lb_policy/xds/xds_client_stats.h" -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/lib/iomgr/exec_ctx.h" #define XDS_SERVICE_NAME_MAX_LENGTH 128 diff --git a/src/core/ext/filters/client_channel/lb_policy_factory.cc b/src/core/ext/filters/client_channel/lb_policy_factory.cc deleted file mode 100644 index 5c6363d295..0000000000 --- a/src/core/ext/filters/client_channel/lb_policy_factory.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* - * - * Copyright 2015 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include <grpc/support/port_platform.h> - -#include <string.h> - -#include <grpc/support/alloc.h> -#include <grpc/support/string_util.h> - -#include "src/core/lib/channel/channel_args.h" - -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" -#include "src/core/ext/filters/client_channel/parse_address.h" - -grpc_lb_addresses* grpc_lb_addresses_create( - size_t num_addresses, const grpc_lb_user_data_vtable* user_data_vtable) { - grpc_lb_addresses* addresses = - static_cast<grpc_lb_addresses*>(gpr_zalloc(sizeof(grpc_lb_addresses))); - addresses->num_addresses = num_addresses; - addresses->user_data_vtable = user_data_vtable; - const size_t addresses_size = sizeof(grpc_lb_address) * num_addresses; - addresses->addresses = - static_cast<grpc_lb_address*>(gpr_zalloc(addresses_size)); - return addresses; -} - -grpc_lb_addresses* grpc_lb_addresses_copy(const grpc_lb_addresses* addresses) { - grpc_lb_addresses* new_addresses = grpc_lb_addresses_create( - addresses->num_addresses, addresses->user_data_vtable); - memcpy(new_addresses->addresses, addresses->addresses, - sizeof(grpc_lb_address) * addresses->num_addresses); - for (size_t i = 0; i < addresses->num_addresses; ++i) { - if (new_addresses->addresses[i].balancer_name != nullptr) { - new_addresses->addresses[i].balancer_name = - gpr_strdup(new_addresses->addresses[i].balancer_name); - } - if (new_addresses->addresses[i].user_data != nullptr) { - new_addresses->addresses[i].user_data = addresses->user_data_vtable->copy( - new_addresses->addresses[i].user_data); - } - } - return new_addresses; -} - -void grpc_lb_addresses_set_address(grpc_lb_addresses* addresses, size_t index, - const void* address, size_t address_len, - bool is_balancer, const char* balancer_name, - void* user_data) { - GPR_ASSERT(index < addresses->num_addresses); - if (user_data != nullptr) GPR_ASSERT(addresses->user_data_vtable != nullptr); - grpc_lb_address* target = &addresses->addresses[index]; - memcpy(target->address.addr, address, address_len); - target->address.len = static_cast<socklen_t>(address_len); - target->is_balancer = is_balancer; - target->balancer_name = gpr_strdup(balancer_name); - target->user_data = user_data; -} - -bool grpc_lb_addresses_set_address_from_uri(grpc_lb_addresses* addresses, - size_t index, const grpc_uri* uri, - bool is_balancer, - const char* balancer_name, - void* user_data) { - grpc_resolved_address address; - if (!grpc_parse_uri(uri, &address)) return false; - grpc_lb_addresses_set_address(addresses, index, address.addr, address.len, - is_balancer, balancer_name, user_data); - return true; -} - -int grpc_lb_addresses_cmp(const grpc_lb_addresses* addresses1, - const grpc_lb_addresses* addresses2) { - if (addresses1->num_addresses > addresses2->num_addresses) return 1; - if (addresses1->num_addresses < addresses2->num_addresses) return -1; - if (addresses1->user_data_vtable > addresses2->user_data_vtable) return 1; - if (addresses1->user_data_vtable < addresses2->user_data_vtable) return -1; - for (size_t i = 0; i < addresses1->num_addresses; ++i) { - const grpc_lb_address* target1 = &addresses1->addresses[i]; - const grpc_lb_address* target2 = &addresses2->addresses[i]; - if (target1->address.len > target2->address.len) return 1; - if (target1->address.len < target2->address.len) return -1; - int retval = memcmp(target1->address.addr, target2->address.addr, - target1->address.len); - if (retval != 0) return retval; - if (target1->is_balancer > target2->is_balancer) return 1; - if (target1->is_balancer < target2->is_balancer) return -1; - const char* balancer_name1 = - target1->balancer_name != nullptr ? target1->balancer_name : ""; - const char* balancer_name2 = - target2->balancer_name != nullptr ? target2->balancer_name : ""; - retval = strcmp(balancer_name1, balancer_name2); - if (retval != 0) return retval; - if (addresses1->user_data_vtable != nullptr) { - retval = addresses1->user_data_vtable->cmp(target1->user_data, - target2->user_data); - if (retval != 0) return retval; - } - } - return 0; -} - -void grpc_lb_addresses_destroy(grpc_lb_addresses* addresses) { - for (size_t i = 0; i < addresses->num_addresses; ++i) { - gpr_free(addresses->addresses[i].balancer_name); - if (addresses->addresses[i].user_data != nullptr) { - addresses->user_data_vtable->destroy(addresses->addresses[i].user_data); - } - } - gpr_free(addresses->addresses); - gpr_free(addresses); -} - -static void* lb_addresses_copy(void* addresses) { - return grpc_lb_addresses_copy(static_cast<grpc_lb_addresses*>(addresses)); -} -static void lb_addresses_destroy(void* addresses) { - grpc_lb_addresses_destroy(static_cast<grpc_lb_addresses*>(addresses)); -} -static int lb_addresses_cmp(void* addresses1, void* addresses2) { - return grpc_lb_addresses_cmp(static_cast<grpc_lb_addresses*>(addresses1), - static_cast<grpc_lb_addresses*>(addresses2)); -} -static const grpc_arg_pointer_vtable lb_addresses_arg_vtable = { - lb_addresses_copy, lb_addresses_destroy, lb_addresses_cmp}; - -grpc_arg grpc_lb_addresses_create_channel_arg( - const grpc_lb_addresses* addresses) { - return grpc_channel_arg_pointer_create( - (char*)GRPC_ARG_LB_ADDRESSES, (void*)addresses, &lb_addresses_arg_vtable); -} - -grpc_lb_addresses* grpc_lb_addresses_find_channel_arg( - const grpc_channel_args* channel_args) { - const grpc_arg* lb_addresses_arg = - grpc_channel_args_find(channel_args, GRPC_ARG_LB_ADDRESSES); - if (lb_addresses_arg == nullptr || lb_addresses_arg->type != GRPC_ARG_POINTER) - return nullptr; - return static_cast<grpc_lb_addresses*>(lb_addresses_arg->value.pointer.p); -} - -bool grpc_lb_addresses_contains_balancer_address( - const grpc_lb_addresses& addresses) { - for (size_t i = 0; i < addresses.num_addresses; ++i) { - if (addresses.addresses[i].is_balancer) return true; - } - return false; -} diff --git a/src/core/ext/filters/client_channel/lb_policy_factory.h b/src/core/ext/filters/client_channel/lb_policy_factory.h index a59deadb26..a165ebafab 100644 --- a/src/core/ext/filters/client_channel/lb_policy_factory.h +++ b/src/core/ext/filters/client_channel/lb_policy_factory.h @@ -21,91 +21,9 @@ #include <grpc/support/port_platform.h> -#include "src/core/lib/iomgr/resolve_address.h" - -#include "src/core/ext/filters/client_channel/client_channel_factory.h" #include "src/core/ext/filters/client_channel/lb_policy.h" -#include "src/core/lib/uri/uri_parser.h" - -// -// representation of an LB address -// - -// Channel arg key for grpc_lb_addresses. -#define GRPC_ARG_LB_ADDRESSES "grpc.lb_addresses" - -/** A resolved address alongside any LB related information associated with it. - * \a user_data, if not NULL, contains opaque data meant to be consumed by the - * gRPC LB policy. Note that no all LB policies support \a user_data as input. - * Those who don't will simply ignore it and will correspondingly return NULL in - * their namesake pick() output argument. */ -// TODO(roth): Once we figure out a better way of handling user_data in -// LB policies, convert these structs to C++ classes. -typedef struct grpc_lb_address { - grpc_resolved_address address; - bool is_balancer; - char* balancer_name; /* For secure naming. */ - void* user_data; -} grpc_lb_address; - -typedef struct grpc_lb_user_data_vtable { - void* (*copy)(void*); - void (*destroy)(void*); - int (*cmp)(void*, void*); -} grpc_lb_user_data_vtable; - -typedef struct grpc_lb_addresses { - size_t num_addresses; - grpc_lb_address* addresses; - const grpc_lb_user_data_vtable* user_data_vtable; -} grpc_lb_addresses; - -/** Returns a grpc_addresses struct with enough space for - \a num_addresses addresses. The \a user_data_vtable argument may be - NULL if no user data will be added. */ -grpc_lb_addresses* grpc_lb_addresses_create( - size_t num_addresses, const grpc_lb_user_data_vtable* user_data_vtable); - -/** Creates a copy of \a addresses. */ -grpc_lb_addresses* grpc_lb_addresses_copy(const grpc_lb_addresses* addresses); - -/** Sets the value of the address at index \a index of \a addresses. - * \a address is a socket address of length \a address_len. */ -void grpc_lb_addresses_set_address(grpc_lb_addresses* addresses, size_t index, - const void* address, size_t address_len, - bool is_balancer, const char* balancer_name, - void* user_data); - -/** Sets the value of the address at index \a index of \a addresses from \a uri. - * Returns true upon success, false otherwise. */ -bool grpc_lb_addresses_set_address_from_uri(grpc_lb_addresses* addresses, - size_t index, const grpc_uri* uri, - bool is_balancer, - const char* balancer_name, - void* user_data); - -/** Compares \a addresses1 and \a addresses2. */ -int grpc_lb_addresses_cmp(const grpc_lb_addresses* addresses1, - const grpc_lb_addresses* addresses2); - -/** Destroys \a addresses. */ -void grpc_lb_addresses_destroy(grpc_lb_addresses* addresses); - -/** Returns a channel arg containing \a addresses. */ -grpc_arg grpc_lb_addresses_create_channel_arg( - const grpc_lb_addresses* addresses); - -/** Returns the \a grpc_lb_addresses instance in \a channel_args or NULL */ -grpc_lb_addresses* grpc_lb_addresses_find_channel_arg( - const grpc_channel_args* channel_args); - -// Returns true if addresses contains at least one balancer address. -bool grpc_lb_addresses_contains_balancer_address( - const grpc_lb_addresses& addresses); - -// -// LB policy factory -// +#include "src/core/lib/gprpp/abstract.h" +#include "src/core/lib/gprpp/orphanable.h" namespace grpc_core { diff --git a/src/core/ext/filters/client_channel/lb_policy_registry.cc b/src/core/ext/filters/client_channel/lb_policy_registry.cc index d651b1120d..ad459c9c8c 100644 --- a/src/core/ext/filters/client_channel/lb_policy_registry.cc +++ b/src/core/ext/filters/client_channel/lb_policy_registry.cc @@ -94,4 +94,9 @@ LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy( return factory->CreateLoadBalancingPolicy(args); } +bool LoadBalancingPolicyRegistry::LoadBalancingPolicyExists(const char* name) { + GPR_ASSERT(g_state != nullptr); + return g_state->GetLoadBalancingPolicyFactory(name) != nullptr; +} + } // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/lb_policy_registry.h b/src/core/ext/filters/client_channel/lb_policy_registry.h index 2e9bb061ed..338f7c9f69 100644 --- a/src/core/ext/filters/client_channel/lb_policy_registry.h +++ b/src/core/ext/filters/client_channel/lb_policy_registry.h @@ -47,6 +47,10 @@ class LoadBalancingPolicyRegistry { /// Creates an LB policy of the type specified by \a name. static OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy( const char* name, const LoadBalancingPolicy::Args& args); + + /// Returns true if the LB policy factory specified by \a name exists in this + /// registry. + static bool LoadBalancingPolicyExists(const char* name); }; } // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/method_params.cc b/src/core/ext/filters/client_channel/method_params.cc deleted file mode 100644 index 1f116bb67d..0000000000 --- a/src/core/ext/filters/client_channel/method_params.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* - * - * Copyright 2015 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include <grpc/support/port_platform.h> - -#include <stdio.h> -#include <string.h> - -#include <grpc/support/alloc.h> -#include <grpc/support/log.h> -#include <grpc/support/string_util.h> - -#include "src/core/ext/filters/client_channel/method_params.h" -#include "src/core/lib/channel/status_util.h" -#include "src/core/lib/gpr/string.h" -#include "src/core/lib/gprpp/memory.h" - -// As per the retry design, we do not allow more than 5 retry attempts. -#define MAX_MAX_RETRY_ATTEMPTS 5 - -namespace grpc_core { -namespace internal { - -namespace { - -bool ParseWaitForReady( - grpc_json* field, ClientChannelMethodParams::WaitForReady* wait_for_ready) { - if (field->type != GRPC_JSON_TRUE && field->type != GRPC_JSON_FALSE) { - return false; - } - *wait_for_ready = field->type == GRPC_JSON_TRUE - ? ClientChannelMethodParams::WAIT_FOR_READY_TRUE - : ClientChannelMethodParams::WAIT_FOR_READY_FALSE; - return true; -} - -// Parses a JSON field of the form generated for a google.proto.Duration -// proto message, as per: -// https://developers.google.com/protocol-buffers/docs/proto3#json -bool ParseDuration(grpc_json* field, grpc_millis* duration) { - if (field->type != GRPC_JSON_STRING) return false; - size_t len = strlen(field->value); - if (field->value[len - 1] != 's') return false; - UniquePtr<char> buf(gpr_strdup(field->value)); - *(buf.get() + len - 1) = '\0'; // Remove trailing 's'. - char* decimal_point = strchr(buf.get(), '.'); - int nanos = 0; - if (decimal_point != nullptr) { - *decimal_point = '\0'; - nanos = gpr_parse_nonnegative_int(decimal_point + 1); - if (nanos == -1) { - return false; - } - int num_digits = static_cast<int>(strlen(decimal_point + 1)); - if (num_digits > 9) { // We don't accept greater precision than nanos. - return false; - } - for (int i = 0; i < (9 - num_digits); ++i) { - nanos *= 10; - } - } - int seconds = - decimal_point == buf.get() ? 0 : gpr_parse_nonnegative_int(buf.get()); - if (seconds == -1) return false; - *duration = seconds * GPR_MS_PER_SEC + nanos / GPR_NS_PER_MS; - return true; -} - -UniquePtr<ClientChannelMethodParams::RetryPolicy> ParseRetryPolicy( - grpc_json* field) { - auto retry_policy = MakeUnique<ClientChannelMethodParams::RetryPolicy>(); - if (field->type != GRPC_JSON_OBJECT) return nullptr; - for (grpc_json* sub_field = field->child; sub_field != nullptr; - sub_field = sub_field->next) { - if (sub_field->key == nullptr) return nullptr; - if (strcmp(sub_field->key, "maxAttempts") == 0) { - if (retry_policy->max_attempts != 0) return nullptr; // Duplicate. - if (sub_field->type != GRPC_JSON_NUMBER) return nullptr; - retry_policy->max_attempts = gpr_parse_nonnegative_int(sub_field->value); - if (retry_policy->max_attempts <= 1) return nullptr; - if (retry_policy->max_attempts > MAX_MAX_RETRY_ATTEMPTS) { - gpr_log(GPR_ERROR, - "service config: clamped retryPolicy.maxAttempts at %d", - MAX_MAX_RETRY_ATTEMPTS); - retry_policy->max_attempts = MAX_MAX_RETRY_ATTEMPTS; - } - } else if (strcmp(sub_field->key, "initialBackoff") == 0) { - if (retry_policy->initial_backoff > 0) return nullptr; // Duplicate. - if (!ParseDuration(sub_field, &retry_policy->initial_backoff)) { - return nullptr; - } - if (retry_policy->initial_backoff == 0) return nullptr; - } else if (strcmp(sub_field->key, "maxBackoff") == 0) { - if (retry_policy->max_backoff > 0) return nullptr; // Duplicate. - if (!ParseDuration(sub_field, &retry_policy->max_backoff)) { - return nullptr; - } - if (retry_policy->max_backoff == 0) return nullptr; - } else if (strcmp(sub_field->key, "backoffMultiplier") == 0) { - if (retry_policy->backoff_multiplier != 0) return nullptr; // Duplicate. - if (sub_field->type != GRPC_JSON_NUMBER) return nullptr; - if (sscanf(sub_field->value, "%f", &retry_policy->backoff_multiplier) != - 1) { - return nullptr; - } - if (retry_policy->backoff_multiplier <= 0) return nullptr; - } else if (strcmp(sub_field->key, "retryableStatusCodes") == 0) { - if (!retry_policy->retryable_status_codes.Empty()) { - return nullptr; // Duplicate. - } - if (sub_field->type != GRPC_JSON_ARRAY) return nullptr; - for (grpc_json* element = sub_field->child; element != nullptr; - element = element->next) { - if (element->type != GRPC_JSON_STRING) return nullptr; - grpc_status_code status; - if (!grpc_status_code_from_string(element->value, &status)) { - return nullptr; - } - retry_policy->retryable_status_codes.Add(status); - } - if (retry_policy->retryable_status_codes.Empty()) return nullptr; - } - } - // Make sure required fields are set. - if (retry_policy->max_attempts == 0 || retry_policy->initial_backoff == 0 || - retry_policy->max_backoff == 0 || retry_policy->backoff_multiplier == 0 || - retry_policy->retryable_status_codes.Empty()) { - return nullptr; - } - return retry_policy; -} - -} // namespace - -RefCountedPtr<ClientChannelMethodParams> -ClientChannelMethodParams::CreateFromJson(const grpc_json* json) { - RefCountedPtr<ClientChannelMethodParams> method_params = - MakeRefCounted<ClientChannelMethodParams>(); - for (grpc_json* field = json->child; field != nullptr; field = field->next) { - if (field->key == nullptr) continue; - if (strcmp(field->key, "waitForReady") == 0) { - if (method_params->wait_for_ready_ != WAIT_FOR_READY_UNSET) { - return nullptr; // Duplicate. - } - if (!ParseWaitForReady(field, &method_params->wait_for_ready_)) { - return nullptr; - } - } else if (strcmp(field->key, "timeout") == 0) { - if (method_params->timeout_ > 0) return nullptr; // Duplicate. - if (!ParseDuration(field, &method_params->timeout_)) return nullptr; - } else if (strcmp(field->key, "retryPolicy") == 0) { - if (method_params->retry_policy_ != nullptr) { - return nullptr; // Duplicate. - } - method_params->retry_policy_ = ParseRetryPolicy(field); - if (method_params->retry_policy_ == nullptr) return nullptr; - } - } - return method_params; -} - -} // namespace internal -} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/method_params.h b/src/core/ext/filters/client_channel/method_params.h deleted file mode 100644 index a31d360f17..0000000000 --- a/src/core/ext/filters/client_channel/method_params.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * - * Copyright 2015 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_METHOD_PARAMS_H -#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_METHOD_PARAMS_H - -#include <grpc/support/port_platform.h> - -#include "src/core/lib/channel/status_util.h" -#include "src/core/lib/gprpp/ref_counted.h" -#include "src/core/lib/gprpp/ref_counted_ptr.h" -#include "src/core/lib/iomgr/exec_ctx.h" // for grpc_millis -#include "src/core/lib/json/json.h" - -namespace grpc_core { -namespace internal { - -class ClientChannelMethodParams : public RefCounted<ClientChannelMethodParams> { - public: - enum WaitForReady { - WAIT_FOR_READY_UNSET = 0, - WAIT_FOR_READY_FALSE, - WAIT_FOR_READY_TRUE - }; - - struct RetryPolicy { - int max_attempts = 0; - grpc_millis initial_backoff = 0; - grpc_millis max_backoff = 0; - float backoff_multiplier = 0; - StatusCodeSet retryable_status_codes; - }; - - /// Creates a method_parameters object from \a json. - /// Intended for use with ServiceConfig::CreateMethodConfigTable(). - static RefCountedPtr<ClientChannelMethodParams> CreateFromJson( - const grpc_json* json); - - grpc_millis timeout() const { return timeout_; } - WaitForReady wait_for_ready() const { return wait_for_ready_; } - const RetryPolicy* retry_policy() const { return retry_policy_.get(); } - - private: - // So New() can call our private ctor. - template <typename T, typename... Args> - friend T* grpc_core::New(Args&&... args); - - // So Delete() can call our private dtor. - template <typename T> - friend void grpc_core::Delete(T*); - - ClientChannelMethodParams() {} - virtual ~ClientChannelMethodParams() {} - - grpc_millis timeout_ = 0; - WaitForReady wait_for_ready_ = WAIT_FOR_READY_UNSET; - UniquePtr<RetryPolicy> retry_policy_; -}; - -} // namespace internal -} // namespace grpc_core - -#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_METHOD_PARAMS_H */ diff --git a/src/core/ext/filters/client_channel/request_routing.cc b/src/core/ext/filters/client_channel/request_routing.cc new file mode 100644 index 0000000000..f9a7e164e7 --- /dev/null +++ b/src/core/ext/filters/client_channel/request_routing.cc @@ -0,0 +1,936 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/request_routing.h" + +#include <inttypes.h> +#include <limits.h> +#include <stdbool.h> +#include <stdio.h> +#include <string.h> + +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> +#include <grpc/support/sync.h> + +#include "src/core/ext/filters/client_channel/backup_poller.h" +#include "src/core/ext/filters/client_channel/http_connect_handshaker.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h" +#include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/retry_throttle.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/ext/filters/client_channel/subchannel.h" +#include "src/core/ext/filters/deadline/deadline_filter.h" +#include "src/core/lib/backoff/backoff.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/connected_channel.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/manual_constructor.h" +#include "src/core/lib/iomgr/combiner.h" +#include "src/core/lib/iomgr/iomgr.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/profiling/timers.h" +#include "src/core/lib/slice/slice_internal.h" +#include "src/core/lib/slice/slice_string_helpers.h" +#include "src/core/lib/surface/channel.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/error_utils.h" +#include "src/core/lib/transport/metadata.h" +#include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/transport/service_config.h" +#include "src/core/lib/transport/static_metadata.h" +#include "src/core/lib/transport/status_metadata.h" + +namespace grpc_core { + +// +// RequestRouter::Request::ResolverResultWaiter +// + +// Handles waiting for a resolver result. +// Used only for the first call on an idle channel. +class RequestRouter::Request::ResolverResultWaiter { + public: + explicit ResolverResultWaiter(Request* request) + : request_router_(request->request_router_), + request_(request), + tracer_enabled_(request_router_->tracer_->enabled()) { + if (tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: deferring pick pending resolver " + "result", + request_router_, request); + } + // Add closure to be run when a resolver result is available. + GRPC_CLOSURE_INIT(&done_closure_, &DoneLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + AddToWaitingList(); + // Set cancellation closure, so that we abort if the call is cancelled. + GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + grpc_call_combiner_set_notify_on_cancel(request->call_combiner_, + &cancel_closure_); + } + + private: + // Adds done_closure_ to + // request_router_->waiting_for_resolver_result_closures_. + void AddToWaitingList() { + grpc_closure_list_append( + &request_router_->waiting_for_resolver_result_closures_, &done_closure_, + GRPC_ERROR_NONE); + } + + // Invoked when a resolver result is available. + static void DoneLocked(void* arg, grpc_error* error) { + ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); + RequestRouter* request_router = self->request_router_; + // If CancelLocked() has already run, delete ourselves without doing + // anything. Note that the call stack may have already been destroyed, + // so it's not safe to access anything in state_. + if (GPR_UNLIKELY(self->finished_)) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p: call cancelled before resolver result", + request_router); + } + Delete(self); + return; + } + // Otherwise, process the resolver result. + Request* request = self->request_; + if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver failed to return data", + request_router, request); + } + GRPC_CLOSURE_RUN(request->on_route_done_, GRPC_ERROR_REF(error)); + } else if (GPR_UNLIKELY(request_router->resolver_ == nullptr)) { + // Shutting down. + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, "request_router=%p request=%p: resolver disconnected", + request_router, request); + } + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); + } else if (GPR_UNLIKELY(request_router->lb_policy_ == nullptr)) { + // Transient resolver failure. + // If call has wait_for_ready=true, try again; otherwise, fail. + if (*request->pick_.initial_metadata_flags & + GRPC_INITIAL_METADATA_WAIT_FOR_READY) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned but no LB " + "policy; wait_for_ready=true; trying again", + request_router, request); + } + // Re-add ourselves to the waiting list. + self->AddToWaitingList(); + // Return early so that we don't set finished_ to true below. + return; + } else { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned but no LB " + "policy; wait_for_ready=false; failing", + request_router, request); + } + GRPC_CLOSURE_RUN( + request->on_route_done_, + grpc_error_set_int( + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"), + GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE)); + } + } else { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: resolver returned, doing LB " + "pick", + request_router, request); + } + request->ProcessServiceConfigAndStartLbPickLocked(); + } + self->finished_ = true; + } + + // Invoked when the call is cancelled. + // Note: This runs under the client_channel combiner, but will NOT be + // holding the call combiner. + static void CancelLocked(void* arg, grpc_error* error) { + ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg); + RequestRouter* request_router = self->request_router_; + // If DoneLocked() has already run, delete ourselves without doing anything. + if (self->finished_) { + Delete(self); + return; + } + Request* request = self->request_; + // If we are being cancelled, immediately invoke on_route_done_ + // to propagate the error back to the caller. + if (error != GRPC_ERROR_NONE) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: cancelling call waiting for " + "name resolution", + request_router, request); + } + // Note: Although we are not in the call combiner here, we are + // basically stealing the call combiner from the pending pick, so + // it's safe to run on_route_done_ here -- we are essentially + // calling it here instead of calling it in DoneLocked(). + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Pick cancelled", &error, 1)); + } + self->finished_ = true; + } + + RequestRouter* request_router_; + Request* request_; + const bool tracer_enabled_; + grpc_closure done_closure_; + grpc_closure cancel_closure_; + bool finished_ = false; +}; + +// +// RequestRouter::Request::AsyncPickCanceller +// + +// Handles the call combiner cancellation callback for an async LB pick. +class RequestRouter::Request::AsyncPickCanceller { + public: + explicit AsyncPickCanceller(Request* request) + : request_router_(request->request_router_), + request_(request), + tracer_enabled_(request_router_->tracer_->enabled()) { + GRPC_CALL_STACK_REF(request->owning_call_, "pick_callback_cancel"); + // Set cancellation closure, so that we abort if the call is cancelled. + GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + grpc_call_combiner_set_notify_on_cancel(request->call_combiner_, + &cancel_closure_); + } + + void MarkFinishedLocked() { + finished_ = true; + GRPC_CALL_STACK_UNREF(request_->owning_call_, "pick_callback_cancel"); + } + + private: + // Invoked when the call is cancelled. + // Note: This runs under the client_channel combiner, but will NOT be + // holding the call combiner. + static void CancelLocked(void* arg, grpc_error* error) { + AsyncPickCanceller* self = static_cast<AsyncPickCanceller*>(arg); + Request* request = self->request_; + RequestRouter* request_router = self->request_router_; + if (!self->finished_) { + // Note: request_router->lb_policy_ may have changed since we started our + // pick, in which case we will be cancelling the pick on a policy other + // than the one we started it on. However, this will just be a no-op. + if (error != GRPC_ERROR_NONE && request_router->lb_policy_ != nullptr) { + if (self->tracer_enabled_) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: cancelling pick from LB " + "policy %p", + request_router, request, request_router->lb_policy_.get()); + } + request_router->lb_policy_->CancelPickLocked(&request->pick_, + GRPC_ERROR_REF(error)); + } + request->pick_canceller_ = nullptr; + GRPC_CALL_STACK_UNREF(request->owning_call_, "pick_callback_cancel"); + } + Delete(self); + } + + RequestRouter* request_router_; + Request* request_; + const bool tracer_enabled_; + grpc_closure cancel_closure_; + bool finished_ = false; +}; + +// +// RequestRouter::Request +// + +RequestRouter::Request::Request(grpc_call_stack* owning_call, + grpc_call_combiner* call_combiner, + grpc_polling_entity* pollent, + grpc_metadata_batch* send_initial_metadata, + uint32_t* send_initial_metadata_flags, + ApplyServiceConfigCallback apply_service_config, + void* apply_service_config_user_data, + grpc_closure* on_route_done) + : owning_call_(owning_call), + call_combiner_(call_combiner), + pollent_(pollent), + apply_service_config_(apply_service_config), + apply_service_config_user_data_(apply_service_config_user_data), + on_route_done_(on_route_done) { + pick_.initial_metadata = send_initial_metadata; + pick_.initial_metadata_flags = send_initial_metadata_flags; +} + +RequestRouter::Request::~Request() { + if (pick_.connected_subchannel != nullptr) { + pick_.connected_subchannel.reset(); + } + for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) { + if (pick_.subchannel_call_context[i].destroy != nullptr) { + pick_.subchannel_call_context[i].destroy( + pick_.subchannel_call_context[i].value); + } + } +} + +// Invoked once resolver results are available. +void RequestRouter::Request::ProcessServiceConfigAndStartLbPickLocked() { + // Get service config data if needed. + if (!apply_service_config_(apply_service_config_user_data_)) return; + // Start LB pick. + StartLbPickLocked(); +} + +void RequestRouter::Request::MaybeAddCallToInterestedPartiesLocked() { + if (!pollent_added_to_interested_parties_) { + pollent_added_to_interested_parties_ = true; + grpc_polling_entity_add_to_pollset_set( + pollent_, request_router_->interested_parties_); + } +} + +void RequestRouter::Request::MaybeRemoveCallFromInterestedPartiesLocked() { + if (pollent_added_to_interested_parties_) { + pollent_added_to_interested_parties_ = false; + grpc_polling_entity_del_from_pollset_set( + pollent_, request_router_->interested_parties_); + } +} + +// Starts a pick on the LB policy. +void RequestRouter::Request::StartLbPickLocked() { + if (request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: starting pick on lb_policy=%p", + request_router_, this, request_router_->lb_policy_.get()); + } + GRPC_CLOSURE_INIT(&on_pick_done_, &LbPickDoneLocked, this, + grpc_combiner_scheduler(request_router_->combiner_)); + pick_.on_complete = &on_pick_done_; + GRPC_CALL_STACK_REF(owning_call_, "pick_callback"); + grpc_error* error = GRPC_ERROR_NONE; + const bool pick_done = + request_router_->lb_policy_->PickLocked(&pick_, &error); + if (pick_done) { + // Pick completed synchronously. + if (request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: pick completed synchronously", + request_router_, this); + } + GRPC_CLOSURE_RUN(on_route_done_, error); + GRPC_CALL_STACK_UNREF(owning_call_, "pick_callback"); + } else { + // Pick will be returned asynchronously. + // Add the request's polling entity to the request_router's + // interested_parties, so that the I/O of the LB policy can be done + // under it. It will be removed in LbPickDoneLocked(). + MaybeAddCallToInterestedPartiesLocked(); + // Request notification on call cancellation. + // We allocate a separate object to track cancellation, since the + // cancellation closure might still be pending when we need to reuse + // the memory in which this Request object is stored for a subsequent + // retry attempt. + pick_canceller_ = New<AsyncPickCanceller>(this); + } +} + +// Callback invoked by LoadBalancingPolicy::PickLocked() for async picks. +// Unrefs the LB policy and invokes on_route_done_. +void RequestRouter::Request::LbPickDoneLocked(void* arg, grpc_error* error) { + Request* self = static_cast<Request*>(arg); + RequestRouter* request_router = self->request_router_; + if (request_router->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p request=%p: pick completed asynchronously", + request_router, self); + } + self->MaybeRemoveCallFromInterestedPartiesLocked(); + if (self->pick_canceller_ != nullptr) { + self->pick_canceller_->MarkFinishedLocked(); + } + GRPC_CLOSURE_RUN(self->on_route_done_, GRPC_ERROR_REF(error)); + GRPC_CALL_STACK_UNREF(self->owning_call_, "pick_callback"); +} + +// +// RequestRouter::LbConnectivityWatcher +// + +class RequestRouter::LbConnectivityWatcher { + public: + LbConnectivityWatcher(RequestRouter* request_router, + grpc_connectivity_state state, + LoadBalancingPolicy* lb_policy, + grpc_channel_stack* owning_stack, + grpc_combiner* combiner) + : request_router_(request_router), + state_(state), + lb_policy_(lb_policy), + owning_stack_(owning_stack) { + GRPC_CHANNEL_STACK_REF(owning_stack_, "LbConnectivityWatcher"); + GRPC_CLOSURE_INIT(&on_changed_, &OnLbPolicyStateChangedLocked, this, + grpc_combiner_scheduler(combiner)); + lb_policy_->NotifyOnStateChangeLocked(&state_, &on_changed_); + } + + ~LbConnectivityWatcher() { + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "LbConnectivityWatcher"); + } + + private: + static void OnLbPolicyStateChangedLocked(void* arg, grpc_error* error) { + LbConnectivityWatcher* self = static_cast<LbConnectivityWatcher*>(arg); + // If the notification is not for the current policy, we're stale, + // so delete ourselves. + if (self->lb_policy_ != self->request_router_->lb_policy_.get()) { + Delete(self); + return; + } + // Otherwise, process notification. + if (self->request_router_->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: lb_policy=%p state changed to %s", + self->request_router_, self->lb_policy_, + grpc_connectivity_state_name(self->state_)); + } + self->request_router_->SetConnectivityStateLocked( + self->state_, GRPC_ERROR_REF(error), "lb_changed"); + // If shutting down, terminate watch. + if (self->state_ == GRPC_CHANNEL_SHUTDOWN) { + Delete(self); + return; + } + // Renew watch. + self->lb_policy_->NotifyOnStateChangeLocked(&self->state_, + &self->on_changed_); + } + + RequestRouter* request_router_; + grpc_connectivity_state state_; + // LB policy address. No ref held, so not safe to dereference unless + // it happens to match request_router->lb_policy_. + LoadBalancingPolicy* lb_policy_; + grpc_channel_stack* owning_stack_; + grpc_closure on_changed_; +}; + +// +// RequestRounter::ReresolutionRequestHandler +// + +class RequestRouter::ReresolutionRequestHandler { + public: + ReresolutionRequestHandler(RequestRouter* request_router, + LoadBalancingPolicy* lb_policy, + grpc_channel_stack* owning_stack, + grpc_combiner* combiner) + : request_router_(request_router), + lb_policy_(lb_policy), + owning_stack_(owning_stack) { + GRPC_CHANNEL_STACK_REF(owning_stack_, "ReresolutionRequestHandler"); + GRPC_CLOSURE_INIT(&closure_, &OnRequestReresolutionLocked, this, + grpc_combiner_scheduler(combiner)); + lb_policy_->SetReresolutionClosureLocked(&closure_); + } + + private: + static void OnRequestReresolutionLocked(void* arg, grpc_error* error) { + ReresolutionRequestHandler* self = + static_cast<ReresolutionRequestHandler*>(arg); + RequestRouter* request_router = self->request_router_; + // If this invocation is for a stale LB policy, treat it as an LB shutdown + // signal. + if (self->lb_policy_ != request_router->lb_policy_.get() || + error != GRPC_ERROR_NONE || request_router->resolver_ == nullptr) { + GRPC_CHANNEL_STACK_UNREF(request_router->owning_stack_, + "ReresolutionRequestHandler"); + Delete(self); + return; + } + if (request_router->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: started name re-resolving", + request_router); + } + request_router->resolver_->RequestReresolutionLocked(); + // Give back the closure to the LB policy. + self->lb_policy_->SetReresolutionClosureLocked(&self->closure_); + } + + RequestRouter* request_router_; + // LB policy address. No ref held, so not safe to dereference unless + // it happens to match request_router->lb_policy_. + LoadBalancingPolicy* lb_policy_; + grpc_channel_stack* owning_stack_; + grpc_closure closure_; +}; + +// +// RequestRouter +// + +RequestRouter::RequestRouter( + grpc_channel_stack* owning_stack, grpc_combiner* combiner, + grpc_client_channel_factory* client_channel_factory, + grpc_pollset_set* interested_parties, TraceFlag* tracer, + ProcessResolverResultCallback process_resolver_result, + void* process_resolver_result_user_data, const char* target_uri, + const grpc_channel_args* args, grpc_error** error) + : owning_stack_(owning_stack), + combiner_(combiner), + client_channel_factory_(client_channel_factory), + interested_parties_(interested_parties), + tracer_(tracer), + process_resolver_result_(process_resolver_result), + process_resolver_result_user_data_(process_resolver_result_user_data) { + GRPC_CLOSURE_INIT(&on_resolver_result_changed_, + &RequestRouter::OnResolverResultChangedLocked, this, + grpc_combiner_scheduler(combiner)); + grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE, + "request_router"); + grpc_channel_args* new_args = nullptr; + if (process_resolver_result == nullptr) { + grpc_arg arg = grpc_channel_arg_integer_create( + const_cast<char*>(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION), 0); + new_args = grpc_channel_args_copy_and_add(args, &arg, 1); + } + resolver_ = ResolverRegistry::CreateResolver( + target_uri, (new_args == nullptr ? args : new_args), interested_parties_, + combiner_); + grpc_channel_args_destroy(new_args); + if (resolver_ == nullptr) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed"); + } +} + +RequestRouter::~RequestRouter() { + if (resolver_ != nullptr) { + // The only way we can get here is if we never started resolving, + // because we take a ref to the channel stack when we start + // resolving and do not release it until the resolver callback is + // invoked after the resolver shuts down. + resolver_.reset(); + } + if (lb_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + if (client_channel_factory_ != nullptr) { + grpc_client_channel_factory_unref(client_channel_factory_); + } + grpc_connectivity_state_destroy(&state_tracker_); +} + +namespace { + +const char* GetChannelConnectivityStateChangeString( + grpc_connectivity_state state) { + switch (state) { + case GRPC_CHANNEL_IDLE: + return "Channel state change to IDLE"; + case GRPC_CHANNEL_CONNECTING: + return "Channel state change to CONNECTING"; + case GRPC_CHANNEL_READY: + return "Channel state change to READY"; + case GRPC_CHANNEL_TRANSIENT_FAILURE: + return "Channel state change to TRANSIENT_FAILURE"; + case GRPC_CHANNEL_SHUTDOWN: + return "Channel state change to SHUTDOWN"; + } + GPR_UNREACHABLE_CODE(return "UNKNOWN"); +} + +} // namespace + +void RequestRouter::SetConnectivityStateLocked(grpc_connectivity_state state, + grpc_error* error, + const char* reason) { + if (lb_policy_ != nullptr) { + if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { + // Cancel picks with wait_for_ready=false. + lb_policy_->CancelMatchingPicksLocked( + /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY, + /* check= */ 0, GRPC_ERROR_REF(error)); + } else if (state == GRPC_CHANNEL_SHUTDOWN) { + // Cancel all picks. + lb_policy_->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0, + GRPC_ERROR_REF(error)); + } + } + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: setting connectivity state to %s", + this, grpc_connectivity_state_name(state)); + } + if (channelz_node_ != nullptr) { + channelz_node_->AddTraceEvent( + channelz::ChannelTrace::Severity::Info, + grpc_slice_from_static_string( + GetChannelConnectivityStateChangeString(state))); + } + grpc_connectivity_state_set(&state_tracker_, state, error, reason); +} + +void RequestRouter::StartResolvingLocked() { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: starting name resolution", this); + } + GPR_ASSERT(!started_resolving_); + started_resolving_ = true; + GRPC_CHANNEL_STACK_REF(owning_stack_, "resolver"); + resolver_->NextLocked(&resolver_result_, &on_resolver_result_changed_); +} + +// Invoked from the resolver NextLocked() callback when the resolver +// is shutting down. +void RequestRouter::OnResolverShutdownLocked(grpc_error* error) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down", this); + } + if (lb_policy_ != nullptr) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this, + lb_policy_.get()); + } + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + if (resolver_ != nullptr) { + // This should never happen; it can only be triggered by a resolver + // implementation spotaneously deciding to report shutdown without + // being orphaned. This code is included just to be defensive. + if (tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p: spontaneous shutdown from resolver %p", this, + resolver_.get()); + } + resolver_.reset(); + SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Resolver spontaneous shutdown", &error, 1), + "resolver_spontaneous_shutdown"); + } + grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_, + GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( + "Channel disconnected", &error, 1)); + GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_); + GRPC_CHANNEL_STACK_UNREF(owning_stack_, "resolver"); + grpc_channel_args_destroy(resolver_result_); + resolver_result_ = nullptr; + GRPC_ERROR_UNREF(error); +} + +// Creates a new LB policy, replacing any previous one. +// If the new policy is created successfully, sets *connectivity_state and +// *connectivity_error to its initial connectivity state; otherwise, +// leaves them unchanged. +void RequestRouter::CreateNewLbPolicyLocked( + const char* lb_policy_name, grpc_json* lb_config, + grpc_connectivity_state* connectivity_state, + grpc_error** connectivity_error, TraceStringVector* trace_strings) { + LoadBalancingPolicy::Args lb_policy_args; + lb_policy_args.combiner = combiner_; + lb_policy_args.client_channel_factory = client_channel_factory_; + lb_policy_args.args = resolver_result_; + lb_policy_args.lb_config = lb_config; + OrphanablePtr<LoadBalancingPolicy> new_lb_policy = + LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(lb_policy_name, + lb_policy_args); + if (GPR_UNLIKELY(new_lb_policy == nullptr)) { + gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name); + if (channelz_node_ != nullptr) { + char* str; + gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name); + trace_strings->push_back(str); + } + } else { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: created new LB policy \"%s\" (%p)", + this, lb_policy_name, new_lb_policy.get()); + } + if (channelz_node_ != nullptr) { + char* str; + gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name); + trace_strings->push_back(str); + } + // Swap out the LB policy and update the fds in interested_parties_. + if (lb_policy_ != nullptr) { + if (tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this, + lb_policy_.get()); + } + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_->HandOffPendingPicksLocked(new_lb_policy.get()); + } + lb_policy_ = std::move(new_lb_policy); + grpc_pollset_set_add_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + // Create re-resolution request handler for the new LB policy. It + // will delete itself when no longer needed. + New<ReresolutionRequestHandler>(this, lb_policy_.get(), owning_stack_, + combiner_); + // Get the new LB policy's initial connectivity state and start a + // connectivity watch. + GRPC_ERROR_UNREF(*connectivity_error); + *connectivity_state = + lb_policy_->CheckConnectivityLocked(connectivity_error); + if (exit_idle_when_lb_policy_arrives_) { + lb_policy_->ExitIdleLocked(); + exit_idle_when_lb_policy_arrives_ = false; + } + // Create new watcher. It will delete itself when done. + New<LbConnectivityWatcher>(this, *connectivity_state, lb_policy_.get(), + owning_stack_, combiner_); + } +} + +void RequestRouter::MaybeAddTraceMessagesForAddressChangesLocked( + TraceStringVector* trace_strings) { + const ServerAddressList* addresses = + FindServerAddressListChannelArg(resolver_result_); + const bool resolution_contains_addresses = + addresses != nullptr && addresses->size() > 0; + if (!resolution_contains_addresses && + previous_resolution_contained_addresses_) { + trace_strings->push_back(gpr_strdup("Address list became empty")); + } else if (resolution_contains_addresses && + !previous_resolution_contained_addresses_) { + trace_strings->push_back(gpr_strdup("Address list became non-empty")); + } + previous_resolution_contained_addresses_ = resolution_contains_addresses; +} + +void RequestRouter::ConcatenateAndAddChannelTraceLocked( + TraceStringVector* trace_strings) const { + if (!trace_strings->empty()) { + gpr_strvec v; + gpr_strvec_init(&v); + gpr_strvec_add(&v, gpr_strdup("Resolution event: ")); + bool is_first = 1; + for (size_t i = 0; i < trace_strings->size(); ++i) { + if (!is_first) gpr_strvec_add(&v, gpr_strdup(", ")); + is_first = false; + gpr_strvec_add(&v, (*trace_strings)[i]); + } + char* flat; + size_t flat_len = 0; + flat = gpr_strvec_flatten(&v, &flat_len); + channelz_node_->AddTraceEvent( + grpc_core::channelz::ChannelTrace::Severity::Info, + grpc_slice_new(flat, flat_len, gpr_free)); + gpr_strvec_destroy(&v); + } +} + +// Callback invoked when a resolver result is available. +void RequestRouter::OnResolverResultChangedLocked(void* arg, + grpc_error* error) { + RequestRouter* self = static_cast<RequestRouter*>(arg); + if (self->tracer_->enabled()) { + const char* disposition = + self->resolver_result_ != nullptr + ? "" + : (error == GRPC_ERROR_NONE ? " (transient error)" + : " (resolver shutdown)"); + gpr_log(GPR_INFO, + "request_router=%p: got resolver result: resolver_result=%p " + "error=%s%s", + self, self->resolver_result_, grpc_error_string(error), + disposition); + } + // Handle shutdown. + if (error != GRPC_ERROR_NONE || self->resolver_ == nullptr) { + self->OnResolverShutdownLocked(GRPC_ERROR_REF(error)); + return; + } + // Data used to set the channel's connectivity state. + bool set_connectivity_state = true; + // We only want to trace the address resolution in the follow cases: + // (a) Address resolution resulted in service config change. + // (b) Address resolution that causes number of backends to go from + // zero to non-zero. + // (c) Address resolution that causes number of backends to go from + // non-zero to zero. + // (d) Address resolution that causes a new LB policy to be created. + // + // we track a list of strings to eventually be concatenated and traced. + TraceStringVector trace_strings; + grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE; + grpc_error* connectivity_error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy"); + // resolver_result_ will be null in the case of a transient + // resolution error. In that case, we don't have any new result to + // process, which means that we keep using the previous result (if any). + if (self->resolver_result_ == nullptr) { + if (self->tracer_->enabled()) { + gpr_log(GPR_INFO, "request_router=%p: resolver transient failure", self); + } + // Don't override connectivity state if we already have an LB policy. + if (self->lb_policy_ != nullptr) set_connectivity_state = false; + } else { + // Parse the resolver result. + const char* lb_policy_name = nullptr; + grpc_json* lb_policy_config = nullptr; + const bool service_config_changed = self->process_resolver_result_( + self->process_resolver_result_user_data_, *self->resolver_result_, + &lb_policy_name, &lb_policy_config); + GPR_ASSERT(lb_policy_name != nullptr); + // Check to see if we're already using the right LB policy. + const bool lb_policy_name_changed = + self->lb_policy_ == nullptr || + strcmp(self->lb_policy_->name(), lb_policy_name) != 0; + if (self->lb_policy_ != nullptr && !lb_policy_name_changed) { + // Continue using the same LB policy. Update with new addresses. + if (self->tracer_->enabled()) { + gpr_log(GPR_INFO, + "request_router=%p: updating existing LB policy \"%s\" (%p)", + self, lb_policy_name, self->lb_policy_.get()); + } + self->lb_policy_->UpdateLocked(*self->resolver_result_, lb_policy_config); + // No need to set the channel's connectivity state; the existing + // watch on the LB policy will take care of that. + set_connectivity_state = false; + } else { + // Instantiate new LB policy. + self->CreateNewLbPolicyLocked(lb_policy_name, lb_policy_config, + &connectivity_state, &connectivity_error, + &trace_strings); + } + // Add channel trace event. + if (self->channelz_node_ != nullptr) { + if (service_config_changed) { + // TODO(ncteisen): might be worth somehow including a snippet of the + // config in the trace, at the risk of bloating the trace logs. + trace_strings.push_back(gpr_strdup("Service config changed")); + } + self->MaybeAddTraceMessagesForAddressChangesLocked(&trace_strings); + self->ConcatenateAndAddChannelTraceLocked(&trace_strings); + } + // Clean up. + grpc_channel_args_destroy(self->resolver_result_); + self->resolver_result_ = nullptr; + } + // Set the channel's connectivity state if needed. + if (set_connectivity_state) { + self->SetConnectivityStateLocked(connectivity_state, connectivity_error, + "resolver_result"); + } else { + GRPC_ERROR_UNREF(connectivity_error); + } + // Invoke closures that were waiting for results and renew the watch. + GRPC_CLOSURE_LIST_SCHED(&self->waiting_for_resolver_result_closures_); + self->resolver_->NextLocked(&self->resolver_result_, + &self->on_resolver_result_changed_); +} + +void RequestRouter::RouteCallLocked(Request* request) { + GPR_ASSERT(request->pick_.connected_subchannel == nullptr); + request->request_router_ = this; + if (lb_policy_ != nullptr) { + // We already have resolver results, so process the service config + // and start an LB pick. + request->ProcessServiceConfigAndStartLbPickLocked(); + } else if (resolver_ == nullptr) { + GRPC_CLOSURE_RUN(request->on_route_done_, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected")); + } else { + // We do not yet have an LB policy, so wait for a resolver result. + if (!started_resolving_) { + StartResolvingLocked(); + } + // Create a new waiter, which will delete itself when done. + New<Request::ResolverResultWaiter>(request); + // Add the request's polling entity to the request_router's + // interested_parties, so that the I/O of the resolver can be done + // under it. It will be removed in LbPickDoneLocked(). + request->MaybeAddCallToInterestedPartiesLocked(); + } +} + +void RequestRouter::ShutdownLocked(grpc_error* error) { + if (resolver_ != nullptr) { + SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN, GRPC_ERROR_REF(error), + "disconnect"); + resolver_.reset(); + if (!started_resolving_) { + grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_, + GRPC_ERROR_REF(error)); + GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_); + } + if (lb_policy_ != nullptr) { + grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(), + interested_parties_); + lb_policy_.reset(); + } + } + GRPC_ERROR_UNREF(error); +} + +grpc_connectivity_state RequestRouter::GetConnectivityState() { + return grpc_connectivity_state_check(&state_tracker_); +} + +void RequestRouter::NotifyOnConnectivityStateChange( + grpc_connectivity_state* state, grpc_closure* closure) { + grpc_connectivity_state_notify_on_state_change(&state_tracker_, state, + closure); +} + +void RequestRouter::ExitIdleLocked() { + if (lb_policy_ != nullptr) { + lb_policy_->ExitIdleLocked(); + } else { + exit_idle_when_lb_policy_arrives_ = true; + if (!started_resolving_ && resolver_ != nullptr) { + StartResolvingLocked(); + } + } +} + +void RequestRouter::ResetConnectionBackoffLocked() { + if (resolver_ != nullptr) { + resolver_->ResetBackoffLocked(); + resolver_->RequestReresolutionLocked(); + } + if (lb_policy_ != nullptr) { + lb_policy_->ResetBackoffLocked(); + } +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/request_routing.h b/src/core/ext/filters/client_channel/request_routing.h new file mode 100644 index 0000000000..0c671229c8 --- /dev/null +++ b/src/core/ext/filters/client_channel/request_routing.h @@ -0,0 +1,177 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H +#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/client_channel_channelz.h" +#include "src/core/ext/filters/client_channel/client_channel_factory.h" +#include "src/core/ext/filters/client_channel/lb_policy.h" +#include "src/core/ext/filters/client_channel/resolver.h" +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/channel/channel_stack.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/orphanable.h" +#include "src/core/lib/iomgr/call_combiner.h" +#include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/polling_entity.h" +#include "src/core/lib/iomgr/pollset_set.h" +#include "src/core/lib/transport/connectivity_state.h" +#include "src/core/lib/transport/metadata_batch.h" + +namespace grpc_core { + +class RequestRouter { + public: + class Request { + public: + // Synchronous callback that applies the service config to a call. + // Returns false if the call should be failed. + typedef bool (*ApplyServiceConfigCallback)(void* user_data); + + Request(grpc_call_stack* owning_call, grpc_call_combiner* call_combiner, + grpc_polling_entity* pollent, + grpc_metadata_batch* send_initial_metadata, + uint32_t* send_initial_metadata_flags, + ApplyServiceConfigCallback apply_service_config, + void* apply_service_config_user_data, grpc_closure* on_route_done); + + ~Request(); + + // TODO(roth): It seems a bit ugly to expose this member in a + // non-const way. Find a better API to avoid this. + LoadBalancingPolicy::PickState* pick() { return &pick_; } + + private: + friend class RequestRouter; + + class ResolverResultWaiter; + class AsyncPickCanceller; + + void ProcessServiceConfigAndStartLbPickLocked(); + void StartLbPickLocked(); + static void LbPickDoneLocked(void* arg, grpc_error* error); + + void MaybeAddCallToInterestedPartiesLocked(); + void MaybeRemoveCallFromInterestedPartiesLocked(); + + // Populated by caller. + grpc_call_stack* owning_call_; + grpc_call_combiner* call_combiner_; + grpc_polling_entity* pollent_; + ApplyServiceConfigCallback apply_service_config_; + void* apply_service_config_user_data_; + grpc_closure* on_route_done_; + LoadBalancingPolicy::PickState pick_; + + // Internal state. + RequestRouter* request_router_ = nullptr; + bool pollent_added_to_interested_parties_ = false; + grpc_closure on_pick_done_; + AsyncPickCanceller* pick_canceller_ = nullptr; + }; + + // Synchronous callback that takes the service config JSON string and + // LB policy name. + // Returns true if the service config has changed since the last result. + typedef bool (*ProcessResolverResultCallback)(void* user_data, + const grpc_channel_args& args, + const char** lb_policy_name, + grpc_json** lb_policy_config); + + RequestRouter(grpc_channel_stack* owning_stack, grpc_combiner* combiner, + grpc_client_channel_factory* client_channel_factory, + grpc_pollset_set* interested_parties, TraceFlag* tracer, + ProcessResolverResultCallback process_resolver_result, + void* process_resolver_result_user_data, const char* target_uri, + const grpc_channel_args* args, grpc_error** error); + + ~RequestRouter(); + + void set_channelz_node(channelz::ClientChannelNode* channelz_node) { + channelz_node_ = channelz_node; + } + + void RouteCallLocked(Request* request); + + // TODO(roth): Add methods to cancel picks. + + void ShutdownLocked(grpc_error* error); + + void ExitIdleLocked(); + void ResetConnectionBackoffLocked(); + + grpc_connectivity_state GetConnectivityState(); + void NotifyOnConnectivityStateChange(grpc_connectivity_state* state, + grpc_closure* closure); + + LoadBalancingPolicy* lb_policy() const { return lb_policy_.get(); } + + private: + using TraceStringVector = grpc_core::InlinedVector<char*, 3>; + + class ReresolutionRequestHandler; + class LbConnectivityWatcher; + + void StartResolvingLocked(); + void OnResolverShutdownLocked(grpc_error* error); + void CreateNewLbPolicyLocked(const char* lb_policy_name, grpc_json* lb_config, + grpc_connectivity_state* connectivity_state, + grpc_error** connectivity_error, + TraceStringVector* trace_strings); + void MaybeAddTraceMessagesForAddressChangesLocked( + TraceStringVector* trace_strings); + void ConcatenateAndAddChannelTraceLocked( + TraceStringVector* trace_strings) const; + static void OnResolverResultChangedLocked(void* arg, grpc_error* error); + + void SetConnectivityStateLocked(grpc_connectivity_state state, + grpc_error* error, const char* reason); + + // Passed in from caller at construction time. + grpc_channel_stack* owning_stack_; + grpc_combiner* combiner_; + grpc_client_channel_factory* client_channel_factory_; + grpc_pollset_set* interested_parties_; + TraceFlag* tracer_; + + channelz::ClientChannelNode* channelz_node_ = nullptr; + + // Resolver and associated state. + OrphanablePtr<Resolver> resolver_; + ProcessResolverResultCallback process_resolver_result_; + void* process_resolver_result_user_data_; + bool started_resolving_ = false; + grpc_channel_args* resolver_result_ = nullptr; + bool previous_resolution_contained_addresses_ = false; + grpc_closure_list waiting_for_resolver_result_closures_; + grpc_closure on_resolver_result_changed_; + + // LB policy and associated state. + OrphanablePtr<LoadBalancingPolicy> lb_policy_; + bool exit_idle_when_lb_policy_arrives_ = false; + + grpc_connectivity_state_tracker state_tracker_; +}; + +} // namespace grpc_core + +#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H */ diff --git a/src/core/ext/filters/client_channel/resolver.cc b/src/core/ext/filters/client_channel/resolver.cc index cd11eeb9e4..601b08be24 100644 --- a/src/core/ext/filters/client_channel/resolver.cc +++ b/src/core/ext/filters/client_channel/resolver.cc @@ -27,7 +27,7 @@ grpc_core::DebugOnlyTraceFlag grpc_trace_resolver_refcount(false, namespace grpc_core { Resolver::Resolver(grpc_combiner* combiner) - : InternallyRefCountedWithTracing(&grpc_trace_resolver_refcount), + : InternallyRefCounted(&grpc_trace_resolver_refcount), combiner_(GRPC_COMBINER_REF(combiner, "resolver")) {} Resolver::~Resolver() { GRPC_COMBINER_UNREF(combiner_, "resolver"); } diff --git a/src/core/ext/filters/client_channel/resolver.h b/src/core/ext/filters/client_channel/resolver.h index e9acbb7c41..9da849a101 100644 --- a/src/core/ext/filters/client_channel/resolver.h +++ b/src/core/ext/filters/client_channel/resolver.h @@ -44,7 +44,7 @@ namespace grpc_core { /// /// Note: All methods with a "Locked" suffix must be called from the /// combiner passed to the constructor. -class Resolver : public InternallyRefCountedWithTracing<Resolver> { +class Resolver : public InternallyRefCounted<Resolver> { public: // Not copyable nor movable. Resolver(const Resolver&) = delete; diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc index 01796ca08f..abacd0c960 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc @@ -33,6 +33,7 @@ #include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/env.h" @@ -117,9 +118,13 @@ class AresDnsResolver : public Resolver { /// retry backoff state BackOff backoff_; /// currently resolving addresses - grpc_lb_addresses* lb_addresses_ = nullptr; + UniquePtr<ServerAddressList> addresses_; /// currently resolving service config char* service_config_json_ = nullptr; + // has shutdown been initiated + bool shutdown_initiated_ = false; + // timeout in milliseconds for active DNS queries + int query_timeout_ms_; }; AresDnsResolver::AresDnsResolver(const ResolverArgs& args) @@ -157,10 +162,15 @@ AresDnsResolver::AresDnsResolver(const ResolverArgs& args) grpc_combiner_scheduler(combiner())); GRPC_CLOSURE_INIT(&on_resolved_, OnResolvedLocked, this, grpc_combiner_scheduler(combiner())); + const grpc_arg* query_timeout_ms_arg = + grpc_channel_args_find(channel_args_, GRPC_ARG_DNS_ARES_QUERY_TIMEOUT_MS); + query_timeout_ms_ = grpc_channel_arg_get_integer( + query_timeout_ms_arg, + {GRPC_DNS_ARES_DEFAULT_QUERY_TIMEOUT_MS, 0, INT_MAX}); } AresDnsResolver::~AresDnsResolver() { - gpr_log(GPR_DEBUG, "destroying AresDnsResolver"); + GRPC_CARES_TRACE_LOG("resolver:%p destroying AresDnsResolver", this); if (resolved_result_ != nullptr) { grpc_channel_args_destroy(resolved_result_); } @@ -172,7 +182,8 @@ AresDnsResolver::~AresDnsResolver() { void AresDnsResolver::NextLocked(grpc_channel_args** target_result, grpc_closure* on_complete) { - gpr_log(GPR_DEBUG, "AresDnsResolver::NextLocked() is called."); + GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::NextLocked() is called.", + this); GPR_ASSERT(next_completion_ == nullptr); next_completion_ = on_complete; target_result_ = target_result; @@ -197,6 +208,7 @@ void AresDnsResolver::ResetBackoffLocked() { } void AresDnsResolver::ShutdownLocked() { + shutdown_initiated_ = true; if (have_next_resolution_timer_) { grpc_timer_cancel(&next_resolution_timer_); } @@ -213,9 +225,15 @@ void AresDnsResolver::ShutdownLocked() { void AresDnsResolver::OnNextResolutionLocked(void* arg, grpc_error* error) { AresDnsResolver* r = static_cast<AresDnsResolver*>(arg); + GRPC_CARES_TRACE_LOG( + "resolver:%p re-resolution timer fired. error: %s. shutdown_initiated_: " + "%d", + r, grpc_error_string(error), r->shutdown_initiated_); r->have_next_resolution_timer_ = false; - if (error == GRPC_ERROR_NONE) { + if (error == GRPC_ERROR_NONE && !r->shutdown_initiated_) { if (!r->resolving_) { + GRPC_CARES_TRACE_LOG( + "resolver:%p start resolving due to re-resolution timer", r); r->StartResolvingLocked(); } } @@ -300,53 +318,40 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { r->resolving_ = false; gpr_free(r->pending_request_); r->pending_request_ = nullptr; - if (r->lb_addresses_ != nullptr) { - static const char* args_to_remove[2]; + if (r->addresses_ != nullptr) { + static const char* args_to_remove[1]; size_t num_args_to_remove = 0; - grpc_arg new_args[3]; + grpc_arg args_to_add[2]; size_t num_args_to_add = 0; - new_args[num_args_to_add++] = - grpc_lb_addresses_create_channel_arg(r->lb_addresses_); - grpc_core::UniquePtr<grpc_core::ServiceConfig> service_config; + args_to_add[num_args_to_add++] = + CreateServerAddressListChannelArg(r->addresses_.get()); char* service_config_string = nullptr; if (r->service_config_json_ != nullptr) { service_config_string = ChooseServiceConfig(r->service_config_json_); gpr_free(r->service_config_json_); if (service_config_string != nullptr) { - gpr_log(GPR_INFO, "selected service config choice: %s", - service_config_string); + GRPC_CARES_TRACE_LOG("resolver:%p selected service config choice: %s", + r, service_config_string); args_to_remove[num_args_to_remove++] = GRPC_ARG_SERVICE_CONFIG; - new_args[num_args_to_add++] = grpc_channel_arg_string_create( + args_to_add[num_args_to_add++] = grpc_channel_arg_string_create( (char*)GRPC_ARG_SERVICE_CONFIG, service_config_string); - service_config = - grpc_core::ServiceConfig::Create(service_config_string); - if (service_config != nullptr) { - const char* lb_policy_name = - service_config->GetLoadBalancingPolicyName(); - if (lb_policy_name != nullptr) { - args_to_remove[num_args_to_remove++] = GRPC_ARG_LB_POLICY_NAME; - new_args[num_args_to_add++] = grpc_channel_arg_string_create( - (char*)GRPC_ARG_LB_POLICY_NAME, - const_cast<char*>(lb_policy_name)); - } - } } } result = grpc_channel_args_copy_and_add_and_remove( - r->channel_args_, args_to_remove, num_args_to_remove, new_args, + r->channel_args_, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); gpr_free(service_config_string); - grpc_lb_addresses_destroy(r->lb_addresses_); + r->addresses_.reset(); // Reset backoff state so that we start from the beginning when the // next request gets triggered. r->backoff_.Reset(); - } else { + } else if (!r->shutdown_initiated_) { const char* msg = grpc_error_string(error); - gpr_log(GPR_DEBUG, "dns resolution failed: %s", msg); + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed: %s", r, msg); grpc_millis next_try = r->backoff_.NextAttemptTime(); grpc_millis timeout = next_try - ExecCtx::Get()->Now(); - gpr_log(GPR_INFO, "dns resolution failed (will retry): %s", - grpc_error_string(error)); + GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed (will retry): %s", + r, grpc_error_string(error)); GPR_ASSERT(!r->have_next_resolution_timer_); r->have_next_resolution_timer_ = true; // TODO(roth): We currently deal with this ref manually. Once the @@ -355,9 +360,10 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { RefCountedPtr<Resolver> self = r->Ref(DEBUG_LOCATION, "retry-timer"); self.release(); if (timeout > 0) { - gpr_log(GPR_DEBUG, "retrying in %" PRId64 " milliseconds", timeout); + GRPC_CARES_TRACE_LOG("resolver:%p retrying in %" PRId64 " milliseconds", + r, timeout); } else { - gpr_log(GPR_DEBUG, "retrying immediately"); + GRPC_CARES_TRACE_LOG("resolver:%p retrying immediately", r); } grpc_timer_init(&r->next_resolution_timer_, next_try, &r->on_next_resolution_); @@ -383,10 +389,10 @@ void AresDnsResolver::MaybeStartResolvingLocked() { if (ms_until_next_resolution > 0) { const grpc_millis last_resolution_ago = grpc_core::ExecCtx::Get()->Now() - last_resolution_timestamp_; - gpr_log(GPR_DEBUG, - "In cooldown from last resolution (from %" PRId64 - " ms ago). Will resolve again in %" PRId64 " ms", - last_resolution_ago, ms_until_next_resolution); + GRPC_CARES_TRACE_LOG( + "resolver:%p In cooldown from last resolution (from %" PRId64 + " ms ago). Will resolve again in %" PRId64 " ms", + this, last_resolution_ago, ms_until_next_resolution); have_next_resolution_timer_ = true; // TODO(roth): We currently deal with this ref manually. Once the // new closure API is done, find a way to track this ref with the timer @@ -403,7 +409,6 @@ void AresDnsResolver::MaybeStartResolvingLocked() { } void AresDnsResolver::StartResolvingLocked() { - gpr_log(GPR_DEBUG, "Start resolving."); // TODO(roth): We currently deal with this ref manually. Once the // new closure API is done, find a way to track this ref with the timer // callback as part of the type system. @@ -411,13 +416,15 @@ void AresDnsResolver::StartResolvingLocked() { self.release(); GPR_ASSERT(!resolving_); resolving_ = true; - lb_addresses_ = nullptr; service_config_json_ = nullptr; pending_request_ = grpc_dns_lookup_ares_locked( dns_server_, name_to_resolve_, kDefaultPort, interested_parties_, - &on_resolved_, &lb_addresses_, true /* check_grpclb */, - request_service_config_ ? &service_config_json_ : nullptr, combiner()); + &on_resolved_, &addresses_, true /* check_grpclb */, + request_service_config_ ? &service_config_json_ : nullptr, + query_timeout_ms_, combiner()); last_resolution_timestamp_ = grpc_core::ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG("resolver:%p Started resolving. pending_request_:%p", + this, pending_request_); } void AresDnsResolver::MaybeFinishNextLocked() { @@ -425,7 +432,8 @@ void AresDnsResolver::MaybeFinishNextLocked() { *target_result_ = resolved_result_ == nullptr ? nullptr : grpc_channel_args_copy(resolved_result_); - gpr_log(GPR_DEBUG, "AresDnsResolver::MaybeFinishNextLocked()"); + GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::MaybeFinishNextLocked()", + this); GRPC_CLOSURE_SCHED(next_completion_, GRPC_ERROR_NONE); next_completion_ = nullptr; published_version_ = resolved_version_; @@ -463,11 +471,16 @@ static grpc_error* blocking_resolve_address_ares( static grpc_address_resolver_vtable ares_resolver = { grpc_resolve_address_ares, blocking_resolve_address_ares}; +static bool should_use_ares(const char* resolver_env) { + return resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0; +} + void grpc_resolver_dns_ares_init() { char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER"); /* TODO(zyc): Turn on c-ares based resolver by default after the address sorter and the CNAME support are added. */ - if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + if (should_use_ares(resolver_env)) { + gpr_log(GPR_DEBUG, "Using ares dns resolver"); address_sorting_init(); grpc_error* error = grpc_ares_init(); if (error != GRPC_ERROR_NONE) { @@ -487,7 +500,7 @@ void grpc_resolver_dns_ares_init() { void grpc_resolver_dns_ares_shutdown() { char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER"); - if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) { + if (should_use_ares(resolver_env)) { address_sorting_shutdown(); grpc_ares_cleanup(); } diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc index fdbd07ebf5..d99c2e3004 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc @@ -31,8 +31,10 @@ #include <grpc/support/time.h> #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/combiner.h" #include "src/core/lib/iomgr/iomgr_internal.h" #include "src/core/lib/iomgr/sockaddr_utils.h" +#include "src/core/lib/iomgr/timer.h" typedef struct fd_node { /** the owner of this fd node */ @@ -76,21 +78,30 @@ struct grpc_ares_ev_driver { grpc_ares_request* request; /** Owned by the ev_driver. Creates new GrpcPolledFd's */ grpc_core::UniquePtr<grpc_core::GrpcPolledFdFactory> polled_fd_factory; + /** query timeout in milliseconds */ + int query_timeout_ms; + /** alarm to cancel active queries */ + grpc_timer query_timeout; + /** cancels queries on a timeout */ + grpc_closure on_timeout_locked; }; static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver); static grpc_ares_ev_driver* grpc_ares_ev_driver_ref( grpc_ares_ev_driver* ev_driver) { - gpr_log(GPR_DEBUG, "Ref ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p Ref ev_driver %p", ev_driver->request, + ev_driver); gpr_ref(&ev_driver->refs); return ev_driver; } static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) { - gpr_log(GPR_DEBUG, "Unref ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p Unref ev_driver %p", ev_driver->request, + ev_driver); if (gpr_unref(&ev_driver->refs)) { - gpr_log(GPR_DEBUG, "destroy ev_driver %" PRIuPTR, (uintptr_t)ev_driver); + GRPC_CARES_TRACE_LOG("request:%p destroy ev_driver %p", ev_driver->request, + ev_driver); GPR_ASSERT(ev_driver->fds == nullptr); GRPC_COMBINER_UNREF(ev_driver->combiner, "free ares event driver"); ares_destroy(ev_driver->channel); @@ -100,7 +111,8 @@ static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) { } static void fd_node_destroy_locked(fd_node* fdn) { - gpr_log(GPR_DEBUG, "delete fd: %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p delete fd: %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); GPR_ASSERT(!fdn->readable_registered); GPR_ASSERT(!fdn->writable_registered); GPR_ASSERT(fdn->already_shutdown); @@ -116,8 +128,11 @@ static void fd_node_shutdown_locked(fd_node* fdn, const char* reason) { } } +static void on_timeout_locked(void* arg, grpc_error* error); + grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver, grpc_pollset_set* pollset_set, + int query_timeout_ms, grpc_combiner* combiner, grpc_ares_request* request) { *ev_driver = grpc_core::New<grpc_ares_ev_driver>(); @@ -125,7 +140,7 @@ grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver, memset(&opts, 0, sizeof(opts)); opts.flags |= ARES_FLAG_STAYOPEN; int status = ares_init_options(&(*ev_driver)->channel, &opts, ARES_OPT_FLAGS); - gpr_log(GPR_DEBUG, "grpc_ares_ev_driver_create_locked"); + GRPC_CARES_TRACE_LOG("request:%p grpc_ares_ev_driver_create_locked", request); if (status != ARES_SUCCESS) { char* err_msg; gpr_asprintf(&err_msg, "Failed to init ares channel. C-ares error: %s", @@ -146,6 +161,9 @@ grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver, grpc_core::NewGrpcPolledFdFactory((*ev_driver)->combiner); (*ev_driver) ->polled_fd_factory->ConfigureAresChannelLocked((*ev_driver)->channel); + GRPC_CLOSURE_INIT(&(*ev_driver)->on_timeout_locked, on_timeout_locked, + *ev_driver, grpc_combiner_scheduler(combiner)); + (*ev_driver)->query_timeout_ms = query_timeout_ms; return GRPC_ERROR_NONE; } @@ -155,6 +173,7 @@ void grpc_ares_ev_driver_on_queries_complete_locked( // is working, grpc_ares_notify_on_event_locked will shut down the // fds; if it's not working, there are no fds to shut down. ev_driver->shutting_down = true; + grpc_timer_cancel(&ev_driver->query_timeout); grpc_ares_ev_driver_unref(ev_driver); } @@ -185,12 +204,25 @@ static fd_node* pop_fd_node_locked(fd_node** head, ares_socket_t as) { return nullptr; } +static void on_timeout_locked(void* arg, grpc_error* error) { + grpc_ares_ev_driver* driver = static_cast<grpc_ares_ev_driver*>(arg); + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p on_timeout_locked. driver->shutting_down=%d. " + "err=%s", + driver->request, driver, driver->shutting_down, grpc_error_string(error)); + if (!driver->shutting_down && error == GRPC_ERROR_NONE) { + grpc_ares_ev_driver_shutdown_locked(driver); + } + grpc_ares_ev_driver_unref(driver); +} + static void on_readable_locked(void* arg, grpc_error* error) { fd_node* fdn = static_cast<fd_node*>(arg); grpc_ares_ev_driver* ev_driver = fdn->ev_driver; const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); fdn->readable_registered = false; - gpr_log(GPR_DEBUG, "readable on %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p readable on %s", fdn->ev_driver->request, + fdn->grpc_polled_fd->GetName()); if (error == GRPC_ERROR_NONE) { do { ares_process_fd(ev_driver->channel, as, ARES_SOCKET_BAD); @@ -213,7 +245,8 @@ static void on_writable_locked(void* arg, grpc_error* error) { grpc_ares_ev_driver* ev_driver = fdn->ev_driver; const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked(); fdn->writable_registered = false; - gpr_log(GPR_DEBUG, "writable on %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p writable on %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); if (error == GRPC_ERROR_NONE) { ares_process_fd(ev_driver->channel, ARES_SOCKET_BAD, as); } else { @@ -252,7 +285,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { fdn->grpc_polled_fd = ev_driver->polled_fd_factory->NewGrpcPolledFdLocked( socks[i], ev_driver->pollset_set, ev_driver->combiner); - gpr_log(GPR_DEBUG, "new fd: %s", fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p new fd: %s", ev_driver->request, + fdn->grpc_polled_fd->GetName()); fdn->ev_driver = ev_driver; fdn->readable_registered = false; fdn->writable_registered = false; @@ -269,8 +303,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { if (ARES_GETSOCK_READABLE(socks_bitmask, i) && !fdn->readable_registered) { grpc_ares_ev_driver_ref(ev_driver); - gpr_log(GPR_DEBUG, "notify read on: %s", - fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p notify read on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); fdn->grpc_polled_fd->RegisterForOnReadableLocked(&fdn->read_closure); fdn->readable_registered = true; } @@ -278,8 +313,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { // has not been registered with this socket. if (ARES_GETSOCK_WRITABLE(socks_bitmask, i) && !fdn->writable_registered) { - gpr_log(GPR_DEBUG, "notify write on: %s", - fdn->grpc_polled_fd->GetName()); + GRPC_CARES_TRACE_LOG("request:%p notify write on: %s", + ev_driver->request, + fdn->grpc_polled_fd->GetName()); grpc_ares_ev_driver_ref(ev_driver); fdn->grpc_polled_fd->RegisterForOnWriteableLocked( &fdn->write_closure); @@ -306,7 +342,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) { // If the ev driver has no working fd, all the tasks are done. if (new_list == nullptr) { ev_driver->working = false; - gpr_log(GPR_DEBUG, "ev driver stop working"); + GRPC_CARES_TRACE_LOG("request:%p ev driver stop working", + ev_driver->request); } } @@ -314,6 +351,17 @@ void grpc_ares_ev_driver_start_locked(grpc_ares_ev_driver* ev_driver) { if (!ev_driver->working) { ev_driver->working = true; grpc_ares_notify_on_event_locked(ev_driver); + grpc_millis timeout = + ev_driver->query_timeout_ms == 0 + ? GRPC_MILLIS_INF_FUTURE + : ev_driver->query_timeout_ms + grpc_core::ExecCtx::Get()->Now(); + GRPC_CARES_TRACE_LOG( + "request:%p ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in " + "%" PRId64 " ms", + ev_driver->request, ev_driver, timeout); + grpc_ares_ev_driver_ref(ev_driver); + grpc_timer_init(&ev_driver->query_timeout, timeout, + &ev_driver->on_timeout_locked); } } diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h index 671c537fe7..b8cefd9470 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h @@ -43,6 +43,7 @@ ares_channel* grpc_ares_ev_driver_get_channel_locked( created successfully. */ grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver, grpc_pollset_set* pollset_set, + int query_timeout_ms, grpc_combiner* combiner, grpc_ares_request* request); diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc index 582e2203fc..1a7e5d0626 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc @@ -37,12 +37,16 @@ #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/iomgr/combiner.h" #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/executor.h" #include "src/core/lib/iomgr/iomgr_internal.h" #include "src/core/lib/iomgr/nameser.h" #include "src/core/lib/iomgr/sockaddr_utils.h" +using grpc_core::ServerAddress; +using grpc_core::ServerAddressList; + static gpr_once g_basic_init = GPR_ONCE_INIT; static gpr_mu g_init_mu; @@ -58,7 +62,7 @@ struct grpc_ares_request { /** closure to call when the request completes */ grpc_closure* on_done; /** the pointer to receive the resolved addresses */ - grpc_lb_addresses** lb_addrs_out; + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addresses_out; /** the pointer to receive the service config in JSON */ char** service_config_json_out; /** the evernt driver used by this request */ @@ -87,46 +91,44 @@ typedef struct grpc_ares_hostbyname_request { static void do_basic_init(void) { gpr_mu_init(&g_init_mu); } -static void log_address_sorting_list(grpc_lb_addresses* lb_addrs, +static void log_address_sorting_list(const ServerAddressList& addresses, const char* input_output_str) { - for (size_t i = 0; i < lb_addrs->num_addresses; i++) { + for (size_t i = 0; i < addresses.size(); i++) { char* addr_str; - if (grpc_sockaddr_to_string(&addr_str, &lb_addrs->addresses[i].address, - true)) { - gpr_log(GPR_DEBUG, "c-ares address sorting: %s[%" PRIuPTR "]=%s", + if (grpc_sockaddr_to_string(&addr_str, &addresses[i].address(), true)) { + gpr_log(GPR_INFO, "c-ares address sorting: %s[%" PRIuPTR "]=%s", input_output_str, i, addr_str); gpr_free(addr_str); } else { - gpr_log(GPR_DEBUG, + gpr_log(GPR_INFO, "c-ares address sorting: %s[%" PRIuPTR "]=<unprintable>", input_output_str, i); } } } -void grpc_cares_wrapper_address_sorting_sort(grpc_lb_addresses* lb_addrs) { +void grpc_cares_wrapper_address_sorting_sort(ServerAddressList* addresses) { if (grpc_trace_cares_address_sorting.enabled()) { - log_address_sorting_list(lb_addrs, "input"); + log_address_sorting_list(*addresses, "input"); } address_sorting_sortable* sortables = (address_sorting_sortable*)gpr_zalloc( - sizeof(address_sorting_sortable) * lb_addrs->num_addresses); - for (size_t i = 0; i < lb_addrs->num_addresses; i++) { - sortables[i].user_data = &lb_addrs->addresses[i]; - memcpy(&sortables[i].dest_addr.addr, &lb_addrs->addresses[i].address.addr, - lb_addrs->addresses[i].address.len); - sortables[i].dest_addr.len = lb_addrs->addresses[i].address.len; + sizeof(address_sorting_sortable) * addresses->size()); + for (size_t i = 0; i < addresses->size(); ++i) { + sortables[i].user_data = &(*addresses)[i]; + memcpy(&sortables[i].dest_addr.addr, &(*addresses)[i].address().addr, + (*addresses)[i].address().len); + sortables[i].dest_addr.len = (*addresses)[i].address().len; } - address_sorting_rfc_6724_sort(sortables, lb_addrs->num_addresses); - grpc_lb_address* sorted_lb_addrs = (grpc_lb_address*)gpr_zalloc( - sizeof(grpc_lb_address) * lb_addrs->num_addresses); - for (size_t i = 0; i < lb_addrs->num_addresses; i++) { - sorted_lb_addrs[i] = *(grpc_lb_address*)sortables[i].user_data; + address_sorting_rfc_6724_sort(sortables, addresses->size()); + ServerAddressList sorted; + sorted.reserve(addresses->size()); + for (size_t i = 0; i < addresses->size(); ++i) { + sorted.emplace_back(*static_cast<ServerAddress*>(sortables[i].user_data)); } gpr_free(sortables); - gpr_free(lb_addrs->addresses); - lb_addrs->addresses = sorted_lb_addrs; + *addresses = std::move(sorted); if (grpc_trace_cares_address_sorting.enabled()) { - log_address_sorting_list(lb_addrs, "output"); + log_address_sorting_list(*addresses, "output"); } } @@ -145,9 +147,9 @@ void grpc_ares_complete_request_locked(grpc_ares_request* r) { /* Invoke on_done callback and destroy the request */ r->ev_driver = nullptr; - grpc_lb_addresses* lb_addrs = *(r->lb_addrs_out); - if (lb_addrs != nullptr) { - grpc_cares_wrapper_address_sorting_sort(lb_addrs); + ServerAddressList* addresses = r->addresses_out->get(); + if (addresses != nullptr) { + grpc_cares_wrapper_address_sorting_sort(addresses); } GRPC_CLOSURE_SCHED(r->on_done, r->error); } @@ -181,60 +183,53 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts, GRPC_ERROR_UNREF(r->error); r->error = GRPC_ERROR_NONE; r->success = true; - grpc_lb_addresses** lb_addresses = r->lb_addrs_out; - if (*lb_addresses == nullptr) { - *lb_addresses = grpc_lb_addresses_create(0, nullptr); - } - size_t prev_naddr = (*lb_addresses)->num_addresses; - size_t i; - for (i = 0; hostent->h_addr_list[i] != nullptr; i++) { + if (*r->addresses_out == nullptr) { + *r->addresses_out = grpc_core::MakeUnique<ServerAddressList>(); } - (*lb_addresses)->num_addresses += i; - (*lb_addresses)->addresses = static_cast<grpc_lb_address*>( - gpr_realloc((*lb_addresses)->addresses, - sizeof(grpc_lb_address) * (*lb_addresses)->num_addresses)); - for (i = prev_naddr; i < (*lb_addresses)->num_addresses; i++) { + ServerAddressList& addresses = **r->addresses_out; + for (size_t i = 0; hostent->h_addr_list[i] != nullptr; ++i) { + grpc_core::InlinedVector<grpc_arg, 2> args_to_add; + if (hr->is_balancer) { + args_to_add.emplace_back(grpc_channel_arg_integer_create( + const_cast<char*>(GRPC_ARG_ADDRESS_IS_BALANCER), 1)); + args_to_add.emplace_back(grpc_channel_arg_string_create( + const_cast<char*>(GRPC_ARG_ADDRESS_BALANCER_NAME), hr->host)); + } + grpc_channel_args* args = grpc_channel_args_copy_and_add( + nullptr, args_to_add.data(), args_to_add.size()); switch (hostent->h_addrtype) { case AF_INET6: { size_t addr_len = sizeof(struct sockaddr_in6); struct sockaddr_in6 addr; memset(&addr, 0, addr_len); - memcpy(&addr.sin6_addr, hostent->h_addr_list[i - prev_naddr], + memcpy(&addr.sin6_addr, hostent->h_addr_list[i], sizeof(struct in6_addr)); addr.sin6_family = static_cast<unsigned char>(hostent->h_addrtype); addr.sin6_port = hr->port; - grpc_lb_addresses_set_address( - *lb_addresses, i, &addr, addr_len, - hr->is_balancer /* is_balancer */, - hr->is_balancer ? hr->host : nullptr /* balancer_name */, - nullptr /* user_data */); + addresses.emplace_back(&addr, addr_len, args); char output[INET6_ADDRSTRLEN]; ares_inet_ntop(AF_INET6, &addr.sin6_addr, output, INET6_ADDRSTRLEN); - gpr_log(GPR_DEBUG, - "c-ares resolver gets a AF_INET6 result: \n" - " addr: %s\n port: %d\n sin6_scope_id: %d\n", - output, ntohs(hr->port), addr.sin6_scope_id); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET6 result: \n" + " addr: %s\n port: %d\n sin6_scope_id: %d\n", + r, output, ntohs(hr->port), addr.sin6_scope_id); break; } case AF_INET: { size_t addr_len = sizeof(struct sockaddr_in); struct sockaddr_in addr; memset(&addr, 0, addr_len); - memcpy(&addr.sin_addr, hostent->h_addr_list[i - prev_naddr], + memcpy(&addr.sin_addr, hostent->h_addr_list[i], sizeof(struct in_addr)); addr.sin_family = static_cast<unsigned char>(hostent->h_addrtype); addr.sin_port = hr->port; - grpc_lb_addresses_set_address( - *lb_addresses, i, &addr, addr_len, - hr->is_balancer /* is_balancer */, - hr->is_balancer ? hr->host : nullptr /* balancer_name */, - nullptr /* user_data */); + addresses.emplace_back(&addr, addr_len, args); char output[INET_ADDRSTRLEN]; ares_inet_ntop(AF_INET, &addr.sin_addr, output, INET_ADDRSTRLEN); - gpr_log(GPR_DEBUG, - "c-ares resolver gets a AF_INET result: \n" - " addr: %s\n port: %d\n", - output, ntohs(hr->port)); + GRPC_CARES_TRACE_LOG( + "request:%p c-ares resolver gets a AF_INET result: \n" + " addr: %s\n port: %d\n", + r, output, ntohs(hr->port)); break; } } @@ -257,9 +252,9 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts, static void on_srv_query_done_locked(void* arg, int status, int timeouts, unsigned char* abuf, int alen) { grpc_ares_request* r = static_cast<grpc_ares_request*>(arg); - gpr_log(GPR_DEBUG, "on_query_srv_done_locked"); + GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked", r); if (status == ARES_SUCCESS) { - gpr_log(GPR_DEBUG, "on_query_srv_done_locked ARES_SUCCESS"); + GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked ARES_SUCCESS", r); struct ares_srv_reply* reply; const int parse_status = ares_parse_srv_reply(abuf, alen, &reply); if (parse_status == ARES_SUCCESS) { @@ -302,9 +297,9 @@ static const char g_service_config_attribute_prefix[] = "grpc_config="; static void on_txt_done_locked(void* arg, int status, int timeouts, unsigned char* buf, int len) { - gpr_log(GPR_DEBUG, "on_txt_done_locked"); char* error_msg; grpc_ares_request* r = static_cast<grpc_ares_request*>(arg); + GRPC_CARES_TRACE_LOG("request:%p on_txt_done_locked", r); const size_t prefix_len = sizeof(g_service_config_attribute_prefix) - 1; struct ares_txt_ext* result = nullptr; struct ares_txt_ext* reply = nullptr; @@ -337,7 +332,8 @@ static void on_txt_done_locked(void* arg, int status, int timeouts, service_config_len += result->length; } (*r->service_config_json_out)[service_config_len] = '\0'; - gpr_log(GPR_INFO, "found service config: %s", *r->service_config_json_out); + GRPC_CARES_TRACE_LOG("request:%p found service config: %s", r, + *r->service_config_json_out); } // Clean up. ares_free_data(reply); @@ -359,16 +355,10 @@ done: void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( grpc_ares_request* r, const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, - bool check_grpclb, grpc_combiner* combiner) { + bool check_grpclb, int query_timeout_ms, grpc_combiner* combiner) { grpc_error* error = GRPC_ERROR_NONE; grpc_ares_hostbyname_request* hr = nullptr; ares_channel* channel = nullptr; - /* TODO(zyc): Enable tracing after #9603 is checked in */ - /* if (grpc_dns_trace) { - gpr_log(GPR_DEBUG, "resolve_address (blocking): name=%s, default_port=%s", - name, default_port); - } */ - /* parse name, splitting it into host and port parts */ char* host; char* port; @@ -388,12 +378,12 @@ void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( port = gpr_strdup(default_port); } error = grpc_ares_ev_driver_create_locked(&r->ev_driver, interested_parties, - combiner, r); + query_timeout_ms, combiner, r); if (error != GRPC_ERROR_NONE) goto error_cleanup; channel = grpc_ares_ev_driver_get_channel_locked(r->ev_driver); // If dns_server is specified, use it. if (dns_server != nullptr) { - gpr_log(GPR_INFO, "Using DNS server %s", dns_server); + GRPC_CARES_TRACE_LOG("request:%p Using DNS server %s", r, dns_server); grpc_resolved_address addr; if (grpc_parse_ipv4_hostport(dns_server, &addr, false /* log_errors */)) { r->dns_server_addr.family = AF_INET; @@ -467,11 +457,10 @@ error_cleanup: gpr_free(port); } -static bool inner_resolve_as_ip_literal_locked(const char* name, - const char* default_port, - grpc_lb_addresses** addrs, - char** host, char** port, - char** hostport) { +static bool inner_resolve_as_ip_literal_locked( + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, char** host, + char** port, char** hostport) { gpr_split_host_port(name, host, port); if (*host == nullptr) { gpr_log(GPR_ERROR, @@ -495,18 +484,16 @@ static bool inner_resolve_as_ip_literal_locked(const char* name, if (grpc_parse_ipv4_hostport(*hostport, &addr, false /* log errors */) || grpc_parse_ipv6_hostport(*hostport, &addr, false /* log errors */)) { GPR_ASSERT(*addrs == nullptr); - *addrs = grpc_lb_addresses_create(1, nullptr); - grpc_lb_addresses_set_address( - *addrs, 0, addr.addr, addr.len, false /* is_balancer */, - nullptr /* balancer_name */, nullptr /* user_data */); + *addrs = grpc_core::MakeUnique<ServerAddressList>(); + (*addrs)->emplace_back(addr.addr, addr.len, nullptr /* args */); return true; } return false; } -static bool resolve_as_ip_literal_locked(const char* name, - const char* default_port, - grpc_lb_addresses** addrs) { +static bool resolve_as_ip_literal_locked( + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs) { char* host = nullptr; char* port = nullptr; char* hostport = nullptr; @@ -518,20 +505,47 @@ static bool resolve_as_ip_literal_locked(const char* name, return out; } +static bool target_matches_localhost_inner(const char* name, char** host, + char** port) { + if (!gpr_split_host_port(name, host, port)) { + gpr_log(GPR_ERROR, "Unable to split host and port for name: %s", name); + return false; + } + if (gpr_stricmp(*host, "localhost") == 0) { + return true; + } else { + return false; + } +} + +static bool target_matches_localhost(const char* name) { + char* host = nullptr; + char* port = nullptr; + bool out = target_matches_localhost_inner(name, &host, &port); + gpr_free(host); + gpr_free(port); + return out; +} + static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, - grpc_lb_addresses** addrs, bool check_grpclb, char** service_config_json, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, + bool check_grpclb, char** service_config_json, int query_timeout_ms, grpc_combiner* combiner) { grpc_ares_request* r = static_cast<grpc_ares_request*>(gpr_zalloc(sizeof(grpc_ares_request))); r->ev_driver = nullptr; r->on_done = on_done; - r->lb_addrs_out = addrs; + r->addresses_out = addrs; r->service_config_json_out = service_config_json; r->success = false; r->error = GRPC_ERROR_NONE; r->pending_queries = 0; + GRPC_CARES_TRACE_LOG( + "request:%p c-ares grpc_dns_lookup_ares_locked_impl name=%s, " + "default_port=%s", + r, name, default_port); // Early out if the target is an ipv4 or ipv6 literal. if (resolve_as_ip_literal_locked(name, default_port, addrs)) { GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE); @@ -543,17 +557,25 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE); return r; } + // Don't query for SRV and TXT records if the target is "localhost", so + // as to cut down on lookups over the network, especially in tests: + // https://github.com/grpc/proposal/pull/79 + if (target_matches_localhost(name)) { + check_grpclb = false; + r->service_config_json_out = nullptr; + } // Look up name using c-ares lib. grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked( r, dns_server, name, default_port, interested_parties, check_grpclb, - combiner); + query_timeout_ms, combiner); return r; } grpc_ares_request* (*grpc_dns_lookup_ares_locked)( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, - grpc_lb_addresses** addrs, bool check_grpclb, char** service_config_json, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, + bool check_grpclb, char** service_config_json, int query_timeout_ms, grpc_combiner* combiner) = grpc_dns_lookup_ares_locked_impl; static void grpc_cancel_ares_request_locked_impl(grpc_ares_request* r) { @@ -598,8 +620,8 @@ typedef struct grpc_resolve_address_ares_request { grpc_combiner* combiner; /** the pointer to receive the resolved addresses */ grpc_resolved_addresses** addrs_out; - /** currently resolving lb addresses */ - grpc_lb_addresses* lb_addrs; + /** currently resolving addresses */ + grpc_core::UniquePtr<ServerAddressList> addresses; /** closure to call when the resolve_address_ares request completes */ grpc_closure* on_resolve_address_done; /** a closure wrapping on_resolve_address_done, which should be invoked when @@ -612,7 +634,7 @@ typedef struct grpc_resolve_address_ares_request { /* pollset_set to be driven by */ grpc_pollset_set* interested_parties; /* underlying ares_request that the query is performed on */ - grpc_ares_request* ares_request; + grpc_ares_request* ares_request = nullptr; } grpc_resolve_address_ares_request; static void on_dns_lookup_done_locked(void* arg, grpc_error* error) { @@ -620,25 +642,24 @@ static void on_dns_lookup_done_locked(void* arg, grpc_error* error) { static_cast<grpc_resolve_address_ares_request*>(arg); gpr_free(r->ares_request); grpc_resolved_addresses** resolved_addresses = r->addrs_out; - if (r->lb_addrs == nullptr || r->lb_addrs->num_addresses == 0) { + if (r->addresses == nullptr || r->addresses->empty()) { *resolved_addresses = nullptr; } else { *resolved_addresses = static_cast<grpc_resolved_addresses*>( gpr_zalloc(sizeof(grpc_resolved_addresses))); - (*resolved_addresses)->naddrs = r->lb_addrs->num_addresses; + (*resolved_addresses)->naddrs = r->addresses->size(); (*resolved_addresses)->addrs = static_cast<grpc_resolved_address*>(gpr_zalloc( sizeof(grpc_resolved_address) * (*resolved_addresses)->naddrs)); - for (size_t i = 0; i < (*resolved_addresses)->naddrs; i++) { - GPR_ASSERT(!r->lb_addrs->addresses[i].is_balancer); - memcpy(&(*resolved_addresses)->addrs[i], - &r->lb_addrs->addresses[i].address, sizeof(grpc_resolved_address)); + for (size_t i = 0; i < (*resolved_addresses)->naddrs; ++i) { + GPR_ASSERT(!(*r->addresses)[i].IsBalancer()); + memcpy(&(*resolved_addresses)->addrs[i], &(*r->addresses)[i].address(), + sizeof(grpc_resolved_address)); } } GRPC_CLOSURE_SCHED(r->on_resolve_address_done, GRPC_ERROR_REF(error)); - if (r->lb_addrs != nullptr) grpc_lb_addresses_destroy(r->lb_addrs); GRPC_COMBINER_UNREF(r->combiner, "on_dns_lookup_done_cb"); - gpr_free(r); + grpc_core::Delete(r); } static void grpc_resolve_address_invoke_dns_lookup_ares_locked( @@ -647,8 +668,9 @@ static void grpc_resolve_address_invoke_dns_lookup_ares_locked( static_cast<grpc_resolve_address_ares_request*>(arg); r->ares_request = grpc_dns_lookup_ares_locked( nullptr /* dns_server */, r->name, r->default_port, r->interested_parties, - &r->on_dns_lookup_done_locked, &r->lb_addrs, false /* check_grpclb */, - nullptr /* service_config_json */, r->combiner); + &r->on_dns_lookup_done_locked, &r->addresses, false /* check_grpclb */, + nullptr /* service_config_json */, GRPC_DNS_ARES_DEFAULT_QUERY_TIMEOUT_MS, + r->combiner); } static void grpc_resolve_address_ares_impl(const char* name, @@ -657,8 +679,7 @@ static void grpc_resolve_address_ares_impl(const char* name, grpc_closure* on_done, grpc_resolved_addresses** addrs) { grpc_resolve_address_ares_request* r = - static_cast<grpc_resolve_address_ares_request*>( - gpr_zalloc(sizeof(grpc_resolve_address_ares_request))); + grpc_core::New<grpc_resolve_address_ares_request>(); r->combiner = grpc_combiner_create(); r->addrs_out = addrs; r->on_resolve_address_done = on_done; diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h index a1231cc4e0..2808250456 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h @@ -21,11 +21,13 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/iomgr/iomgr.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/iomgr/resolve_address.h" +#define GRPC_DNS_ARES_DEFAULT_QUERY_TIMEOUT_MS 10000 + extern grpc_core::TraceFlag grpc_trace_cares_address_sorting; extern grpc_core::TraceFlag grpc_trace_cares_resolver; @@ -59,8 +61,9 @@ extern void (*grpc_resolve_address_ares)(const char* name, extern grpc_ares_request* (*grpc_dns_lookup_ares_locked)( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, - grpc_lb_addresses** addresses, bool check_grpclb, - char** service_config_json, grpc_combiner* combiner); + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addresses, + bool check_grpclb, char** service_config_json, int query_timeout_ms, + grpc_combiner* combiner); /* Cancel the pending grpc_ares_request \a request */ extern void (*grpc_cancel_ares_request_locked)(grpc_ares_request* request); @@ -87,10 +90,12 @@ bool grpc_ares_query_ipv6(); * Returns a bool indicating whether or not such an action was performed. * See https://github.com/grpc/grpc/issues/15158. */ bool grpc_ares_maybe_resolve_localhost_manually_locked( - const char* name, const char* default_port, grpc_lb_addresses** addrs); + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs); /* Sorts destinations in lb_addrs according to RFC 6724. */ -void grpc_cares_wrapper_address_sorting_sort(grpc_lb_addresses* lb_addrs); +void grpc_cares_wrapper_address_sorting_sort( + grpc_core::ServerAddressList* addresses); #endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_DNS_C_ARES_GRPC_ARES_WRAPPER_H \ */ diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.cc index 9f293c1ac0..1f4701c999 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.cc @@ -29,7 +29,8 @@ struct grpc_ares_request { static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, - grpc_lb_addresses** addrs, bool check_grpclb, char** service_config_json, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, + bool check_grpclb, char** service_config_json, int query_timeout_ms, grpc_combiner* combiner) { return NULL; } @@ -37,7 +38,8 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl( grpc_ares_request* (*grpc_dns_lookup_ares_locked)( const char* dns_server, const char* name, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, - grpc_lb_addresses** addrs, bool check_grpclb, char** service_config_json, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, + bool check_grpclb, char** service_config_json, int query_timeout_ms, grpc_combiner* combiner) = grpc_dns_lookup_ares_locked_impl; static void grpc_cancel_ares_request_locked_impl(grpc_ares_request* r) {} diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc index 639eec2323..028d844216 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_posix.cc @@ -27,7 +27,8 @@ bool grpc_ares_query_ipv6() { return grpc_ipv6_loopback_available(); } bool grpc_ares_maybe_resolve_localhost_manually_locked( - const char* name, const char* default_port, grpc_lb_addresses** addrs) { + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs) { return false; } diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc index 7e34784691..202452f1b2 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_windows.cc @@ -23,9 +23,9 @@ #include <grpc/support/string_util.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/iomgr/socket_windows.h" @@ -33,8 +33,9 @@ bool grpc_ares_query_ipv6() { return grpc_ipv6_loopback_available(); } static bool inner_maybe_resolve_localhost_manually_locked( - const char* name, const char* default_port, grpc_lb_addresses** addrs, - char** host, char** port) { + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs, char** host, + char** port) { gpr_split_host_port(name, host, port); if (*host == nullptr) { gpr_log(GPR_ERROR, @@ -55,7 +56,7 @@ static bool inner_maybe_resolve_localhost_manually_locked( } if (gpr_stricmp(*host, "localhost") == 0) { GPR_ASSERT(*addrs == nullptr); - *addrs = grpc_lb_addresses_create(2, nullptr); + *addrs = grpc_core::MakeUnique<grpc_core::ServerAddressList>(); uint16_t numeric_port = grpc_strhtons(*port); // Append the ipv6 loopback address. struct sockaddr_in6 ipv6_loopback_addr; @@ -63,10 +64,8 @@ static bool inner_maybe_resolve_localhost_manually_locked( ((char*)&ipv6_loopback_addr.sin6_addr)[15] = 1; ipv6_loopback_addr.sin6_family = AF_INET6; ipv6_loopback_addr.sin6_port = numeric_port; - grpc_lb_addresses_set_address( - *addrs, 0, &ipv6_loopback_addr, sizeof(ipv6_loopback_addr), - false /* is_balancer */, nullptr /* balancer_name */, - nullptr /* user_data */); + (*addrs)->emplace_back(&ipv6_loopback_addr, sizeof(ipv6_loopback_addr), + nullptr /* args */); // Append the ipv4 loopback address. struct sockaddr_in ipv4_loopback_addr; memset(&ipv4_loopback_addr, 0, sizeof(ipv4_loopback_addr)); @@ -74,19 +73,18 @@ static bool inner_maybe_resolve_localhost_manually_locked( ((char*)&ipv4_loopback_addr.sin_addr)[3] = 0x01; ipv4_loopback_addr.sin_family = AF_INET; ipv4_loopback_addr.sin_port = numeric_port; - grpc_lb_addresses_set_address( - *addrs, 1, &ipv4_loopback_addr, sizeof(ipv4_loopback_addr), - false /* is_balancer */, nullptr /* balancer_name */, - nullptr /* user_data */); + (*addrs)->emplace_back(&ipv4_loopback_addr, sizeof(ipv4_loopback_addr), + nullptr /* args */); // Let the address sorter figure out which one should be tried first. - grpc_cares_wrapper_address_sorting_sort(*addrs); + grpc_cares_wrapper_address_sorting_sort(addrs->get()); return true; } return false; } bool grpc_ares_maybe_resolve_localhost_manually_locked( - const char* name, const char* default_port, grpc_lb_addresses** addrs) { + const char* name, const char* default_port, + grpc_core::UniquePtr<grpc_core::ServerAddressList>* addrs) { char* host = nullptr; char* port = nullptr; bool out = inner_maybe_resolve_localhost_manually_locked(name, default_port, diff --git a/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc b/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc index 65ff1ec1a5..c365f1abfd 100644 --- a/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.cc @@ -26,8 +26,8 @@ #include <grpc/support/string_util.h> #include <grpc/support/time.h> -#include "src/core/ext/filters/client_channel/lb_policy_registry.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/backoff/backoff.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/env.h" @@ -198,18 +198,14 @@ void NativeDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) { grpc_error_set_str(error, GRPC_ERROR_STR_TARGET_ADDRESS, grpc_slice_from_copied_string(r->name_to_resolve_)); if (r->addresses_ != nullptr) { - grpc_lb_addresses* addresses = grpc_lb_addresses_create( - r->addresses_->naddrs, nullptr /* user_data_vtable */); + ServerAddressList addresses; for (size_t i = 0; i < r->addresses_->naddrs; ++i) { - grpc_lb_addresses_set_address( - addresses, i, &r->addresses_->addrs[i].addr, - r->addresses_->addrs[i].len, false /* is_balancer */, - nullptr /* balancer_name */, nullptr /* user_data */); + addresses.emplace_back(&r->addresses_->addrs[i].addr, + r->addresses_->addrs[i].len, nullptr /* args */); } - grpc_arg new_arg = grpc_lb_addresses_create_channel_arg(addresses); + grpc_arg new_arg = CreateServerAddressListChannelArg(&addresses); result = grpc_channel_args_copy_and_add(r->channel_args_, &new_arg, 1); grpc_resolved_addresses_destroy(r->addresses_); - grpc_lb_addresses_destroy(addresses); // Reset backoff state so that we start from the beginning when the // next request gets triggered. r->backoff_.Reset(); diff --git a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc index 144ac24a56..258339491c 100644 --- a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.cc @@ -28,12 +28,13 @@ #include <grpc/support/alloc.h> #include <grpc/support/string_util.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gpr/useful.h" #include "src/core/lib/iomgr/closure.h" #include "src/core/lib/iomgr/combiner.h" #include "src/core/lib/iomgr/resolve_address.h" @@ -103,7 +104,7 @@ void FakeResolver::NextLocked(grpc_channel_args** target_result, } void FakeResolver::RequestReresolutionLocked() { - if (reresolution_results_ != nullptr) { + if (reresolution_results_ != nullptr || return_failure_) { grpc_channel_args_destroy(next_results_); next_results_ = grpc_channel_args_copy(reresolution_results_); MaybeFinishNextLocked(); @@ -141,6 +142,7 @@ struct SetResponseClosureArg { grpc_closure set_response_closure; FakeResolverResponseGenerator* generator; grpc_channel_args* response; + bool immediate = true; }; void FakeResolverResponseGenerator::SetResponseLocked(void* arg, @@ -194,7 +196,7 @@ void FakeResolverResponseGenerator::SetFailureLocked(void* arg, SetResponseClosureArg* closure_arg = static_cast<SetResponseClosureArg*>(arg); FakeResolver* resolver = closure_arg->generator->resolver_; resolver->return_failure_ = true; - resolver->MaybeFinishNextLocked(); + if (closure_arg->immediate) resolver->MaybeFinishNextLocked(); Delete(closure_arg); } @@ -209,6 +211,18 @@ void FakeResolverResponseGenerator::SetFailure() { GRPC_ERROR_NONE); } +void FakeResolverResponseGenerator::SetFailureOnReresolution() { + GPR_ASSERT(resolver_ != nullptr); + SetResponseClosureArg* closure_arg = New<SetResponseClosureArg>(); + closure_arg->generator = this; + closure_arg->immediate = false; + GRPC_CLOSURE_SCHED( + GRPC_CLOSURE_INIT(&closure_arg->set_response_closure, SetFailureLocked, + closure_arg, + grpc_combiner_scheduler(resolver_->combiner())), + GRPC_ERROR_NONE); +} + namespace { static void* response_generator_arg_copy(void* p) { diff --git a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h index 74a3062e7f..d86111c382 100644 --- a/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h +++ b/src/core/ext/filters/client_channel/resolver/fake/fake_resolver.h @@ -19,10 +19,9 @@ #include <grpc/support/port_platform.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gprpp/ref_counted.h" -#include "src/core/lib/uri/uri_parser.h" +#include "src/core/lib/iomgr/error.h" #define GRPC_ARG_FAKE_RESOLVER_RESPONSE_GENERATOR \ "grpc.fake_resolver.response_generator" @@ -61,6 +60,10 @@ class FakeResolverResponseGenerator // returning a null result with no error). void SetFailure(); + // Same as SetFailure(), but instead of returning the error + // immediately, waits for the next call to RequestReresolutionLocked(). + void SetFailureOnReresolution(); + // Returns a channel arg containing \a generator. static grpc_arg MakeChannelArg(FakeResolverResponseGenerator* generator); diff --git a/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc b/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc index 801734764b..1654747a79 100644 --- a/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc +++ b/src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.cc @@ -26,9 +26,9 @@ #include <grpc/support/alloc.h> #include <grpc/support/string_util.h> -#include "src/core/ext/filters/client_channel/lb_policy_factory.h" #include "src/core/ext/filters/client_channel/parse_address.h" #include "src/core/ext/filters/client_channel/resolver_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" @@ -45,7 +45,8 @@ namespace { class SockaddrResolver : public Resolver { public: /// Takes ownership of \a addresses. - SockaddrResolver(const ResolverArgs& args, grpc_lb_addresses* addresses); + SockaddrResolver(const ResolverArgs& args, + UniquePtr<ServerAddressList> addresses); void NextLocked(grpc_channel_args** result, grpc_closure* on_complete) override; @@ -58,7 +59,7 @@ class SockaddrResolver : public Resolver { void MaybeFinishNextLocked(); /// the addresses that we've "resolved" - grpc_lb_addresses* addresses_ = nullptr; + UniquePtr<ServerAddressList> addresses_; /// channel args grpc_channel_args* channel_args_ = nullptr; /// have we published? @@ -70,13 +71,12 @@ class SockaddrResolver : public Resolver { }; SockaddrResolver::SockaddrResolver(const ResolverArgs& args, - grpc_lb_addresses* addresses) + UniquePtr<ServerAddressList> addresses) : Resolver(args.combiner), - addresses_(addresses), + addresses_(std::move(addresses)), channel_args_(grpc_channel_args_copy(args.args)) {} SockaddrResolver::~SockaddrResolver() { - grpc_lb_addresses_destroy(addresses_); grpc_channel_args_destroy(channel_args_); } @@ -100,7 +100,7 @@ void SockaddrResolver::ShutdownLocked() { void SockaddrResolver::MaybeFinishNextLocked() { if (next_completion_ != nullptr && !published_) { published_ = true; - grpc_arg arg = grpc_lb_addresses_create_channel_arg(addresses_); + grpc_arg arg = CreateServerAddressListChannelArg(addresses_.get()); *target_result_ = grpc_channel_args_copy_and_add(channel_args_, &arg, 1); GRPC_CLOSURE_SCHED(next_completion_, GRPC_ERROR_NONE); next_completion_ = nullptr; @@ -127,27 +127,27 @@ OrphanablePtr<Resolver> CreateSockaddrResolver( grpc_slice_buffer path_parts; grpc_slice_buffer_init(&path_parts); grpc_slice_split(path_slice, ",", &path_parts); - grpc_lb_addresses* addresses = grpc_lb_addresses_create( - path_parts.count, nullptr /* user_data_vtable */); + auto addresses = MakeUnique<ServerAddressList>(); bool errors_found = false; - for (size_t i = 0; i < addresses->num_addresses; i++) { + for (size_t i = 0; i < path_parts.count; i++) { grpc_uri ith_uri = *args.uri; - char* part_str = grpc_slice_to_c_string(path_parts.slices[i]); - ith_uri.path = part_str; - if (!parse(&ith_uri, &addresses->addresses[i].address)) { + UniquePtr<char> part_str(grpc_slice_to_c_string(path_parts.slices[i])); + ith_uri.path = part_str.get(); + grpc_resolved_address addr; + if (!parse(&ith_uri, &addr)) { errors_found = true; /* GPR_TRUE */ + break; } - gpr_free(part_str); - if (errors_found) break; + addresses->emplace_back(addr, nullptr /* args */); } grpc_slice_buffer_destroy_internal(&path_parts); grpc_slice_unref_internal(path_slice); if (errors_found) { - grpc_lb_addresses_destroy(addresses); return OrphanablePtr<Resolver>(nullptr); } // Instantiate resolver. - return OrphanablePtr<Resolver>(New<SockaddrResolver>(args, addresses)); + return OrphanablePtr<Resolver>( + New<SockaddrResolver>(args, std::move(addresses))); } class IPv4ResolverFactory : public ResolverFactory { diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.cc b/src/core/ext/filters/client_channel/resolver_result_parsing.cc new file mode 100644 index 0000000000..9a0122e8ec --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver_result_parsing.cc @@ -0,0 +1,377 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/resolver_result_parsing.h" + +#include <ctype.h> +#include <stdio.h> +#include <string.h> + +#include <grpc/support/alloc.h> +#include <grpc/support/log.h> +#include <grpc/support/string_util.h> + +#include "src/core/ext/filters/client_channel/client_channel.h" +#include "src/core/ext/filters/client_channel/lb_policy_registry.h" +#include "src/core/ext/filters/client_channel/server_address.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/uri/uri_parser.h" + +// As per the retry design, we do not allow more than 5 retry attempts. +#define MAX_MAX_RETRY_ATTEMPTS 5 + +namespace grpc_core { +namespace internal { + +ProcessedResolverResult::ProcessedResolverResult( + const grpc_channel_args& resolver_result, bool parse_retry) { + ProcessServiceConfig(resolver_result, parse_retry); + // If no LB config was found above, just find the LB policy name then. + if (lb_policy_name_ == nullptr) ProcessLbPolicyName(resolver_result); +} + +void ProcessedResolverResult::ProcessServiceConfig( + const grpc_channel_args& resolver_result, bool parse_retry) { + const grpc_arg* channel_arg = + grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVICE_CONFIG); + const char* service_config_json = grpc_channel_arg_get_string(channel_arg); + if (service_config_json != nullptr) { + service_config_json_.reset(gpr_strdup(service_config_json)); + service_config_ = grpc_core::ServiceConfig::Create(service_config_json); + if (service_config_ != nullptr) { + if (parse_retry) { + channel_arg = + grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVER_URI); + const char* server_uri = grpc_channel_arg_get_string(channel_arg); + GPR_ASSERT(server_uri != nullptr); + grpc_uri* uri = grpc_uri_parse(server_uri, true); + GPR_ASSERT(uri->path[0] != '\0'); + server_name_ = uri->path[0] == '/' ? uri->path + 1 : uri->path; + service_config_->ParseGlobalParams(ParseServiceConfig, this); + grpc_uri_destroy(uri); + } else { + service_config_->ParseGlobalParams(ParseServiceConfig, this); + } + method_params_table_ = service_config_->CreateMethodConfigTable( + ClientChannelMethodParams::CreateFromJson); + } + } +} + +void ProcessedResolverResult::ProcessLbPolicyName( + const grpc_channel_args& resolver_result) { + // Prefer the LB policy name found in the service config. Note that this is + // checking the deprecated loadBalancingPolicy field, rather than the new + // loadBalancingConfig field. + if (service_config_ != nullptr) { + lb_policy_name_.reset( + gpr_strdup(service_config_->GetLoadBalancingPolicyName())); + // Convert to lower-case. + if (lb_policy_name_ != nullptr) { + char* lb_policy_name = lb_policy_name_.get(); + for (size_t i = 0; i < strlen(lb_policy_name); ++i) { + lb_policy_name[i] = tolower(lb_policy_name[i]); + } + } + } + // Otherwise, find the LB policy name set by the client API. + if (lb_policy_name_ == nullptr) { + const grpc_arg* channel_arg = + grpc_channel_args_find(&resolver_result, GRPC_ARG_LB_POLICY_NAME); + lb_policy_name_.reset(gpr_strdup(grpc_channel_arg_get_string(channel_arg))); + } + // Special case: If at least one balancer address is present, we use + // the grpclb policy, regardless of what the resolver has returned. + const ServerAddressList* addresses = + FindServerAddressListChannelArg(&resolver_result); + if (addresses != nullptr) { + bool found_balancer_address = false; + for (size_t i = 0; i < addresses->size(); ++i) { + const ServerAddress& address = (*addresses)[i]; + if (address.IsBalancer()) { + found_balancer_address = true; + break; + } + } + if (found_balancer_address) { + if (lb_policy_name_ != nullptr && + strcmp(lb_policy_name_.get(), "grpclb") != 0) { + gpr_log(GPR_INFO, + "resolver requested LB policy %s but provided at least one " + "balancer address -- forcing use of grpclb LB policy", + lb_policy_name_.get()); + } + lb_policy_name_.reset(gpr_strdup("grpclb")); + } + } + // Use pick_first if nothing was specified and we didn't select grpclb + // above. + if (lb_policy_name_ == nullptr) { + lb_policy_name_.reset(gpr_strdup("pick_first")); + } +} + +void ProcessedResolverResult::ParseServiceConfig( + const grpc_json* field, ProcessedResolverResult* parsing_state) { + parsing_state->ParseLbConfigFromServiceConfig(field); + if (parsing_state->server_name_ != nullptr) { + parsing_state->ParseRetryThrottleParamsFromServiceConfig(field); + } +} + +void ProcessedResolverResult::ParseLbConfigFromServiceConfig( + const grpc_json* field) { + if (lb_policy_config_ != nullptr) return; // Already found. + // Find the LB config global parameter. + if (field->key == nullptr || strcmp(field->key, "loadBalancingConfig") != 0 || + field->type != GRPC_JSON_ARRAY) { + return; // Not valid lb config array. + } + // Find the first LB policy that this client supports. + for (grpc_json* lb_config = field->child; lb_config != nullptr; + lb_config = lb_config->next) { + if (lb_config->type != GRPC_JSON_OBJECT) return; + // Find the policy object. + grpc_json* policy = nullptr; + for (grpc_json* field = lb_config->child; field != nullptr; + field = field->next) { + if (field->key == nullptr || strcmp(field->key, "policy") != 0 || + field->type != GRPC_JSON_OBJECT) { + return; + } + if (policy != nullptr) return; // Duplicate. + policy = field; + } + // Find the specific policy content since the policy object is of type + // "oneof". + grpc_json* policy_content = nullptr; + for (grpc_json* field = policy->child; field != nullptr; + field = field->next) { + if (field->key == nullptr || field->type != GRPC_JSON_OBJECT) return; + if (policy_content != nullptr) return; // Violate "oneof" type. + policy_content = field; + } + // If we support this policy, then select it. + if (grpc_core::LoadBalancingPolicyRegistry::LoadBalancingPolicyExists( + policy_content->key)) { + lb_policy_name_.reset(gpr_strdup(policy_content->key)); + lb_policy_config_ = policy_content->child; + return; + } + } +} + +void ProcessedResolverResult::ParseRetryThrottleParamsFromServiceConfig( + const grpc_json* field) { + if (strcmp(field->key, "retryThrottling") == 0) { + if (retry_throttle_data_ != nullptr) return; // Duplicate. + if (field->type != GRPC_JSON_OBJECT) return; + int max_milli_tokens = 0; + int milli_token_ratio = 0; + for (grpc_json* sub_field = field->child; sub_field != nullptr; + sub_field = sub_field->next) { + if (sub_field->key == nullptr) return; + if (strcmp(sub_field->key, "maxTokens") == 0) { + if (max_milli_tokens != 0) return; // Duplicate. + if (sub_field->type != GRPC_JSON_NUMBER) return; + max_milli_tokens = gpr_parse_nonnegative_int(sub_field->value); + if (max_milli_tokens == -1) return; + max_milli_tokens *= 1000; + } else if (strcmp(sub_field->key, "tokenRatio") == 0) { + if (milli_token_ratio != 0) return; // Duplicate. + if (sub_field->type != GRPC_JSON_NUMBER) return; + // We support up to 3 decimal digits. + size_t whole_len = strlen(sub_field->value); + uint32_t multiplier = 1; + uint32_t decimal_value = 0; + const char* decimal_point = strchr(sub_field->value, '.'); + if (decimal_point != nullptr) { + whole_len = static_cast<size_t>(decimal_point - sub_field->value); + multiplier = 1000; + size_t decimal_len = strlen(decimal_point + 1); + if (decimal_len > 3) decimal_len = 3; + if (!gpr_parse_bytes_to_uint32(decimal_point + 1, decimal_len, + &decimal_value)) { + return; + } + uint32_t decimal_multiplier = 1; + for (size_t i = 0; i < (3 - decimal_len); ++i) { + decimal_multiplier *= 10; + } + decimal_value *= decimal_multiplier; + } + uint32_t whole_value; + if (!gpr_parse_bytes_to_uint32(sub_field->value, whole_len, + &whole_value)) { + return; + } + milli_token_ratio = + static_cast<int>((whole_value * multiplier) + decimal_value); + if (milli_token_ratio <= 0) return; + } + } + retry_throttle_data_ = + grpc_core::internal::ServerRetryThrottleMap::GetDataForServer( + server_name_, max_milli_tokens, milli_token_ratio); + } +} + +namespace { + +bool ParseWaitForReady( + grpc_json* field, ClientChannelMethodParams::WaitForReady* wait_for_ready) { + if (field->type != GRPC_JSON_TRUE && field->type != GRPC_JSON_FALSE) { + return false; + } + *wait_for_ready = field->type == GRPC_JSON_TRUE + ? ClientChannelMethodParams::WAIT_FOR_READY_TRUE + : ClientChannelMethodParams::WAIT_FOR_READY_FALSE; + return true; +} + +// Parses a JSON field of the form generated for a google.proto.Duration +// proto message, as per: +// https://developers.google.com/protocol-buffers/docs/proto3#json +bool ParseDuration(grpc_json* field, grpc_millis* duration) { + if (field->type != GRPC_JSON_STRING) return false; + size_t len = strlen(field->value); + if (field->value[len - 1] != 's') return false; + UniquePtr<char> buf(gpr_strdup(field->value)); + *(buf.get() + len - 1) = '\0'; // Remove trailing 's'. + char* decimal_point = strchr(buf.get(), '.'); + int nanos = 0; + if (decimal_point != nullptr) { + *decimal_point = '\0'; + nanos = gpr_parse_nonnegative_int(decimal_point + 1); + if (nanos == -1) { + return false; + } + int num_digits = static_cast<int>(strlen(decimal_point + 1)); + if (num_digits > 9) { // We don't accept greater precision than nanos. + return false; + } + for (int i = 0; i < (9 - num_digits); ++i) { + nanos *= 10; + } + } + int seconds = + decimal_point == buf.get() ? 0 : gpr_parse_nonnegative_int(buf.get()); + if (seconds == -1) return false; + *duration = seconds * GPR_MS_PER_SEC + nanos / GPR_NS_PER_MS; + return true; +} + +UniquePtr<ClientChannelMethodParams::RetryPolicy> ParseRetryPolicy( + grpc_json* field) { + auto retry_policy = MakeUnique<ClientChannelMethodParams::RetryPolicy>(); + if (field->type != GRPC_JSON_OBJECT) return nullptr; + for (grpc_json* sub_field = field->child; sub_field != nullptr; + sub_field = sub_field->next) { + if (sub_field->key == nullptr) return nullptr; + if (strcmp(sub_field->key, "maxAttempts") == 0) { + if (retry_policy->max_attempts != 0) return nullptr; // Duplicate. + if (sub_field->type != GRPC_JSON_NUMBER) return nullptr; + retry_policy->max_attempts = gpr_parse_nonnegative_int(sub_field->value); + if (retry_policy->max_attempts <= 1) return nullptr; + if (retry_policy->max_attempts > MAX_MAX_RETRY_ATTEMPTS) { + gpr_log(GPR_ERROR, + "service config: clamped retryPolicy.maxAttempts at %d", + MAX_MAX_RETRY_ATTEMPTS); + retry_policy->max_attempts = MAX_MAX_RETRY_ATTEMPTS; + } + } else if (strcmp(sub_field->key, "initialBackoff") == 0) { + if (retry_policy->initial_backoff > 0) return nullptr; // Duplicate. + if (!ParseDuration(sub_field, &retry_policy->initial_backoff)) { + return nullptr; + } + if (retry_policy->initial_backoff == 0) return nullptr; + } else if (strcmp(sub_field->key, "maxBackoff") == 0) { + if (retry_policy->max_backoff > 0) return nullptr; // Duplicate. + if (!ParseDuration(sub_field, &retry_policy->max_backoff)) { + return nullptr; + } + if (retry_policy->max_backoff == 0) return nullptr; + } else if (strcmp(sub_field->key, "backoffMultiplier") == 0) { + if (retry_policy->backoff_multiplier != 0) return nullptr; // Duplicate. + if (sub_field->type != GRPC_JSON_NUMBER) return nullptr; + if (sscanf(sub_field->value, "%f", &retry_policy->backoff_multiplier) != + 1) { + return nullptr; + } + if (retry_policy->backoff_multiplier <= 0) return nullptr; + } else if (strcmp(sub_field->key, "retryableStatusCodes") == 0) { + if (!retry_policy->retryable_status_codes.Empty()) { + return nullptr; // Duplicate. + } + if (sub_field->type != GRPC_JSON_ARRAY) return nullptr; + for (grpc_json* element = sub_field->child; element != nullptr; + element = element->next) { + if (element->type != GRPC_JSON_STRING) return nullptr; + grpc_status_code status; + if (!grpc_status_code_from_string(element->value, &status)) { + return nullptr; + } + retry_policy->retryable_status_codes.Add(status); + } + if (retry_policy->retryable_status_codes.Empty()) return nullptr; + } + } + // Make sure required fields are set. + if (retry_policy->max_attempts == 0 || retry_policy->initial_backoff == 0 || + retry_policy->max_backoff == 0 || retry_policy->backoff_multiplier == 0 || + retry_policy->retryable_status_codes.Empty()) { + return nullptr; + } + return retry_policy; +} + +} // namespace + +RefCountedPtr<ClientChannelMethodParams> +ClientChannelMethodParams::CreateFromJson(const grpc_json* json) { + RefCountedPtr<ClientChannelMethodParams> method_params = + MakeRefCounted<ClientChannelMethodParams>(); + for (grpc_json* field = json->child; field != nullptr; field = field->next) { + if (field->key == nullptr) continue; + if (strcmp(field->key, "waitForReady") == 0) { + if (method_params->wait_for_ready_ != WAIT_FOR_READY_UNSET) { + return nullptr; // Duplicate. + } + if (!ParseWaitForReady(field, &method_params->wait_for_ready_)) { + return nullptr; + } + } else if (strcmp(field->key, "timeout") == 0) { + if (method_params->timeout_ > 0) return nullptr; // Duplicate. + if (!ParseDuration(field, &method_params->timeout_)) return nullptr; + } else if (strcmp(field->key, "retryPolicy") == 0) { + if (method_params->retry_policy_ != nullptr) { + return nullptr; // Duplicate. + } + method_params->retry_policy_ = ParseRetryPolicy(field); + if (method_params->retry_policy_ == nullptr) return nullptr; + } + } + return method_params; +} + +} // namespace internal +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.h b/src/core/ext/filters/client_channel/resolver_result_parsing.h new file mode 100644 index 0000000000..98a9d26c46 --- /dev/null +++ b/src/core/ext/filters/client_channel/resolver_result_parsing.h @@ -0,0 +1,142 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_RESULT_PARSING_H +#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_RESULT_PARSING_H + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/retry_throttle.h" +#include "src/core/lib/channel/status_util.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/iomgr/exec_ctx.h" // for grpc_millis +#include "src/core/lib/json/json.h" +#include "src/core/lib/slice/slice_hash_table.h" +#include "src/core/lib/transport/service_config.h" + +namespace grpc_core { +namespace internal { + +class ClientChannelMethodParams; + +// A table mapping from a method name to its method parameters. +typedef SliceHashTable<RefCountedPtr<ClientChannelMethodParams>> + ClientChannelMethodParamsTable; + +// A container of processed fields from the resolver result. Simplifies the +// usage of resolver result. +class ProcessedResolverResult { + public: + // Processes the resolver result and populates the relative members + // for later consumption. Tries to parse retry parameters only if parse_retry + // is true. + ProcessedResolverResult(const grpc_channel_args& resolver_result, + bool parse_retry); + + // Getters. Any managed object's ownership is transferred. + UniquePtr<char> service_config_json() { + return std::move(service_config_json_); + } + RefCountedPtr<ServerRetryThrottleData> retry_throttle_data() { + return std::move(retry_throttle_data_); + } + RefCountedPtr<ClientChannelMethodParamsTable> method_params_table() { + return std::move(method_params_table_); + } + UniquePtr<char> lb_policy_name() { return std::move(lb_policy_name_); } + grpc_json* lb_policy_config() { return lb_policy_config_; } + + private: + // Finds the service config; extracts LB config and (maybe) retry throttle + // params from it. + void ProcessServiceConfig(const grpc_channel_args& resolver_result, + bool parse_retry); + + // Finds the LB policy name (when no LB config was found). + void ProcessLbPolicyName(const grpc_channel_args& resolver_result); + + // Parses the service config. Intended to be used by + // ServiceConfig::ParseGlobalParams. + static void ParseServiceConfig(const grpc_json* field, + ProcessedResolverResult* parsing_state); + // Parses the LB config from service config. + void ParseLbConfigFromServiceConfig(const grpc_json* field); + // Parses the retry throttle parameters from service config. + void ParseRetryThrottleParamsFromServiceConfig(const grpc_json* field); + + // Service config. + UniquePtr<char> service_config_json_; + UniquePtr<grpc_core::ServiceConfig> service_config_; + // LB policy. + grpc_json* lb_policy_config_ = nullptr; + UniquePtr<char> lb_policy_name_; + // Retry throttle data. + char* server_name_ = nullptr; + RefCountedPtr<ServerRetryThrottleData> retry_throttle_data_; + // Method params table. + RefCountedPtr<ClientChannelMethodParamsTable> method_params_table_; +}; + +// The parameters of a method. +class ClientChannelMethodParams : public RefCounted<ClientChannelMethodParams> { + public: + enum WaitForReady { + WAIT_FOR_READY_UNSET = 0, + WAIT_FOR_READY_FALSE, + WAIT_FOR_READY_TRUE + }; + + struct RetryPolicy { + int max_attempts = 0; + grpc_millis initial_backoff = 0; + grpc_millis max_backoff = 0; + float backoff_multiplier = 0; + StatusCodeSet retryable_status_codes; + }; + + /// Creates a method_parameters object from \a json. + /// Intended for use with ServiceConfig::CreateMethodConfigTable(). + static RefCountedPtr<ClientChannelMethodParams> CreateFromJson( + const grpc_json* json); + + grpc_millis timeout() const { return timeout_; } + WaitForReady wait_for_ready() const { return wait_for_ready_; } + const RetryPolicy* retry_policy() const { return retry_policy_.get(); } + + private: + // So New() can call our private ctor. + template <typename T, typename... Args> + friend T* grpc_core::New(Args&&... args); + + // So Delete() can call our private dtor. + template <typename T> + friend void grpc_core::Delete(T*); + + ClientChannelMethodParams() {} + virtual ~ClientChannelMethodParams() {} + + grpc_millis timeout_ = 0; + WaitForReady wait_for_ready_ = WAIT_FOR_READY_UNSET; + UniquePtr<RetryPolicy> retry_policy_; +}; + +} // namespace internal +} // namespace grpc_core + +#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_RESOLVER_RESULT_PARSING_H */ diff --git a/src/core/ext/filters/client_channel/server_address.cc b/src/core/ext/filters/client_channel/server_address.cc new file mode 100644 index 0000000000..ec33cbbd95 --- /dev/null +++ b/src/core/ext/filters/client_channel/server_address.cc @@ -0,0 +1,103 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/filters/client_channel/server_address.h" + +#include <string.h> + +namespace grpc_core { + +// +// ServerAddress +// + +ServerAddress::ServerAddress(const grpc_resolved_address& address, + grpc_channel_args* args) + : address_(address), args_(args) {} + +ServerAddress::ServerAddress(const void* address, size_t address_len, + grpc_channel_args* args) + : args_(args) { + memcpy(address_.addr, address, address_len); + address_.len = static_cast<socklen_t>(address_len); +} + +int ServerAddress::Cmp(const ServerAddress& other) const { + if (address_.len > other.address_.len) return 1; + if (address_.len < other.address_.len) return -1; + int retval = memcmp(address_.addr, other.address_.addr, address_.len); + if (retval != 0) return retval; + return grpc_channel_args_compare(args_, other.args_); +} + +bool ServerAddress::IsBalancer() const { + return grpc_channel_arg_get_bool( + grpc_channel_args_find(args_, GRPC_ARG_ADDRESS_IS_BALANCER), false); +} + +// +// ServerAddressList +// + +namespace { + +void* ServerAddressListCopy(void* addresses) { + ServerAddressList* a = static_cast<ServerAddressList*>(addresses); + return New<ServerAddressList>(*a); +} + +void ServerAddressListDestroy(void* addresses) { + ServerAddressList* a = static_cast<ServerAddressList*>(addresses); + Delete(a); +} + +int ServerAddressListCompare(void* addresses1, void* addresses2) { + ServerAddressList* a1 = static_cast<ServerAddressList*>(addresses1); + ServerAddressList* a2 = static_cast<ServerAddressList*>(addresses2); + if (a1->size() > a2->size()) return 1; + if (a1->size() < a2->size()) return -1; + for (size_t i = 0; i < a1->size(); ++i) { + int retval = (*a1)[i].Cmp((*a2)[i]); + if (retval != 0) return retval; + } + return 0; +} + +const grpc_arg_pointer_vtable server_addresses_arg_vtable = { + ServerAddressListCopy, ServerAddressListDestroy, ServerAddressListCompare}; + +} // namespace + +grpc_arg CreateServerAddressListChannelArg(const ServerAddressList* addresses) { + return grpc_channel_arg_pointer_create( + const_cast<char*>(GRPC_ARG_SERVER_ADDRESS_LIST), + const_cast<ServerAddressList*>(addresses), &server_addresses_arg_vtable); +} + +ServerAddressList* FindServerAddressListChannelArg( + const grpc_channel_args* channel_args) { + const grpc_arg* lb_addresses_arg = + grpc_channel_args_find(channel_args, GRPC_ARG_SERVER_ADDRESS_LIST); + if (lb_addresses_arg == nullptr || lb_addresses_arg->type != GRPC_ARG_POINTER) + return nullptr; + return static_cast<ServerAddressList*>(lb_addresses_arg->value.pointer.p); +} + +} // namespace grpc_core diff --git a/src/core/ext/filters/client_channel/server_address.h b/src/core/ext/filters/client_channel/server_address.h new file mode 100644 index 0000000000..3a1bf1df67 --- /dev/null +++ b/src/core/ext/filters/client_channel/server_address.h @@ -0,0 +1,108 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_SERVER_ADDRESS_H +#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_SERVER_ADDRESS_H + +#include <grpc/support/port_platform.h> + +#include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/uri/uri_parser.h" + +// Channel arg key for ServerAddressList. +#define GRPC_ARG_SERVER_ADDRESS_LIST "grpc.server_address_list" + +// Channel arg key for a bool indicating whether an address is a grpclb +// load balancer (as opposed to a backend). +#define GRPC_ARG_ADDRESS_IS_BALANCER "grpc.address_is_balancer" + +// Channel arg key for a string indicating an address's balancer name. +#define GRPC_ARG_ADDRESS_BALANCER_NAME "grpc.address_balancer_name" + +namespace grpc_core { + +// +// ServerAddress +// + +// A server address is a grpc_resolved_address with an associated set of +// channel args. Any args present here will be merged into the channel +// args when a subchannel is created for this address. +class ServerAddress { + public: + // Takes ownership of args. + ServerAddress(const grpc_resolved_address& address, grpc_channel_args* args); + ServerAddress(const void* address, size_t address_len, + grpc_channel_args* args); + + ~ServerAddress() { grpc_channel_args_destroy(args_); } + + // Copyable. + ServerAddress(const ServerAddress& other) + : address_(other.address_), args_(grpc_channel_args_copy(other.args_)) {} + ServerAddress& operator=(const ServerAddress& other) { + address_ = other.address_; + grpc_channel_args_destroy(args_); + args_ = grpc_channel_args_copy(other.args_); + return *this; + } + + // Movable. + ServerAddress(ServerAddress&& other) + : address_(other.address_), args_(other.args_) { + other.args_ = nullptr; + } + ServerAddress& operator=(ServerAddress&& other) { + address_ = other.address_; + args_ = other.args_; + other.args_ = nullptr; + return *this; + } + + bool operator==(const ServerAddress& other) const { return Cmp(other) == 0; } + + int Cmp(const ServerAddress& other) const; + + const grpc_resolved_address& address() const { return address_; } + const grpc_channel_args* args() const { return args_; } + + bool IsBalancer() const; + + private: + grpc_resolved_address address_; + grpc_channel_args* args_; +}; + +// +// ServerAddressList +// + +typedef InlinedVector<ServerAddress, 1> ServerAddressList; + +// Returns a channel arg containing \a addresses. +grpc_arg CreateServerAddressListChannelArg(const ServerAddressList* addresses); + +// Returns the ServerListAddress instance in channel_args or NULL. +ServerAddressList* FindServerAddressListChannelArg( + const grpc_channel_args* channel_args); + +} // namespace grpc_core + +#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_SERVER_ADDRESS_H */ diff --git a/src/core/ext/filters/client_channel/subchannel.cc b/src/core/ext/filters/client_channel/subchannel.cc index a56db0201b..9077aa9753 100644 --- a/src/core/ext/filters/client_channel/subchannel.cc +++ b/src/core/ext/filters/client_channel/subchannel.cc @@ -153,7 +153,7 @@ struct grpc_subchannel { /** have we started the backoff loop */ bool backoff_begun; // reset_backoff() was called while alarm was pending - bool deferred_reset_backoff; + bool retry_immediately; /** our alarm */ grpc_timer alarm; @@ -709,8 +709,8 @@ static void on_alarm(void* arg, grpc_error* error) { if (c->disconnected) { error = GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING("Disconnected", &error, 1); - } else if (c->deferred_reset_backoff) { - c->deferred_reset_backoff = false; + } else if (c->retry_immediately) { + c->retry_immediately = false; error = GRPC_ERROR_NONE; } else { GRPC_ERROR_REF(error); @@ -837,7 +837,7 @@ static bool publish_transport_locked(grpc_subchannel* c) { /* publish */ c->connected_subchannel.reset(grpc_core::New<grpc_core::ConnectedSubchannel>( - stk, c->channelz_subchannel, socket_uuid)); + stk, c->args, c->channelz_subchannel, socket_uuid)); gpr_log(GPR_INFO, "New connected subchannel at %p for subchannel %p", c->connected_subchannel.get(), c); @@ -887,12 +887,12 @@ static void on_subchannel_connected(void* arg, grpc_error* error) { void grpc_subchannel_reset_backoff(grpc_subchannel* subchannel) { gpr_mu_lock(&subchannel->mu); + subchannel->backoff->Reset(); if (subchannel->have_alarm) { - subchannel->deferred_reset_backoff = true; + subchannel->retry_immediately = true; grpc_timer_cancel(&subchannel->alarm); } else { subchannel->backoff_begun = false; - subchannel->backoff->Reset(); maybe_start_connecting_locked(subchannel); } gpr_mu_unlock(&subchannel->mu); @@ -1068,16 +1068,18 @@ grpc_arg grpc_create_subchannel_address_arg(const grpc_resolved_address* addr) { namespace grpc_core { ConnectedSubchannel::ConnectedSubchannel( - grpc_channel_stack* channel_stack, + grpc_channel_stack* channel_stack, const grpc_channel_args* args, grpc_core::RefCountedPtr<grpc_core::channelz::SubchannelNode> channelz_subchannel, intptr_t socket_uuid) - : RefCountedWithTracing<ConnectedSubchannel>(&grpc_trace_stream_refcount), + : RefCounted<ConnectedSubchannel>(&grpc_trace_stream_refcount), channel_stack_(channel_stack), + args_(grpc_channel_args_copy(args)), channelz_subchannel_(std::move(channelz_subchannel)), socket_uuid_(socket_uuid) {} ConnectedSubchannel::~ConnectedSubchannel() { + grpc_channel_args_destroy(args_); GRPC_CHANNEL_STACK_UNREF(channel_stack_, "connected_subchannel_dtor"); } diff --git a/src/core/ext/filters/client_channel/subchannel.h b/src/core/ext/filters/client_channel/subchannel.h index ec3b4d86e4..14f87f2c68 100644 --- a/src/core/ext/filters/client_channel/subchannel.h +++ b/src/core/ext/filters/client_channel/subchannel.h @@ -72,7 +72,7 @@ typedef struct grpc_subchannel_key grpc_subchannel_key; namespace grpc_core { -class ConnectedSubchannel : public RefCountedWithTracing<ConnectedSubchannel> { +class ConnectedSubchannel : public RefCounted<ConnectedSubchannel> { public: struct CallArgs { grpc_polling_entity* pollent; @@ -85,28 +85,31 @@ class ConnectedSubchannel : public RefCountedWithTracing<ConnectedSubchannel> { size_t parent_data_size; }; - explicit ConnectedSubchannel( - grpc_channel_stack* channel_stack, + ConnectedSubchannel( + grpc_channel_stack* channel_stack, const grpc_channel_args* args, grpc_core::RefCountedPtr<grpc_core::channelz::SubchannelNode> channelz_subchannel, intptr_t socket_uuid); ~ConnectedSubchannel(); - grpc_channel_stack* channel_stack() { return channel_stack_; } void NotifyOnStateChange(grpc_pollset_set* interested_parties, grpc_connectivity_state* state, grpc_closure* closure); void Ping(grpc_closure* on_initiate, grpc_closure* on_ack); grpc_error* CreateCall(const CallArgs& args, grpc_subchannel_call** call); - channelz::SubchannelNode* channelz_subchannel() { + + grpc_channel_stack* channel_stack() const { return channel_stack_; } + const grpc_channel_args* args() const { return args_; } + channelz::SubchannelNode* channelz_subchannel() const { return channelz_subchannel_.get(); } - intptr_t socket_uuid() { return socket_uuid_; } + intptr_t socket_uuid() const { return socket_uuid_; } size_t GetInitialCallSizeEstimate(size_t parent_data_size) const; private: grpc_channel_stack* channel_stack_; + grpc_channel_args* args_; // ref counted pointer to the channelz node in this connected subchannel's // owning subchannel. grpc_core::RefCountedPtr<grpc_core::channelz::SubchannelNode> diff --git a/src/core/ext/filters/client_channel/subchannel_index.cc b/src/core/ext/filters/client_channel/subchannel_index.cc index 1c23a6c4be..aa8441f17b 100644 --- a/src/core/ext/filters/client_channel/subchannel_index.cc +++ b/src/core/ext/filters/client_channel/subchannel_index.cc @@ -91,7 +91,7 @@ void grpc_subchannel_key_destroy(grpc_subchannel_key* k) { gpr_free(k); } -static void sck_avl_destroy(void* p, void* user_data) { +static void sck_avl_destroy(void* p, void* unused) { grpc_subchannel_key_destroy(static_cast<grpc_subchannel_key*>(p)); } @@ -104,7 +104,7 @@ static long sck_avl_compare(void* a, void* b, void* unused) { static_cast<grpc_subchannel_key*>(b)); } -static void scv_avl_destroy(void* p, void* user_data) { +static void scv_avl_destroy(void* p, void* unused) { GRPC_SUBCHANNEL_WEAK_UNREF((grpc_subchannel*)p, "subchannel_index"); } @@ -137,7 +137,7 @@ void grpc_subchannel_index_shutdown(void) { void grpc_subchannel_index_unref(void) { if (gpr_unref(&g_refcount)) { gpr_mu_destroy(&g_mu); - grpc_avl_unref(g_subchannel_index, grpc_core::ExecCtx::Get()); + grpc_avl_unref(g_subchannel_index, nullptr); } } @@ -147,13 +147,12 @@ grpc_subchannel* grpc_subchannel_index_find(grpc_subchannel_key* key) { // Lock, and take a reference to the subchannel index. // We don't need to do the search under a lock as avl's are immutable. gpr_mu_lock(&g_mu); - grpc_avl index = grpc_avl_ref(g_subchannel_index, grpc_core::ExecCtx::Get()); + grpc_avl index = grpc_avl_ref(g_subchannel_index, nullptr); gpr_mu_unlock(&g_mu); grpc_subchannel* c = GRPC_SUBCHANNEL_REF_FROM_WEAK_REF( - (grpc_subchannel*)grpc_avl_get(index, key, grpc_core::ExecCtx::Get()), - "index_find"); - grpc_avl_unref(index, grpc_core::ExecCtx::Get()); + (grpc_subchannel*)grpc_avl_get(index, key, nullptr), "index_find"); + grpc_avl_unref(index, nullptr); return c; } @@ -169,13 +168,11 @@ grpc_subchannel* grpc_subchannel_index_register(grpc_subchannel_key* key, // Compare and swap loop: // - take a reference to the current index gpr_mu_lock(&g_mu); - grpc_avl index = - grpc_avl_ref(g_subchannel_index, grpc_core::ExecCtx::Get()); + grpc_avl index = grpc_avl_ref(g_subchannel_index, nullptr); gpr_mu_unlock(&g_mu); // - Check to see if a subchannel already exists - c = static_cast<grpc_subchannel*>( - grpc_avl_get(index, key, grpc_core::ExecCtx::Get())); + c = static_cast<grpc_subchannel*>(grpc_avl_get(index, key, nullptr)); if (c != nullptr) { c = GRPC_SUBCHANNEL_REF_FROM_WEAK_REF(c, "index_register"); } @@ -184,11 +181,9 @@ grpc_subchannel* grpc_subchannel_index_register(grpc_subchannel_key* key, need_to_unref_constructed = true; } else { // no -> update the avl and compare/swap - grpc_avl updated = - grpc_avl_add(grpc_avl_ref(index, grpc_core::ExecCtx::Get()), - subchannel_key_copy(key), - GRPC_SUBCHANNEL_WEAK_REF(constructed, "index_register"), - grpc_core::ExecCtx::Get()); + grpc_avl updated = grpc_avl_add( + grpc_avl_ref(index, nullptr), subchannel_key_copy(key), + GRPC_SUBCHANNEL_WEAK_REF(constructed, "index_register"), nullptr); // it may happen (but it's expected to be unlikely) // that some other thread has changed the index: @@ -200,9 +195,9 @@ grpc_subchannel* grpc_subchannel_index_register(grpc_subchannel_key* key, } gpr_mu_unlock(&g_mu); - grpc_avl_unref(updated, grpc_core::ExecCtx::Get()); + grpc_avl_unref(updated, nullptr); } - grpc_avl_unref(index, grpc_core::ExecCtx::Get()); + grpc_avl_unref(index, nullptr); } if (need_to_unref_constructed) { @@ -219,24 +214,22 @@ void grpc_subchannel_index_unregister(grpc_subchannel_key* key, // Compare and swap loop: // - take a reference to the current index gpr_mu_lock(&g_mu); - grpc_avl index = - grpc_avl_ref(g_subchannel_index, grpc_core::ExecCtx::Get()); + grpc_avl index = grpc_avl_ref(g_subchannel_index, nullptr); gpr_mu_unlock(&g_mu); // Check to see if this key still refers to the previously // registered subchannel - grpc_subchannel* c = static_cast<grpc_subchannel*>( - grpc_avl_get(index, key, grpc_core::ExecCtx::Get())); + grpc_subchannel* c = + static_cast<grpc_subchannel*>(grpc_avl_get(index, key, nullptr)); if (c != constructed) { - grpc_avl_unref(index, grpc_core::ExecCtx::Get()); + grpc_avl_unref(index, nullptr); break; } // compare and swap the update (some other thread may have // mutated the index behind us) grpc_avl updated = - grpc_avl_remove(grpc_avl_ref(index, grpc_core::ExecCtx::Get()), key, - grpc_core::ExecCtx::Get()); + grpc_avl_remove(grpc_avl_ref(index, nullptr), key, nullptr); gpr_mu_lock(&g_mu); if (index.root == g_subchannel_index.root) { @@ -245,8 +238,8 @@ void grpc_subchannel_index_unregister(grpc_subchannel_key* key, } gpr_mu_unlock(&g_mu); - grpc_avl_unref(updated, grpc_core::ExecCtx::Get()); - grpc_avl_unref(index, grpc_core::ExecCtx::Get()); + grpc_avl_unref(updated, nullptr); + grpc_avl_unref(index, nullptr); } } diff --git a/src/core/ext/transport/chttp2/client/chttp2_connector.cc b/src/core/ext/transport/chttp2/client/chttp2_connector.cc index 60a32022f5..42a2e2e896 100644 --- a/src/core/ext/transport/chttp2/client/chttp2_connector.cc +++ b/src/core/ext/transport/chttp2/client/chttp2_connector.cc @@ -117,8 +117,9 @@ static void on_handshake_done(void* arg, grpc_error* error) { c->args.interested_parties); c->result->transport = grpc_create_chttp2_transport(args->args, args->endpoint, true); - c->result->socket_uuid = - grpc_chttp2_transport_get_socket_uuid(c->result->transport); + grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> socket_node = + grpc_chttp2_transport_get_socket_node(c->result->transport); + c->result->socket_uuid = socket_node == nullptr ? 0 : socket_node->uuid(); GPR_ASSERT(c->result->transport); // TODO(roth): We ideally want to wait until we receive HTTP/2 // settings from the server before we consider the connection diff --git a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc index e73eee4353..9612698e96 100644 --- a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc +++ b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc @@ -110,14 +110,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( grpc_channel_args* args_with_authority = grpc_channel_args_copy_and_add(args->args, args_to_add, num_args_to_add); grpc_uri_destroy(server_uri); - grpc_channel_security_connector* subchannel_security_connector = nullptr; // Create the security connector using the credentials and target name. grpc_channel_args* new_args_from_connector = nullptr; - const grpc_security_status security_status = - grpc_channel_credentials_create_security_connector( - channel_credentials, authority.get(), args_with_authority, - &subchannel_security_connector, &new_args_from_connector); - if (security_status != GRPC_SECURITY_OK) { + grpc_core::RefCountedPtr<grpc_channel_security_connector> + subchannel_security_connector = + channel_credentials->create_security_connector( + /*call_creds=*/nullptr, authority.get(), args_with_authority, + &new_args_from_connector); + if (subchannel_security_connector == nullptr) { gpr_log(GPR_ERROR, "Failed to create secure subchannel for secure name '%s'", authority.get()); @@ -125,15 +125,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( return nullptr; } grpc_arg new_security_connector_arg = - grpc_security_connector_to_arg(&subchannel_security_connector->base); + grpc_security_connector_to_arg(subchannel_security_connector.get()); grpc_channel_args* new_args = grpc_channel_args_copy_and_add( new_args_from_connector != nullptr ? new_args_from_connector : args_with_authority, &new_security_connector_arg, 1); - GRPC_SECURITY_CONNECTOR_UNREF(&subchannel_security_connector->base, - "lb_channel_create"); + subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create"); if (new_args_from_connector != nullptr) { grpc_channel_args_destroy(new_args_from_connector); } diff --git a/src/core/ext/transport/chttp2/server/chttp2_server.cc b/src/core/ext/transport/chttp2/server/chttp2_server.cc index 33d2b22aa5..3d09187b9b 100644 --- a/src/core/ext/transport/chttp2/server/chttp2_server.cc +++ b/src/core/ext/transport/chttp2/server/chttp2_server.cc @@ -149,7 +149,7 @@ static void on_handshake_done(void* arg, grpc_error* error) { grpc_server_setup_transport( connection_state->svr_state->server, transport, connection_state->accepting_pollset, args->args, - grpc_chttp2_transport_get_socket_uuid(transport), resource_user); + grpc_chttp2_transport_get_socket_node(transport), resource_user); // Use notify_on_receive_settings callback to enforce the // handshake deadline. connection_state->transport = diff --git a/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc b/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc index b9024a87e2..c29c1e58cd 100644 --- a/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc +++ b/src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.cc @@ -61,7 +61,7 @@ void grpc_server_add_insecure_channel_from_fd(grpc_server* server, grpc_endpoint_add_to_pollset(server_endpoint, pollsets[i]); } - grpc_server_setup_transport(server, transport, nullptr, server_args, 0); + grpc_server_setup_transport(server, transport, nullptr, server_args, nullptr); grpc_chttp2_transport_start_reading(transport, nullptr, nullptr); } diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc index 6689a17da6..98fdb62070 100644 --- a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc +++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc @@ -31,6 +31,7 @@ #include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/surface/api_trace.h" @@ -40,9 +41,8 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, grpc_server_credentials* creds) { grpc_core::ExecCtx exec_ctx; grpc_error* err = GRPC_ERROR_NONE; - grpc_server_security_connector* sc = nullptr; + grpc_core::RefCountedPtr<grpc_server_security_connector> sc; int port_num = 0; - grpc_security_status status; grpc_channel_args* args = nullptr; GRPC_API_TRACE( "grpc_server_add_secure_http2_port(" @@ -54,30 +54,27 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, "No credentials specified for secure server port (creds==NULL)"); goto done; } - status = grpc_server_credentials_create_security_connector(creds, &sc); - if (status != GRPC_SECURITY_OK) { + sc = creds->create_security_connector(); + if (sc == nullptr) { char* msg; gpr_asprintf(&msg, "Unable to create secure server with credentials of type %s.", - creds->type); - err = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg), - GRPC_ERROR_INT_SECURITY_STATUS, status); + creds->type()); + err = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); gpr_free(msg); goto done; } // Create channel args. grpc_arg args_to_add[2]; args_to_add[0] = grpc_server_credentials_to_arg(creds); - args_to_add[1] = grpc_security_connector_to_arg(&sc->base); + args_to_add[1] = grpc_security_connector_to_arg(sc.get()); args = grpc_channel_args_copy_and_add(grpc_server_get_channel_args(server), args_to_add, GPR_ARRAY_SIZE(args_to_add)); // Add server port. err = grpc_chttp2_server_add_port(server, addr, args, &port_num); done: - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server"); - } + sc.reset(DEBUG_LOCATION, "server"); if (err != GRPC_ERROR_NONE) { const char* msg = grpc_error_string(err); diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc index 978ecd59e4..7f4627fa77 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc @@ -1,6 +1,6 @@ /* * - * Copyright 2015 gRPC authors. + * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> +#include "src/core/ext/transport/chttp2/transport/context_list.h" #include "src/core/ext/transport/chttp2/transport/frame_data.h" #include "src/core/ext/transport/chttp2/transport/internal.h" #include "src/core/ext/transport/chttp2/transport/varint.h" @@ -154,6 +155,7 @@ bool g_flow_control_enabled = true; /******************************************************************************* * CONSTRUCTION/DESTRUCTION/REFCOUNTING */ + grpc_chttp2_transport::~grpc_chttp2_transport() { size_t i; @@ -168,6 +170,14 @@ grpc_chttp2_transport::~grpc_chttp2_transport() { grpc_slice_buffer_destroy_internal(&outbuf); grpc_chttp2_hpack_compressor_destroy(&hpack_compressor); + grpc_error* error = + GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed"); + // ContextList::Execute follows semantics of a callback function and does not + // take a ref on error + grpc_core::ContextList::Execute(cl, nullptr, error); + GRPC_ERROR_UNREF(error); + cl = nullptr; + grpc_slice_buffer_destroy_internal(&read_buffer); grpc_chttp2_hpack_parser_destroy(&hpack_parser); grpc_chttp2_goaway_parser_destroy(&goaway_parser); @@ -202,38 +212,6 @@ grpc_chttp2_transport::~grpc_chttp2_transport() { gpr_free(peer_string); } -#ifndef NDEBUG -void grpc_chttp2_unref_transport(grpc_chttp2_transport* t, const char* reason, - const char* file, int line) { - if (grpc_trace_chttp2_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&t->refs.count); - gpr_log(GPR_DEBUG, "chttp2:unref:%p %" PRIdPTR "->%" PRIdPTR " %s [%s:%d]", - t, val, val - 1, reason, file, line); - } - if (!gpr_unref(&t->refs)) return; - t->~grpc_chttp2_transport(); - gpr_free(t); -} - -void grpc_chttp2_ref_transport(grpc_chttp2_transport* t, const char* reason, - const char* file, int line) { - if (grpc_trace_chttp2_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&t->refs.count); - gpr_log(GPR_DEBUG, "chttp2: ref:%p %" PRIdPTR "->%" PRIdPTR " %s [%s:%d]", - t, val, val + 1, reason, file, line); - } - gpr_ref(&t->refs); -} -#else -void grpc_chttp2_unref_transport(grpc_chttp2_transport* t) { - if (!gpr_unref(&t->refs)) return; - t->~grpc_chttp2_transport(); - gpr_free(t); -} - -void grpc_chttp2_ref_transport(grpc_chttp2_transport* t) { gpr_ref(&t->refs); } -#endif - static const grpc_transport_vtable* get_vtable(void); /* Returns whether bdp is enabled */ @@ -485,7 +463,8 @@ static void init_keepalive_pings_if_enabled(grpc_chttp2_transport* t) { grpc_chttp2_transport::grpc_chttp2_transport( const grpc_channel_args* channel_args, grpc_endpoint* ep, bool is_client, grpc_resource_user* resource_user) - : ep(ep), + : refs(1, &grpc_trace_chttp2_refcount), + ep(ep), peer_string(grpc_endpoint_get_peer(ep)), resource_user(resource_user), combiner(grpc_combiner_create()), @@ -495,8 +474,6 @@ grpc_chttp2_transport::grpc_chttp2_transport( GPR_ASSERT(strlen(GRPC_CHTTP2_CLIENT_CONNECT_STRING) == GRPC_CHTTP2_CLIENT_CONNECT_STRLEN); base.vtable = get_vtable(); - /* one ref is for destroy */ - gpr_ref_init(&refs, 1); /* 8 is a random stab in the dark as to a good initial size: it's small enough that it shouldn't waste memory for infrequently used connections, yet large enough that the exponential growth should happen nicely when it's @@ -1065,11 +1042,13 @@ static void write_action_begin_locked(void* gt, grpc_error* error_ignored) { static void write_action(void* gt, grpc_error* error) { GPR_TIMER_SCOPE("write_action", 0); grpc_chttp2_transport* t = static_cast<grpc_chttp2_transport*>(gt); + void* cl = t->cl; + t->cl = nullptr; grpc_endpoint_write( t->ep, &t->outbuf, GRPC_CLOSURE_INIT(&t->write_action_end_locked, write_action_end_locked, t, grpc_combiner_scheduler(t->combiner)), - nullptr); + cl); } /* Callback from the grpc_endpoint after bytes have been written by calling @@ -1393,6 +1372,8 @@ static void perform_stream_op_locked(void* stream_op, GRPC_STATS_INC_HTTP2_OP_BATCHES(); + s->context = op->payload->context; + s->traced = op->is_traced; if (grpc_http_trace.enabled()) { char* str = grpc_transport_stream_op_batch_string(op); gpr_log(GPR_INFO, "perform_stream_op_locked: %s; on_complete = %p", str, @@ -2837,8 +2818,8 @@ Chttp2IncomingByteStream::Chttp2IncomingByteStream( : ByteStream(frame_size, flags), transport_(transport), stream_(stream), + refs_(2), remaining_bytes_(frame_size) { - gpr_ref_init(&refs_, 2); GRPC_ERROR_UNREF(stream->byte_stream_error); stream->byte_stream_error = GRPC_ERROR_NONE; } @@ -2863,14 +2844,6 @@ void Chttp2IncomingByteStream::Orphan() { GRPC_ERROR_NONE); } -void Chttp2IncomingByteStream::Unref() { - if (gpr_unref(&refs_)) { - Delete(this); - } -} - -void Chttp2IncomingByteStream::Ref() { gpr_ref(&refs_); } - void Chttp2IncomingByteStream::NextLocked(void* arg, grpc_error* error_ignored) { Chttp2IncomingByteStream* bs = static_cast<Chttp2IncomingByteStream*>(arg); @@ -3177,21 +3150,18 @@ static const grpc_transport_vtable vtable = {sizeof(grpc_chttp2_stream), static const grpc_transport_vtable* get_vtable(void) { return &vtable; } -intptr_t grpc_chttp2_transport_get_socket_uuid(grpc_transport* transport) { +grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> +grpc_chttp2_transport_get_socket_node(grpc_transport* transport) { grpc_chttp2_transport* t = reinterpret_cast<grpc_chttp2_transport*>(transport); - if (t->channelz_socket != nullptr) { - return t->channelz_socket->uuid(); - } else { - return 0; - } + return t->channelz_socket; } grpc_transport* grpc_create_chttp2_transport( const grpc_channel_args* channel_args, grpc_endpoint* ep, bool is_client, grpc_resource_user* resource_user) { - auto t = new (gpr_malloc(sizeof(grpc_chttp2_transport))) - grpc_chttp2_transport(channel_args, ep, is_client, resource_user); + auto t = grpc_core::New<grpc_chttp2_transport>(channel_args, ep, is_client, + resource_user); return &t->base; } diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.h b/src/core/ext/transport/chttp2/transport/chttp2_transport.h index b3fe1c082e..c22cfb0ad7 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.h +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.h @@ -21,6 +21,7 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/channel/channelz.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/transport/transport.h" @@ -35,7 +36,8 @@ grpc_transport* grpc_create_chttp2_transport( const grpc_channel_args* channel_args, grpc_endpoint* ep, bool is_client, grpc_resource_user* resource_user = nullptr); -intptr_t grpc_chttp2_transport_get_socket_uuid(grpc_transport* transport); +grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> +grpc_chttp2_transport_get_socket_node(grpc_transport* transport); /// Takes ownership of \a read_buffer, which (if non-NULL) contains /// leftover bytes previously read from the endpoint (e.g., by handshakers). diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc new file mode 100644 index 0000000000..df09809067 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/context_list.cc @@ -0,0 +1,67 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/support/port_platform.h> + +#include "src/core/ext/transport/chttp2/transport/context_list.h" + +namespace { +void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*, + grpc_error* error) = nullptr; +void* (*get_copied_context_fn_g)(void*) = nullptr; +} // namespace + +namespace grpc_core { +void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) { + if (get_copied_context_fn_g == nullptr || + write_timestamps_callback_g == nullptr) { + return; + } + /* Create a new element in the list and add it at the front */ + ContextList* elem = grpc_core::New<ContextList>(); + elem->trace_context_ = get_copied_context_fn_g(s->context); + elem->byte_offset_ = s->byte_counter; + elem->next_ = *head; + *head = elem; +} + +void ContextList::Execute(void* arg, grpc_core::Timestamps* ts, + grpc_error* error) { + ContextList* head = static_cast<ContextList*>(arg); + ContextList* to_be_freed; + while (head != nullptr) { + if (write_timestamps_callback_g) { + ts->byte_offset = static_cast<uint32_t>(head->byte_offset_); + write_timestamps_callback_g(head->trace_context_, ts, error); + } + to_be_freed = head; + head = head->next_; + grpc_core::Delete(to_be_freed); + } +} + +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)) { + write_timestamps_callback_g = fn; +} + +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) { + get_copied_context_fn_g = fn; +} +} /* namespace grpc_core */ diff --git a/src/core/ext/transport/chttp2/transport/context_list.h b/src/core/ext/transport/chttp2/transport/context_list.h new file mode 100644 index 0000000000..5b9d2ab378 --- /dev/null +++ b/src/core/ext/transport/chttp2/transport/context_list.h @@ -0,0 +1,53 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H +#define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H + +#include <grpc/support/port_platform.h> + +#include "src/core/lib/iomgr/buffer_list.h" + +#include "src/core/ext/transport/chttp2/transport/internal.h" + +namespace grpc_core { +/** A list of RPC Contexts */ +class ContextList { + public: + /* Creates a new element with \a context as the value and appends it to the + * list. */ + static void Append(ContextList** head, grpc_chttp2_stream* s); + + /* Executes a function \a fn with each context in the list and \a ts. It also + * frees up the entire list after this operation. It is intended as a callback + * and hence does not take a ref on \a error */ + static void Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error); + + private: + void* trace_context_ = nullptr; + ContextList* next_ = nullptr; + size_t byte_offset_ = 0; +}; + +void grpc_http2_set_write_timestamps_callback(void (*fn)(void*, + grpc_core::Timestamps*, + grpc_error* error)); +void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)); +} /* namespace grpc_core */ + +#endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H */ diff --git a/src/core/ext/transport/chttp2/transport/internal.h b/src/core/ext/transport/chttp2/transport/internal.h index 3ee408c103..341f5b3977 100644 --- a/src/core/ext/transport/chttp2/transport/internal.h +++ b/src/core/ext/transport/chttp2/transport/internal.h @@ -45,6 +45,10 @@ #include "src/core/lib/transport/connectivity_state.h" #include "src/core/lib/transport/transport_impl.h" +namespace grpc_core { +class ContextList; +} + /* streams are kept in various linked lists depending on what things need to happen to them... this enum labels each list */ typedef enum { @@ -232,8 +236,12 @@ class Chttp2IncomingByteStream : public ByteStream { // alone for now. We can revisit this once we're able to link against // libc++, at which point we can eliminate New<> and Delete<> and // switch to std::shared_ptr<>. - void Ref(); - void Unref(); + void Ref() { refs_.Ref(); } + void Unref() { + if (refs_.Unref()) { + grpc_core::Delete(this); + } + } void PublishError(grpc_error* error); @@ -252,7 +260,7 @@ class Chttp2IncomingByteStream : public ByteStream { grpc_chttp2_transport* transport_; // Immutable. grpc_chttp2_stream* stream_; // Immutable. - gpr_refcount refs_; + grpc_core::RefCount refs_; /* Accessed only by transport thread when stream->pending_byte_stream == false * Accessed only by application thread when stream->pending_byte_stream == @@ -286,7 +294,7 @@ struct grpc_chttp2_transport { ~grpc_chttp2_transport(); grpc_transport base; /* must be first */ - gpr_refcount refs; + grpc_core::RefCount refs; grpc_endpoint* ep; char* peer_string; @@ -481,7 +489,7 @@ struct grpc_chttp2_transport { bool keepalive_permit_without_calls = false; /** keep-alive state machine state */ grpc_chttp2_keepalive_state keepalive_state; - + grpc_core::ContextList* cl = nullptr; grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> channelz_socket; uint32_t num_messages_in_next_write = 0; }; @@ -498,6 +506,7 @@ struct grpc_chttp2_stream { const void* server_data, gpr_arena* arena); ~grpc_chttp2_stream(); + void* context; grpc_chttp2_transport* t; grpc_stream_refcount* refcount; @@ -633,8 +642,12 @@ struct grpc_chttp2_stream { /** Whether bytes stored in unprocessed_incoming_byte_stream is decompressed */ bool unprocessed_incoming_frames_decompressed = false; + /** Whether the bytes needs to be traced using Fathom */ + bool traced = false; /** gRPC header bytes that are already decompressed */ size_t decompressed_header_bytes = 0; + /** Byte counter for number of bytes written */ + size_t byte_counter = 0; }; /** Transport writing call flow: @@ -779,15 +792,29 @@ void grpc_chttp2_stream_unref(grpc_chttp2_stream* s); grpc_chttp2_ref_transport(t, r, __FILE__, __LINE__) #define GRPC_CHTTP2_UNREF_TRANSPORT(t, r) \ grpc_chttp2_unref_transport(t, r, __FILE__, __LINE__) -void grpc_chttp2_unref_transport(grpc_chttp2_transport* t, const char* reason, - const char* file, int line); -void grpc_chttp2_ref_transport(grpc_chttp2_transport* t, const char* reason, - const char* file, int line); +inline void grpc_chttp2_unref_transport(grpc_chttp2_transport* t, + const char* reason, const char* file, + int line) { + if (t->refs.Unref(grpc_core::DebugLocation(file, line), reason)) { + grpc_core::Delete(t); + } +} +inline void grpc_chttp2_ref_transport(grpc_chttp2_transport* t, + const char* reason, const char* file, + int line) { + t->refs.Ref(grpc_core::DebugLocation(file, line), reason); +} #else #define GRPC_CHTTP2_REF_TRANSPORT(t, r) grpc_chttp2_ref_transport(t) #define GRPC_CHTTP2_UNREF_TRANSPORT(t, r) grpc_chttp2_unref_transport(t) -void grpc_chttp2_unref_transport(grpc_chttp2_transport* t); -void grpc_chttp2_ref_transport(grpc_chttp2_transport* t); +inline void grpc_chttp2_unref_transport(grpc_chttp2_transport* t) { + if (t->refs.Unref()) { + grpc_core::Delete(t); + } +} +inline void grpc_chttp2_ref_transport(grpc_chttp2_transport* t) { + t->refs.Ref(); +} #endif void grpc_chttp2_ack_ping(grpc_chttp2_transport* t, uint64_t id); diff --git a/src/core/ext/transport/chttp2/transport/writing.cc b/src/core/ext/transport/chttp2/transport/writing.cc index d533989444..265d3365d3 100644 --- a/src/core/ext/transport/chttp2/transport/writing.cc +++ b/src/core/ext/transport/chttp2/transport/writing.cc @@ -18,6 +18,7 @@ #include <grpc/support/port_platform.h> +#include "src/core/ext/transport/chttp2/transport/context_list.h" #include "src/core/ext/transport/chttp2/transport/internal.h" #include <limits.h> @@ -362,6 +363,7 @@ class DataSendContext { grpc_chttp2_encode_data(s_->id, &s_->compressed_data_buffer, send_bytes, is_last_frame_, &s_->stats.outgoing, &t_->outbuf); s_->flow_control->SentData(send_bytes); + s_->byte_counter += send_bytes; if (s_->compressed_data_buffer.length == 0) { s_->sending_bytes += s_->uncompressed_data_size; } @@ -496,6 +498,9 @@ class StreamWriteContext { data_send_context.CompressMoreBytes(); } } + if (s_->traced && grpc_endpoint_can_track_err(t_->ep)) { + grpc_core::ContextList::Append(&t_->cl, s_); + } write_context_->ResetPingClock(); if (data_send_context.is_last_frame()) { SentLastFrame(); diff --git a/src/core/ext/transport/inproc/inproc_transport.cc b/src/core/ext/transport/inproc/inproc_transport.cc index 61968de4d5..0b9bf5dd11 100644 --- a/src/core/ext/transport/inproc/inproc_transport.cc +++ b/src/core/ext/transport/inproc/inproc_transport.cc @@ -1236,7 +1236,7 @@ grpc_channel* grpc_inproc_channel_create(grpc_server* server, // TODO(ncteisen): design and support channelz GetSocket for inproc. grpc_server_setup_transport(server, server_transport, nullptr, server_args, - 0); + nullptr); grpc_channel* channel = grpc_channel_create( "inproc", client_args, GRPC_CLIENT_DIRECT_CHANNEL, client_transport); diff --git a/src/core/lib/channel/channelz.cc b/src/core/lib/channel/channelz.cc index 8d589f5983..8a596ad460 100644 --- a/src/core/lib/channel/channelz.cc +++ b/src/core/lib/channel/channelz.cc @@ -203,33 +203,34 @@ ServerNode::ServerNode(grpc_server* server, size_t channel_tracer_max_nodes) ServerNode::~ServerNode() {} -char* ServerNode::RenderServerSockets(intptr_t start_socket_id) { +char* ServerNode::RenderServerSockets(intptr_t start_socket_id, + intptr_t max_results) { + // if user does not set max_results, we choose 500. + size_t pagination_limit = max_results == 0 ? 500 : max_results; grpc_json* top_level_json = grpc_json_create(GRPC_JSON_OBJECT); grpc_json* json = top_level_json; grpc_json* json_iterator = nullptr; - ChildRefsList socket_refs; - // uuids index into entities one-off (idx 0 is really uuid 1, since 0 is - // reserved). However, we want to support requests coming in with - // start_server_id=0, which signifies "give me everything." - size_t start_idx = start_socket_id == 0 ? 0 : start_socket_id - 1; - grpc_server_populate_server_sockets(server_, &socket_refs, start_idx); + ChildSocketsList socket_refs; + grpc_server_populate_server_sockets(server_, &socket_refs, start_socket_id); + // declared early so it can be used outside of the loop. + size_t i = 0; if (!socket_refs.empty()) { // create list of socket refs grpc_json* array_parent = grpc_json_create_child( nullptr, json, "socketRef", nullptr, GRPC_JSON_ARRAY, false); - for (size_t i = 0; i < socket_refs.size(); ++i) { - json_iterator = - grpc_json_create_child(json_iterator, array_parent, nullptr, nullptr, - GRPC_JSON_OBJECT, false); - grpc_json_add_number_string_child(json_iterator, nullptr, "socketId", - socket_refs[i]); + for (i = 0; i < GPR_MIN(socket_refs.size(), pagination_limit); ++i) { + grpc_json* socket_ref_json = grpc_json_create_child( + nullptr, array_parent, nullptr, nullptr, GRPC_JSON_OBJECT, false); + json_iterator = grpc_json_add_number_string_child( + socket_ref_json, nullptr, "socketId", socket_refs[i]->uuid()); + grpc_json_create_child(json_iterator, socket_ref_json, "name", + socket_refs[i]->remote(), GRPC_JSON_STRING, false); } } - // For now we do not have any pagination rules. In the future we could - // pick a constant for max_channels_sent for a GetServers request. - // Tracking: https://github.com/grpc/grpc/issues/16019. - json_iterator = grpc_json_create_child(nullptr, json, "end", nullptr, - GRPC_JSON_TRUE, false); + if (i == socket_refs.size()) { + json_iterator = grpc_json_create_child(nullptr, json, "end", nullptr, + GRPC_JSON_TRUE, false); + } char* json_str = grpc_json_dump_to_string(top_level_json, 0); grpc_json_destroy(top_level_json); return json_str; diff --git a/src/core/lib/channel/channelz.h b/src/core/lib/channel/channelz.h index 64ab5cb3a6..e43792126f 100644 --- a/src/core/lib/channel/channelz.h +++ b/src/core/lib/channel/channelz.h @@ -59,6 +59,9 @@ namespace channelz { // add human readable names as in the channelz.proto typedef InlinedVector<intptr_t, 10> ChildRefsList; +class SocketNode; +typedef InlinedVector<SocketNode*, 10> ChildSocketsList; + namespace testing { class CallCountingHelperPeer; class ChannelNodePeer; @@ -207,7 +210,8 @@ class ServerNode : public BaseNode { grpc_json* RenderJson() override; - char* RenderServerSockets(intptr_t start_socket_id); + char* RenderServerSockets(intptr_t start_socket_id, + intptr_t pagination_limit); // proxy methods to composed classes. void AddTraceEvent(ChannelTrace::Severity severity, grpc_slice data) { @@ -251,6 +255,8 @@ class SocketNode : public BaseNode { gpr_atm_no_barrier_fetch_add(&keepalives_sent_, static_cast<gpr_atm>(1)); } + const char* remote() { return remote_.get(); } + private: gpr_atm streams_started_ = 0; gpr_atm streams_succeeded_ = 0; diff --git a/src/core/lib/channel/channelz_registry.cc b/src/core/lib/channel/channelz_registry.cc index bc23b90a66..7cca247d64 100644 --- a/src/core/lib/channel/channelz_registry.cc +++ b/src/core/lib/channel/channelz_registry.cc @@ -252,7 +252,8 @@ char* grpc_channelz_get_server(intptr_t server_id) { } char* grpc_channelz_get_server_sockets(intptr_t server_id, - intptr_t start_socket_id) { + intptr_t start_socket_id, + intptr_t max_results) { grpc_core::channelz::BaseNode* base_node = grpc_core::channelz::ChannelzRegistry::Get(server_id); if (base_node == nullptr || @@ -263,7 +264,7 @@ char* grpc_channelz_get_server_sockets(intptr_t server_id, // actually a server node grpc_core::channelz::ServerNode* server_node = static_cast<grpc_core::channelz::ServerNode*>(base_node); - return server_node->RenderServerSockets(start_socket_id); + return server_node->RenderServerSockets(start_socket_id, max_results); } char* grpc_channelz_get_channel(intptr_t channel_id) { diff --git a/src/core/lib/debug/trace.cc b/src/core/lib/debug/trace.cc index 01c1e867d9..cafdb15c69 100644 --- a/src/core/lib/debug/trace.cc +++ b/src/core/lib/debug/trace.cc @@ -21,6 +21,7 @@ #include "src/core/lib/debug/trace.h" #include <string.h> +#include <type_traits> #include <grpc/grpc.h> #include <grpc/support/alloc.h> @@ -79,6 +80,8 @@ void TraceFlagList::LogAllTracers() { // Flags register themselves on the list during construction TraceFlag::TraceFlag(bool default_enabled, const char* name) : name_(name) { + static_assert(std::is_trivially_destructible<TraceFlag>::value, + "TraceFlag needs to be trivially destructible."); set_enabled(default_enabled); TraceFlagList::Add(this); } diff --git a/src/core/lib/debug/trace.h b/src/core/lib/debug/trace.h index fe6301a3fc..4623494520 100644 --- a/src/core/lib/debug/trace.h +++ b/src/core/lib/debug/trace.h @@ -53,7 +53,8 @@ void grpc_tracer_enable_flag(grpc_core::TraceFlag* flag); class TraceFlag { public: TraceFlag(bool default_enabled, const char* name); - ~TraceFlag() {} + // This needs to be trivially destructible as it is used as global variable. + ~TraceFlag() = default; const char* name() const { return name_; } @@ -102,8 +103,9 @@ typedef TraceFlag DebugOnlyTraceFlag; #else class DebugOnlyTraceFlag { public: - DebugOnlyTraceFlag(bool default_enabled, const char* name) {} - bool enabled() { return false; } + constexpr DebugOnlyTraceFlag(bool default_enabled, const char* name) {} + constexpr bool enabled() const { return false; } + constexpr const char* name() const { return "DebugOnlyTraceFlag"; } private: void set_enabled(bool enabled) {} diff --git a/src/core/lib/gpr/sync_posix.cc b/src/core/lib/gpr/sync_posix.cc index 69bd609485..c09a7598ac 100644 --- a/src/core/lib/gpr/sync_posix.cc +++ b/src/core/lib/gpr/sync_posix.cc @@ -30,11 +30,18 @@ // For debug of the timer manager crash only. // TODO (mxyan): remove after bug is fixed. #ifdef GRPC_DEBUG_TIMER_MANAGER +#include <string.h> void (*g_grpc_debug_timer_manager_stats)( int64_t timer_manager_init_count, int64_t timer_manager_shutdown_count, int64_t fork_count, int64_t timer_wait_err, int64_t timer_cv_value, int64_t timer_mu_value, int64_t abstime_sec_value, - int64_t abstime_nsec_value) = nullptr; + int64_t abstime_nsec_value, int64_t abs_deadline_sec_value, + int64_t abs_deadline_nsec_value, int64_t now1_sec_value, + int64_t now1_nsec_value, int64_t now2_sec_value, int64_t now2_nsec_value, + int64_t add_result_sec_value, int64_t add_result_nsec_value, + int64_t sub_result_sec_value, int64_t sub_result_nsec_value, + int64_t next_value, int64_t start_time_sec, + int64_t start_time_nsec) = nullptr; int64_t g_timer_manager_init_count = 0; int64_t g_timer_manager_shutdown_count = 0; int64_t g_fork_count = 0; @@ -43,6 +50,19 @@ int64_t g_timer_cv_value = 0; int64_t g_timer_mu_value = 0; int64_t g_abstime_sec_value = -1; int64_t g_abstime_nsec_value = -1; +int64_t g_abs_deadline_sec_value = -1; +int64_t g_abs_deadline_nsec_value = -1; +int64_t g_now1_sec_value = -1; +int64_t g_now1_nsec_value = -1; +int64_t g_now2_sec_value = -1; +int64_t g_now2_nsec_value = -1; +int64_t g_add_result_sec_value = -1; +int64_t g_add_result_nsec_value = -1; +int64_t g_sub_result_sec_value = -1; +int64_t g_sub_result_nsec_value = -1; +int64_t g_next_value = -1; +int64_t g_start_time_sec = -1; +int64_t g_start_time_nsec = -1; #endif // GRPC_DEBUG_TIMER_MANAGER #ifdef GPR_LOW_LEVEL_COUNTERS @@ -90,17 +110,74 @@ void gpr_cv_init(gpr_cv* cv) { void gpr_cv_destroy(gpr_cv* cv) { GPR_ASSERT(pthread_cond_destroy(cv) == 0); } +// For debug of the timer manager crash only. +// TODO (mxyan): remove after bug is fixed. +#ifdef GRPC_DEBUG_TIMER_MANAGER +static gpr_timespec gpr_convert_clock_type_debug_timespec( + gpr_timespec t, gpr_clock_type clock_type, gpr_timespec& now1, + gpr_timespec& now2, gpr_timespec& add_result, gpr_timespec& sub_result) { + if (t.clock_type == clock_type) { + return t; + } + + if (t.tv_sec == INT64_MAX || t.tv_sec == INT64_MIN) { + t.clock_type = clock_type; + return t; + } + + if (clock_type == GPR_TIMESPAN) { + return gpr_time_sub(t, gpr_now(t.clock_type)); + } + + if (t.clock_type == GPR_TIMESPAN) { + return gpr_time_add(gpr_now(clock_type), t); + } + + now1 = gpr_now(t.clock_type); + sub_result = gpr_time_sub(t, now1); + now2 = gpr_now(clock_type); + add_result = gpr_time_add(now2, sub_result); + return add_result; +} + +#define gpr_convert_clock_type_debug(t, clock_type, now1, now2, add_result, \ + sub_result) \ + gpr_convert_clock_type_debug_timespec((t), (clock_type), (now1), (now2), \ + (add_result), (sub_result)) +#else +#define gpr_convert_clock_type_debug(t, clock_type, now1, now2, add_result, \ + sub_result) \ + gpr_convert_clock_type((t), (clock_type)) +#endif + int gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, gpr_timespec abs_deadline) { int err = 0; +#ifdef GRPC_DEBUG_TIMER_MANAGER + // For debug of the timer manager crash only. + // TODO (mxyan): remove after bug is fixed. + gpr_timespec abs_deadline_copy; + abs_deadline_copy.tv_sec = abs_deadline.tv_sec; + abs_deadline_copy.tv_nsec = abs_deadline.tv_nsec; + gpr_timespec now1; + gpr_timespec now2; + gpr_timespec add_result; + gpr_timespec sub_result; + memset(&now1, 0, sizeof(now1)); + memset(&now2, 0, sizeof(now2)); + memset(&add_result, 0, sizeof(add_result)); + memset(&sub_result, 0, sizeof(sub_result)); +#endif if (gpr_time_cmp(abs_deadline, gpr_inf_future(abs_deadline.clock_type)) == 0) { err = pthread_cond_wait(cv, mu); } else { struct timespec abs_deadline_ts; #if GPR_LINUX - abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_MONOTONIC); + abs_deadline = gpr_convert_clock_type_debug( + abs_deadline, GPR_CLOCK_MONOTONIC, now1, now2, add_result, sub_result); #else - abs_deadline = gpr_convert_clock_type(abs_deadline, GPR_CLOCK_REALTIME); + abs_deadline = gpr_convert_clock_type_debug( + abs_deadline, GPR_CLOCK_REALTIME, now1, now2, add_result, sub_result); #endif // GPR_LINUX abs_deadline_ts.tv_sec = static_cast<time_t>(abs_deadline.tv_sec); abs_deadline_ts.tv_nsec = abs_deadline.tv_nsec; @@ -123,10 +200,25 @@ int gpr_cv_wait(gpr_cv* cv, gpr_mu* mu, gpr_timespec abs_deadline) { g_timer_wait_err = err; g_timer_cv_value = (int64_t)cv; g_timer_mu_value = (int64_t)mu; + g_abs_deadline_sec_value = abs_deadline_copy.tv_sec; + g_abs_deadline_nsec_value = abs_deadline_copy.tv_nsec; + g_now1_sec_value = now1.tv_sec; + g_now1_nsec_value = now1.tv_nsec; + g_now2_sec_value = now2.tv_sec; + g_now2_nsec_value = now2.tv_nsec; + g_add_result_sec_value = add_result.tv_sec; + g_add_result_nsec_value = add_result.tv_nsec; + g_sub_result_sec_value = sub_result.tv_sec; + g_sub_result_nsec_value = sub_result.tv_nsec; g_grpc_debug_timer_manager_stats( g_timer_manager_init_count, g_timer_manager_shutdown_count, g_fork_count, g_timer_wait_err, g_timer_cv_value, g_timer_mu_value, - g_abstime_sec_value, g_abstime_nsec_value); + g_abstime_sec_value, g_abstime_nsec_value, g_abs_deadline_sec_value, + g_abs_deadline_nsec_value, g_now1_sec_value, g_now1_nsec_value, + g_now2_sec_value, g_now2_nsec_value, g_add_result_sec_value, + g_add_result_nsec_value, g_sub_result_sec_value, + g_sub_result_nsec_value, g_next_value, g_start_time_sec, + g_start_time_nsec); } } #endif diff --git a/src/core/lib/gprpp/inlined_vector.h b/src/core/lib/gprpp/inlined_vector.h index 65c2b9634f..66dc751a56 100644 --- a/src/core/lib/gprpp/inlined_vector.h +++ b/src/core/lib/gprpp/inlined_vector.h @@ -100,10 +100,7 @@ class InlinedVector { void reserve(size_t capacity) { if (capacity > capacity_) { T* new_dynamic = static_cast<T*>(gpr_malloc(sizeof(T) * capacity)); - for (size_t i = 0; i < size_; ++i) { - new (&new_dynamic[i]) T(std::move(data()[i])); - data()[i].~T(); - } + move_elements(data(), new_dynamic, size_); gpr_free(dynamic_); dynamic_ = new_dynamic; capacity_ = capacity; @@ -131,13 +128,25 @@ class InlinedVector { size_--; } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + size_t capacity() const { return capacity_; } + + void clear() { + destroy_elements(); + init_data(); + } + + private: void copy_from(const InlinedVector& v) { - // if v is allocated, copy over the buffer. + // if v is allocated, make sure we have enough capacity. if (v.dynamic_ != nullptr) { reserve(v.capacity_); - memcpy(dynamic_, v.dynamic_, v.size_ * sizeof(T)); - } else { - memcpy(inline_, v.inline_, v.size_ * sizeof(T)); + } + // copy over elements + for (size_t i = 0; i < v.size_; ++i) { + new (&(data()[i])) T(v[i]); } // copy over metadata size_ = v.size_; @@ -145,11 +154,12 @@ class InlinedVector { } void move_from(InlinedVector& v) { - // if v is allocated, then we steal its buffer, else we copy it. + // if v is allocated, then we steal its dynamic array; otherwise, we + // move the elements individually. if (v.dynamic_ != nullptr) { dynamic_ = v.dynamic_; } else { - memcpy(inline_, v.inline_, v.size_ * sizeof(T)); + move_elements(v.data(), data(), v.size_); } // copy over metadata size_ = v.size_; @@ -158,17 +168,13 @@ class InlinedVector { v.init_data(); } - size_t size() const { return size_; } - bool empty() const { return size_ == 0; } - - size_t capacity() const { return capacity_; } - - void clear() { - destroy_elements(); - init_data(); + static void move_elements(T* src, T* dst, size_t num_elements) { + for (size_t i = 0; i < num_elements; ++i) { + new (&dst[i]) T(std::move(src[i])); + src[i].~T(); + } } - private: void init_data() { dynamic_ = nullptr; size_ = 0; diff --git a/src/core/lib/gprpp/memory.h b/src/core/lib/gprpp/memory.h index e90bedcd9b..b4b63ae771 100644 --- a/src/core/lib/gprpp/memory.h +++ b/src/core/lib/gprpp/memory.h @@ -40,15 +40,10 @@ namespace grpc_core { -// The alignment of memory returned by gpr_malloc(). -constexpr size_t kAlignmentForDefaultAllocationInBytes = 8; - // Alternative to new, since we cannot use it (for fear of libstdc++) template <typename T, typename... Args> inline T* New(Args&&... args) { - void* p = alignof(T) > kAlignmentForDefaultAllocationInBytes - ? gpr_malloc_aligned(sizeof(T), alignof(T)) - : gpr_malloc(sizeof(T)); + void* p = gpr_malloc(sizeof(T)); return new (p) T(std::forward<Args>(args)...); } @@ -57,11 +52,7 @@ template <typename T> inline void Delete(T* p) { if (p == nullptr) return; p->~T(); - if (alignof(T) > kAlignmentForDefaultAllocationInBytes) { - gpr_free_aligned(p); - } else { - gpr_free(p); - } + gpr_free(p); } template <typename T> diff --git a/src/core/lib/gprpp/orphanable.h b/src/core/lib/gprpp/orphanable.h index 3123e3f5a3..9053c60111 100644 --- a/src/core/lib/gprpp/orphanable.h +++ b/src/core/lib/gprpp/orphanable.h @@ -31,6 +31,7 @@ #include "src/core/lib/gprpp/abstract.h" #include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/memory.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" namespace grpc_core { @@ -89,107 +90,42 @@ class InternallyRefCounted : public Orphanable { template <typename T> friend class RefCountedPtr; - InternallyRefCounted() { gpr_ref_init(&refs_, 1); } - virtual ~InternallyRefCounted() {} + // TraceFlagT is defined to accept both DebugOnlyTraceFlag and TraceFlag. + // Note: RefCount tracing is only enabled on debug builds, even when a + // TraceFlag is used. + template <typename TraceFlagT = TraceFlag> + explicit InternallyRefCounted(TraceFlagT* trace_flag = nullptr) + : refs_(1, trace_flag) {} + virtual ~InternallyRefCounted() = default; RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT { IncrementRefCount(); return RefCountedPtr<Child>(static_cast<Child*>(this)); } - - void Unref() { - if (gpr_unref(&refs_)) { - Delete(static_cast<Child*>(this)); - } - } - - private: - void IncrementRefCount() { gpr_ref(&refs_); } - - gpr_refcount refs_; -}; - -// An alternative version of the InternallyRefCounted base class that -// supports tracing. This is intended to be used in cases where the -// object will be handled both by idiomatic C++ code using smart -// pointers and legacy code that is manually calling Ref() and Unref(). -// Once all of our code is converted to idiomatic C++, we may be able to -// eliminate this class. -template <typename Child> -class InternallyRefCountedWithTracing : public Orphanable { - public: - // Not copyable nor movable. - InternallyRefCountedWithTracing(const InternallyRefCountedWithTracing&) = - delete; - InternallyRefCountedWithTracing& operator=( - const InternallyRefCountedWithTracing&) = delete; - - GRPC_ABSTRACT_BASE_CLASS - - protected: - GPRC_ALLOW_CLASS_TO_USE_NON_PUBLIC_DELETE - - // Allow RefCountedPtr<> to access Unref() and IncrementRefCount(). - template <typename T> - friend class RefCountedPtr; - - InternallyRefCountedWithTracing() - : InternallyRefCountedWithTracing(static_cast<TraceFlag*>(nullptr)) {} - - explicit InternallyRefCountedWithTracing(TraceFlag* trace_flag) - : trace_flag_(trace_flag) { - gpr_ref_init(&refs_, 1); - } - -#ifdef NDEBUG - explicit InternallyRefCountedWithTracing(DebugOnlyTraceFlag* trace_flag) - : InternallyRefCountedWithTracing() {} -#endif - - virtual ~InternallyRefCountedWithTracing() {} - - RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT { - IncrementRefCount(); - return RefCountedPtr<Child>(static_cast<Child*>(this)); - } - RefCountedPtr<Child> Ref(const DebugLocation& location, const char* reason) GRPC_MUST_USE_RESULT { - if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { - gpr_atm old_refs = gpr_atm_no_barrier_load(&refs_.count); - gpr_log(GPR_INFO, "%s:%p %s:%d ref %" PRIdPTR " -> %" PRIdPTR " %s", - trace_flag_->name(), this, location.file(), location.line(), - old_refs, old_refs + 1, reason); - } - return Ref(); + IncrementRefCount(location, reason); + return RefCountedPtr<Child>(static_cast<Child*>(this)); } - // TODO(roth): Once all of our code is converted to C++ and can use - // RefCountedPtr<> instead of manual ref-counting, make the Unref() methods - // private, since they will only be used by RefCountedPtr<>, which is a - // friend of this class. - void Unref() { - if (gpr_unref(&refs_)) { + if (refs_.Unref()) { Delete(static_cast<Child*>(this)); } } - void Unref(const DebugLocation& location, const char* reason) { - if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { - gpr_atm old_refs = gpr_atm_no_barrier_load(&refs_.count); - gpr_log(GPR_INFO, "%s:%p %s:%d unref %" PRIdPTR " -> %" PRIdPTR " %s", - trace_flag_->name(), this, location.file(), location.line(), - old_refs, old_refs - 1, reason); + if (refs_.Unref(location, reason)) { + Delete(static_cast<Child*>(this)); } - Unref(); } private: - void IncrementRefCount() { gpr_ref(&refs_); } + void IncrementRefCount() { refs_.Ref(); } + void IncrementRefCount(const DebugLocation& location, const char* reason) { + refs_.Ref(location, reason); + } - TraceFlag* trace_flag_ = nullptr; - gpr_refcount refs_; + grpc_core::RefCount refs_; }; } // namespace grpc_core diff --git a/src/core/lib/gprpp/ref_counted.h b/src/core/lib/gprpp/ref_counted.h index 03c293f6ed..fa97ffcfed 100644 --- a/src/core/lib/gprpp/ref_counted.h +++ b/src/core/lib/gprpp/ref_counted.h @@ -21,9 +21,12 @@ #include <grpc/support/port_platform.h> +#include <grpc/support/atm.h> #include <grpc/support/log.h> #include <grpc/support/sync.h> +#include <atomic> +#include <cassert> #include <cinttypes> #include "src/core/lib/debug/trace.h" @@ -34,61 +37,150 @@ namespace grpc_core { -// A base class for reference-counted objects. -// New objects should be created via New() and start with a refcount of 1. -// When the refcount reaches 0, the object will be deleted via Delete(). -// -// This will commonly be used by CRTP (curiously-recurring template pattern) -// e.g., class MyClass : public RefCounted<MyClass> -template <typename Child> -class RefCounted { +// PolymorphicRefCount enforces polymorphic destruction of RefCounted. +class PolymorphicRefCount { public: - RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT { - IncrementRefCount(); - return RefCountedPtr<Child>(static_cast<Child*>(this)); - } + GRPC_ABSTRACT_BASE_CLASS - // TODO(roth): Once all of our code is converted to C++ and can use - // RefCountedPtr<> instead of manual ref-counting, make this method - // private, since it will only be used by RefCountedPtr<>, which is a - // friend of this class. - void Unref() { - if (gpr_unref(&refs_)) { - Delete(static_cast<Child*>(this)); - } - } + protected: + GPRC_ALLOW_CLASS_TO_USE_NON_PUBLIC_DELETE - // Not copyable nor movable. - RefCounted(const RefCounted&) = delete; - RefCounted& operator=(const RefCounted&) = delete; + virtual ~PolymorphicRefCount() = default; +}; +// NonPolymorphicRefCount does not enforce polymorphic destruction of +// RefCounted. Please refer to grpc_core::RefCounted for more details, and +// when in doubt use PolymorphicRefCount. +class NonPolymorphicRefCount { + public: GRPC_ABSTRACT_BASE_CLASS protected: GPRC_ALLOW_CLASS_TO_USE_NON_PUBLIC_DELETE - RefCounted() { gpr_ref_init(&refs_, 1); } + ~NonPolymorphicRefCount() = default; +}; - virtual ~RefCounted() {} +// RefCount is a simple atomic ref-count. +// +// This is a C++ implementation of gpr_refcount, with inline functions. Due to +// inline functions, this class is significantly more efficient than +// gpr_refcount and should be preferred over gpr_refcount whenever possible. +// +// TODO(soheil): Remove gpr_refcount after submitting the GRFC and the paragraph +// above. +class RefCount { + public: + using Value = intptr_t; + + // `init` is the initial refcount stored in this object. + // + // TraceFlagT is defined to accept both DebugOnlyTraceFlag and TraceFlag. + // Note: RefCount tracing is only enabled on debug builds, even when a + // TraceFlag is used. + template <typename TraceFlagT = TraceFlag> + constexpr explicit RefCount(Value init = 1, TraceFlagT* trace_flag = nullptr) + : +#ifndef NDEBUG + trace_flag_(trace_flag), +#endif + value_(init) { + } - private: - // Allow RefCountedPtr<> to access IncrementRefCount(). - template <typename T> - friend class RefCountedPtr; + // Increases the ref-count by `n`. + void Ref(Value n = 1) { + GPR_ATM_INC_ADD_THEN(value_.fetch_add(n, std::memory_order_relaxed)); + } + void Ref(const DebugLocation& location, const char* reason, Value n = 1) { +#ifndef NDEBUG + if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { + const RefCount::Value old_refs = get(); + gpr_log(GPR_INFO, "%s:%p %s:%d ref %" PRIdPTR " -> %" PRIdPTR " %s", + trace_flag_->name(), this, location.file(), location.line(), + old_refs, old_refs + n, reason); + } +#endif + Ref(n); + } + + // Similar to Ref() with an assert on the ref-count being non-zero. + void RefNonZero() { +#ifndef NDEBUG + const Value prior = + GPR_ATM_INC_ADD_THEN(value_.fetch_add(1, std::memory_order_relaxed)); + assert(prior > 0); +#else + Ref(); +#endif + } + void RefNonZero(const DebugLocation& location, const char* reason) { +#ifndef NDEBUG + if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { + const RefCount::Value old_refs = get(); + gpr_log(GPR_INFO, "%s:%p %s:%d ref %" PRIdPTR " -> %" PRIdPTR " %s", + trace_flag_->name(), this, location.file(), location.line(), + old_refs, old_refs + 1, reason); + } +#endif + RefNonZero(); + } - void IncrementRefCount() { gpr_ref(&refs_); } + // Decrements the ref-count and returns true if the ref-count reaches 0. + bool Unref() { + const Value prior = + GPR_ATM_INC_ADD_THEN(value_.fetch_sub(1, std::memory_order_acq_rel)); + GPR_DEBUG_ASSERT(prior > 0); + return prior == 1; + } + bool Unref(const DebugLocation& location, const char* reason) { +#ifndef NDEBUG + if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { + const RefCount::Value old_refs = get(); + gpr_log(GPR_INFO, "%s:%p %s:%d unref %" PRIdPTR " -> %" PRIdPTR " %s", + trace_flag_->name(), this, location.file(), location.line(), + old_refs, old_refs - 1, reason); + } +#endif + return Unref(); + } - gpr_refcount refs_; + private: + Value get() const { return value_.load(std::memory_order_relaxed); } + +#ifndef NDEBUG + TraceFlag* trace_flag_; +#endif + std::atomic<Value> value_; }; -// An alternative version of the RefCounted base class that -// supports tracing. This is intended to be used in cases where the -// object will be handled both by idiomatic C++ code using smart -// pointers and legacy code that is manually calling Ref() and Unref(). -// Once all of our code is converted to idiomatic C++, we may be able to -// eliminate this class. -template <typename Child> -class RefCountedWithTracing { +// A base class for reference-counted objects. +// New objects should be created via New() and start with a refcount of 1. +// When the refcount reaches 0, the object will be deleted via Delete(). +// +// This will commonly be used by CRTP (curiously-recurring template pattern) +// e.g., class MyClass : public RefCounted<MyClass> +// +// Use PolymorphicRefCount and NonPolymorphicRefCount to select between +// different implementations of RefCounted. +// +// Note that NonPolymorphicRefCount does not support polymorphic destruction. +// So, use NonPolymorphicRefCount only when both of the following conditions +// are guaranteed to hold: +// (a) Child is a concrete leaf class in RefCounted<Child>, and +// (b) you are gauranteed to call Unref only on concrete leaf classes and not +// their parents. +// +// The following example is illegal, because calling Unref() will not call +// the dtor of Child. +// +// class Parent : public RefCounted<Parent, NonPolymorphicRefCount> {} +// class Child : public Parent {} +// +// Child* ch; +// ch->Unref(); +// +template <typename Child, typename Impl = PolymorphicRefCount> +class RefCounted : public Impl { public: RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT { IncrementRefCount(); @@ -97,69 +189,55 @@ class RefCountedWithTracing { RefCountedPtr<Child> Ref(const DebugLocation& location, const char* reason) GRPC_MUST_USE_RESULT { - if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { - gpr_atm old_refs = gpr_atm_no_barrier_load(&refs_.count); - gpr_log(GPR_INFO, "%s:%p %s:%d ref %" PRIdPTR " -> %" PRIdPTR " %s", - trace_flag_->name(), this, location.file(), location.line(), - old_refs, old_refs + 1, reason); - } - return Ref(); + IncrementRefCount(location, reason); + return RefCountedPtr<Child>(static_cast<Child*>(this)); } // TODO(roth): Once all of our code is converted to C++ and can use - // RefCountedPtr<> instead of manual ref-counting, make the Unref() methods - // private, since they will only be used by RefCountedPtr<>, which is a + // RefCountedPtr<> instead of manual ref-counting, make this method + // private, since it will only be used by RefCountedPtr<>, which is a // friend of this class. - void Unref() { - if (gpr_unref(&refs_)) { + if (refs_.Unref()) { Delete(static_cast<Child*>(this)); } } - void Unref(const DebugLocation& location, const char* reason) { - if (location.Log() && trace_flag_ != nullptr && trace_flag_->enabled()) { - gpr_atm old_refs = gpr_atm_no_barrier_load(&refs_.count); - gpr_log(GPR_INFO, "%s:%p %s:%d unref %" PRIdPTR " -> %" PRIdPTR " %s", - trace_flag_->name(), this, location.file(), location.line(), - old_refs, old_refs - 1, reason); + if (refs_.Unref(location, reason)) { + Delete(static_cast<Child*>(this)); } - Unref(); } // Not copyable nor movable. - RefCountedWithTracing(const RefCountedWithTracing&) = delete; - RefCountedWithTracing& operator=(const RefCountedWithTracing&) = delete; + RefCounted(const RefCounted&) = delete; + RefCounted& operator=(const RefCounted&) = delete; GRPC_ABSTRACT_BASE_CLASS protected: GPRC_ALLOW_CLASS_TO_USE_NON_PUBLIC_DELETE - RefCountedWithTracing() - : RefCountedWithTracing(static_cast<TraceFlag*>(nullptr)) {} - - explicit RefCountedWithTracing(TraceFlag* trace_flag) - : trace_flag_(trace_flag) { - gpr_ref_init(&refs_, 1); - } - -#ifdef NDEBUG - explicit RefCountedWithTracing(DebugOnlyTraceFlag* trace_flag) - : RefCountedWithTracing() {} -#endif + // TraceFlagT is defined to accept both DebugOnlyTraceFlag and TraceFlag. + // Note: RefCount tracing is only enabled on debug builds, even when a + // TraceFlag is used. + template <typename TraceFlagT = TraceFlag> + explicit RefCounted(TraceFlagT* trace_flag = nullptr) + : refs_(1, trace_flag) {} - virtual ~RefCountedWithTracing() {} + // Note: Depending on the Impl used, this dtor can be implicitly virtual. + ~RefCounted() = default; private: // Allow RefCountedPtr<> to access IncrementRefCount(). template <typename T> friend class RefCountedPtr; - void IncrementRefCount() { gpr_ref(&refs_); } + void IncrementRefCount() { refs_.Ref(); } + void IncrementRefCount(const DebugLocation& location, const char* reason) { + refs_.Ref(location, reason); + } - TraceFlag* trace_flag_ = nullptr; - gpr_refcount refs_; + RefCount refs_; }; } // namespace grpc_core diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h index c2dfbdd90f..19f38d7f01 100644 --- a/src/core/lib/gprpp/ref_counted_ptr.h +++ b/src/core/lib/gprpp/ref_counted_ptr.h @@ -21,8 +21,10 @@ #include <grpc/support/port_platform.h> +#include <type_traits> #include <utility> +#include "src/core/lib/gprpp/debug_location.h" #include "src/core/lib/gprpp/memory.h" namespace grpc_core { @@ -48,21 +50,19 @@ class RefCountedPtr { } template <typename Y> RefCountedPtr(RefCountedPtr<Y>&& other) { - value_ = other.value_; + value_ = static_cast<T*>(other.value_); other.value_ = nullptr; } // Move assignment. RefCountedPtr& operator=(RefCountedPtr&& other) { - if (value_ != nullptr) value_->Unref(); - value_ = other.value_; + reset(other.value_); other.value_ = nullptr; return *this; } template <typename Y> RefCountedPtr& operator=(RefCountedPtr<Y>&& other) { - if (value_ != nullptr) value_->Unref(); - value_ = other.value_; + reset(other.value_); other.value_ = nullptr; return *this; } @@ -74,8 +74,10 @@ class RefCountedPtr { } template <typename Y> RefCountedPtr(const RefCountedPtr<Y>& other) { + static_assert(std::has_virtual_destructor<T>::value, + "T does not have a virtual dtor"); if (other.value_ != nullptr) other.value_->IncrementRefCount(); - value_ = other.value_; + value_ = static_cast<T*>(other.value_); } // Copy assignment. @@ -83,17 +85,17 @@ class RefCountedPtr { // Note: Order of reffing and unreffing is important here in case value_ // and other.value_ are the same object. if (other.value_ != nullptr) other.value_->IncrementRefCount(); - if (value_ != nullptr) value_->Unref(); - value_ = other.value_; + reset(other.value_); return *this; } template <typename Y> RefCountedPtr& operator=(const RefCountedPtr<Y>& other) { + static_assert(std::has_virtual_destructor<T>::value, + "T does not have a virtual dtor"); // Note: Order of reffing and unreffing is important here in case value_ // and other.value_ are the same object. if (other.value_ != nullptr) other.value_->IncrementRefCount(); - if (value_ != nullptr) value_->Unref(); - value_ = other.value_; + reset(other.value_); return *this; } @@ -102,15 +104,29 @@ class RefCountedPtr { } // If value is non-null, we take ownership of a ref to it. - template <typename Y> - void reset(Y* value) { + void reset(T* value = nullptr) { if (value_ != nullptr) value_->Unref(); value_ = value; } - - void reset() { + void reset(const DebugLocation& location, const char* reason, + T* value = nullptr) { + if (value_ != nullptr) value_->Unref(location, reason); + value_ = value; + } + template <typename Y> + void reset(Y* value = nullptr) { + static_assert(std::has_virtual_destructor<T>::value, + "T does not have a virtual dtor"); if (value_ != nullptr) value_->Unref(); - value_ = nullptr; + value_ = static_cast<T*>(value); + } + template <typename Y> + void reset(const DebugLocation& location, const char* reason, + Y* value = nullptr) { + static_assert(std::has_virtual_destructor<T>::value, + "T does not have a virtual dtor"); + if (value_ != nullptr) value_->Unref(location, reason); + value_ = static_cast<T*>(value); } // TODO(roth): This method exists solely as a transition mechanism to allow diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 1c798d368b..fdea7511cc 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -29,119 +29,125 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/ssl_utils.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/tsi/ssl_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* handshaker_factory; - char* secure_peer_name; -} grpc_httpcli_ssl_channel_security_connector; - -static void httpcli_ssl_destroy(grpc_security_connector* sc) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - if (c->handshaker_factory != nullptr) { - tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory); - c->handshaker_factory = nullptr; +class grpc_httpcli_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name) + : grpc_channel_security_connector( + /*url_scheme=*/nullptr, + /*channel_creds=*/nullptr, + /*request_metadata_creds=*/nullptr), + secure_peer_name_(secure_peer_name) {} + + ~grpc_httpcli_ssl_channel_security_connector() override { + if (handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(handshaker_factory_); + } + if (secure_peer_name_ != nullptr) { + gpr_free(secure_peer_name_); + } + } + + tsi_result InitHandshakerFactory(const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store) { + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + return tsi_create_ssl_client_handshaker_factory_with_options( + &options, &handshaker_factory_); } - if (c->secure_peer_name != nullptr) gpr_free(c->secure_peer_name); - gpr_free(sc); -} -static void httpcli_ssl_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - tsi_handshaker* handshaker = nullptr; - if (c->handshaker_factory != nullptr) { - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->handshaker_factory, c->secure_peer_name, &handshaker); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + tsi_handshaker* handshaker = nullptr; + if (handshaker_factory_ != nullptr) { + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + handshaker_factory_, secure_peer_name_, &handshaker); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } } + grpc_handshake_manager_add( + handshake_mgr, grpc_security_handshaker_create(handshaker, this)); } - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(handshaker, &sc->base)); -} -static void httpcli_ssl_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc); - grpc_error* error = GRPC_ERROR_NONE; - - /* Check the peer name. */ - if (c->secure_peer_name != nullptr && - !tsi_ssl_peer_matches_name(&peer, c->secure_peer_name)) { - char* msg; - gpr_asprintf(&msg, "Peer name %s is not in peer certificate", - c->secure_peer_name); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); + tsi_ssl_client_handshaker_factory* handshaker_factory() const { + return handshaker_factory_; } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static int httpcli_ssl_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_httpcli_ssl_channel_security_connector* c1 = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc1); - grpc_httpcli_ssl_channel_security_connector* c2 = - reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc2); - return strcmp(c1->secure_peer_name, c2->secure_peer_name); -} + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* /*auth_context*/, + grpc_closure* on_peer_checked) override { + grpc_error* error = GRPC_ERROR_NONE; + + /* Check the peer name. */ + if (secure_peer_name_ != nullptr && + !tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) { + char* msg; + gpr_asprintf(&msg, "Peer name %s is not in peer certificate", + secure_peer_name_); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); + } -static grpc_security_connector_vtable httpcli_ssl_vtable = { - httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp}; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_httpcli_ssl_channel_security_connector*>( + other_sc); + return strcmp(secure_peer_name_, other->secure_peer_name_); + } -static grpc_security_status httpcli_ssl_channel_security_connector_create( - const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, - const char* secure_peer_name, grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_httpcli_ssl_channel_security_connector* c; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + *error = GRPC_ERROR_NONE; + return true; + } - if (secure_peer_name != nullptr && pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, - "Cannot assert a secure peer name without a trust root."); - return GRPC_SECURITY_ERROR; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); } - c = static_cast<grpc_httpcli_ssl_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_httpcli_ssl_channel_security_connector))); + const char* secure_peer_name() const { return secure_peer_name_; } - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &httpcli_ssl_vtable; - if (secure_peer_name != nullptr) { - c->secure_peer_name = gpr_strdup(secure_peer_name); + private: + tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr; + char* secure_peer_name_; +}; + +static grpc_core::RefCountedPtr<grpc_channel_security_connector> +httpcli_ssl_channel_security_connector_create( + const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, + const char* secure_peer_name) { + if (secure_peer_name != nullptr && pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, + "Cannot assert a secure peer name without a trust root."); + return nullptr; } - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.pem_root_certs = pem_root_certs; - options.root_store = root_store; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->handshaker_factory); + grpc_core::RefCountedPtr<grpc_httpcli_ssl_channel_security_connector> c = + grpc_core::MakeRefCounted<grpc_httpcli_ssl_channel_security_connector>( + secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name)); + tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - httpcli_ssl_destroy(&c->base.base); - *sc = nullptr; - return GRPC_SECURITY_ERROR; + return nullptr; } - // We don't actually need a channel credentials object in this case, - // but we set it to a non-nullptr address so that we don't trigger - // assertions in grpc_channel_security_connector_cmp(). - c->base.channel_creds = (grpc_channel_credentials*)1; - c->base.add_handshakers = httpcli_ssl_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return c; } /* handshaker */ @@ -186,10 +192,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, } c->func = on_done; c->arg = arg; - grpc_channel_security_connector* sc = nullptr; - GPR_ASSERT(httpcli_ssl_channel_security_connector_create( - pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK); - grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base); + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store, + host); + GPR_ASSERT(sc != nullptr); + grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_channel_args args = {1, &channel_arg}; c->handshake_mgr = grpc_handshake_manager_create(); grpc_handshakers_add(HANDSHAKER_CLIENT, &args, @@ -197,7 +204,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, grpc_handshake_manager_do_handshake( c->handshake_mgr, tcp, nullptr /* channel_args */, deadline, nullptr /* acceptor */, on_handshake_done, c /* user_data */); - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "httpcli"); + sc.reset(DEBUG_LOCATION, "httpcli"); } const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake}; diff --git a/src/core/lib/http/parser.h b/src/core/lib/http/parser.h index 1d2e13e831..a8f47c96c8 100644 --- a/src/core/lib/http/parser.h +++ b/src/core/lib/http/parser.h @@ -70,13 +70,13 @@ typedef struct grpc_http_request { /* A response */ typedef struct grpc_http_response { /* HTTP status code */ - int status; + int status = 0; /* Headers: count and key/values */ - size_t hdr_count; - grpc_http_header* hdrs; + size_t hdr_count = 0; + grpc_http_header* hdrs = nullptr; /* Body: length and contents; contents are NOT null-terminated */ - size_t body_length; - char* body; + size_t body_length = 0; + char* body = nullptr; } grpc_http_response; typedef struct { diff --git a/src/core/lib/iomgr/buffer_list.cc b/src/core/lib/iomgr/buffer_list.cc index 6ada23db1c..ace17a108d 100644 --- a/src/core/lib/iomgr/buffer_list.cc +++ b/src/core/lib/iomgr/buffer_list.cc @@ -35,6 +35,9 @@ void TracedBuffer::AddNewEntry(TracedBuffer** head, uint32_t seq_no, TracedBuffer* new_elem = New<TracedBuffer>(seq_no, arg); /* Store the current time as the sendmsg time. */ new_elem->ts_.sendmsg_time = gpr_now(GPR_CLOCK_REALTIME); + new_elem->ts_.scheduled_time = gpr_inf_past(GPR_CLOCK_REALTIME); + new_elem->ts_.sent_time = gpr_inf_past(GPR_CLOCK_REALTIME); + new_elem->ts_.acked_time = gpr_inf_past(GPR_CLOCK_REALTIME); if (*head == nullptr) { *head = new_elem; return; @@ -55,10 +58,16 @@ void fill_gpr_from_timestamp(gpr_timespec* gts, const struct timespec* ts) { gts->clock_type = GPR_CLOCK_REALTIME; } +void default_timestamps_callback(void* arg, grpc_core::Timestamps* ts, + grpc_error* shudown_err) { + gpr_log(GPR_DEBUG, "Timestamps callback has not been registered"); +} + /** The saved callback function that will be invoked when we get all the * timestamps that we are going to get for a TracedBuffer. */ void (*timestamps_callback)(void*, grpc_core::Timestamps*, - grpc_error* shutdown_err); + grpc_error* shutdown_err) = + default_timestamps_callback; } /* namespace */ void TracedBuffer::ProcessTimestamp(TracedBuffer** head, @@ -99,18 +108,20 @@ void TracedBuffer::ProcessTimestamp(TracedBuffer** head, } } -void TracedBuffer::Shutdown(TracedBuffer** head, grpc_error* shutdown_err) { +void TracedBuffer::Shutdown(TracedBuffer** head, void* remaining, + grpc_error* shutdown_err) { GPR_DEBUG_ASSERT(head != nullptr); TracedBuffer* elem = *head; while (elem != nullptr) { - if (timestamps_callback) { - timestamps_callback(elem->arg_, &(elem->ts_), shutdown_err); - } + timestamps_callback(elem->arg_, &(elem->ts_), shutdown_err); auto* next = elem->next_; Delete<TracedBuffer>(elem); elem = next; } *head = nullptr; + if (remaining != nullptr) { + timestamps_callback(remaining, nullptr, shutdown_err); + } GRPC_ERROR_UNREF(shutdown_err); } diff --git a/src/core/lib/iomgr/buffer_list.h b/src/core/lib/iomgr/buffer_list.h index cbbf50a657..627f1bde99 100644 --- a/src/core/lib/iomgr/buffer_list.h +++ b/src/core/lib/iomgr/buffer_list.h @@ -37,6 +37,8 @@ struct Timestamps { gpr_timespec scheduled_time; gpr_timespec sent_time; gpr_timespec acked_time; + + uint32_t byte_offset; /* byte offset relative to the start of the RPC */ }; /** TracedBuffer is a class to keep track of timestamps for a specific buffer in @@ -67,13 +69,13 @@ class TracedBuffer { /** Cleans the list by calling the callback for each traced buffer in the list * with timestamps that it has. */ - static void Shutdown(grpc_core::TracedBuffer** head, + static void Shutdown(grpc_core::TracedBuffer** head, void* remaining, grpc_error* shutdown_err); private: GPRC_ALLOW_CLASS_TO_USE_NON_PUBLIC_NEW - TracedBuffer(int seq_no, void* arg) + TracedBuffer(uint32_t seq_no, void* arg) : seq_no_(seq_no), arg_(arg), next_(nullptr) {} uint32_t seq_no_; /* The sequence number for the last byte in the buffer */ @@ -82,7 +84,12 @@ class TracedBuffer { grpc_core::TracedBuffer* next_; /* The next TracedBuffer in the list */ }; #else /* GRPC_LINUX_ERRQUEUE */ -class TracedBuffer {}; +class TracedBuffer { + public: + /* Dummy shutdown function */ + static void Shutdown(grpc_core::TracedBuffer** head, void* remaining, + grpc_error* shutdown_err) {} +}; #endif /* GRPC_LINUX_ERRQUEUE */ /** Sets the callback function to call when timestamps for a write are diff --git a/src/core/lib/iomgr/call_combiner.cc b/src/core/lib/iomgr/call_combiner.cc index 90dda45ba3..6b5759a036 100644 --- a/src/core/lib/iomgr/call_combiner.cc +++ b/src/core/lib/iomgr/call_combiner.cc @@ -39,10 +39,57 @@ static gpr_atm encode_cancel_state_error(grpc_error* error) { return static_cast<gpr_atm>(1) | (gpr_atm)error; } +#ifdef GRPC_TSAN_ENABLED +static void tsan_closure(void* user_data, grpc_error* error) { + grpc_call_combiner* call_combiner = + static_cast<grpc_call_combiner*>(user_data); + // We ref-count the lock, and check if it's already taken. + // If it was taken, we should do nothing. Otherwise, we will mark it as + // locked. Note that if two different threads try to do this, only one of + // them will be able to mark the lock as acquired, while they both run their + // callbacks. In such cases (which should never happen for call_combiner), + // TSAN will correctly produce an error. + // + // TODO(soheil): This only covers the callbacks scheduled by + // grpc_call_combiner_(start|finish). If in the future, a + // callback gets scheduled using other mechanisms, we will need + // to add APIs to externally lock call combiners. + grpc_core::RefCountedPtr<grpc_call_combiner::TsanLock> lock = + call_combiner->tsan_lock; + bool prev = false; + if (lock->taken.compare_exchange_strong(prev, true)) { + TSAN_ANNOTATE_RWLOCK_ACQUIRED(&lock->taken, true); + } else { + lock.reset(); + } + GRPC_CLOSURE_RUN(call_combiner->original_closure, GRPC_ERROR_REF(error)); + if (lock != nullptr) { + TSAN_ANNOTATE_RWLOCK_RELEASED(&lock->taken, true); + bool prev = true; + GPR_ASSERT(lock->taken.compare_exchange_strong(prev, false)); + } +} +#endif + +static void call_combiner_sched_closure(grpc_call_combiner* call_combiner, + grpc_closure* closure, + grpc_error* error) { +#ifdef GRPC_TSAN_ENABLED + call_combiner->original_closure = closure; + GRPC_CLOSURE_SCHED(&call_combiner->tsan_closure, error); +#else + GRPC_CLOSURE_SCHED(closure, error); +#endif +} + void grpc_call_combiner_init(grpc_call_combiner* call_combiner) { gpr_atm_no_barrier_store(&call_combiner->cancel_state, 0); gpr_atm_no_barrier_store(&call_combiner->size, 0); gpr_mpscq_init(&call_combiner->queue); +#ifdef GRPC_TSAN_ENABLED + GRPC_CLOSURE_INIT(&call_combiner->tsan_closure, tsan_closure, call_combiner, + grpc_schedule_on_exec_ctx); +#endif } void grpc_call_combiner_destroy(grpc_call_combiner* call_combiner) { @@ -87,7 +134,7 @@ void grpc_call_combiner_start(grpc_call_combiner* call_combiner, gpr_log(GPR_INFO, " EXECUTING IMMEDIATELY"); } // Queue was empty, so execute this closure immediately. - GRPC_CLOSURE_SCHED(closure, error); + call_combiner_sched_closure(call_combiner, closure, error); } else { if (grpc_call_combiner_trace.enabled()) { gpr_log(GPR_INFO, " QUEUING"); @@ -134,7 +181,8 @@ void grpc_call_combiner_stop(grpc_call_combiner* call_combiner DEBUG_ARGS, gpr_log(GPR_INFO, " EXECUTING FROM QUEUE: closure=%p error=%s", closure, grpc_error_string(closure->error_data.error)); } - GRPC_CLOSURE_SCHED(closure, closure->error_data.error); + call_combiner_sched_closure(call_combiner, closure, + closure->error_data.error); break; } } else if (grpc_call_combiner_trace.enabled()) { diff --git a/src/core/lib/iomgr/call_combiner.h b/src/core/lib/iomgr/call_combiner.h index c943fb1557..4ec0044f05 100644 --- a/src/core/lib/iomgr/call_combiner.h +++ b/src/core/lib/iomgr/call_combiner.h @@ -27,7 +27,10 @@ #include "src/core/lib/gpr/mpscq.h" #include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/closure.h" +#include "src/core/lib/iomgr/dynamic_annotations.h" // A simple, lock-free mechanism for serializing activity related to a // single call. This is similar to a combiner but is more lightweight. @@ -40,14 +43,38 @@ extern grpc_core::TraceFlag grpc_call_combiner_trace; -typedef struct { +struct grpc_call_combiner { gpr_atm size = 0; // size_t, num closures in queue or currently executing gpr_mpscq queue; // Either 0 (if not cancelled and no cancellation closure set), // a grpc_closure* (if the lowest bit is 0), // or a grpc_error* (if the lowest bit is 1). gpr_atm cancel_state = 0; -} grpc_call_combiner; +#ifdef GRPC_TSAN_ENABLED + // A fake ref-counted lock that is kept alive after the destruction of + // grpc_call_combiner, when we are running the original closure. + // + // Ideally we want to lock and unlock the call combiner as a pointer, when the + // callback is called. However, original_closure is free to trigger + // anything on the call combiner (including destruction of grpc_call). + // Thus, we need a ref-counted structure that can outlive the call combiner. + struct TsanLock + : public grpc_core::RefCounted<TsanLock, + grpc_core::NonPolymorphicRefCount> { + TsanLock() { TSAN_ANNOTATE_RWLOCK_CREATE(&taken); } + ~TsanLock() { TSAN_ANNOTATE_RWLOCK_DESTROY(&taken); } + + // To avoid double-locking by the same thread, we should acquire/release + // the lock only when taken is false. On each acquire taken must be set to + // true. + std::atomic<bool> taken{false}; + }; + grpc_core::RefCountedPtr<TsanLock> tsan_lock = + grpc_core::MakeRefCounted<TsanLock>(); + grpc_closure tsan_closure; + grpc_closure* original_closure; +#endif +}; // Assumes memory was initialized to zero. void grpc_call_combiner_init(grpc_call_combiner* call_combiner); diff --git a/src/core/lib/iomgr/dynamic_annotations.h b/src/core/lib/iomgr/dynamic_annotations.h new file mode 100644 index 0000000000..713928023a --- /dev/null +++ b/src/core/lib/iomgr/dynamic_annotations.h @@ -0,0 +1,67 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_LIB_IOMGR_DYNAMIC_ANNOTATIONS_H +#define GRPC_CORE_LIB_IOMGR_DYNAMIC_ANNOTATIONS_H + +#include <grpc/support/port_platform.h> + +#ifdef GRPC_TSAN_ENABLED + +#define TSAN_ANNOTATE_HAPPENS_BEFORE(addr) \ + AnnotateHappensBefore(__FILE__, __LINE__, (void*)(addr)) +#define TSAN_ANNOTATE_HAPPENS_AFTER(addr) \ + AnnotateHappensAfter(__FILE__, __LINE__, (void*)(addr)) +#define TSAN_ANNOTATE_RWLOCK_CREATE(addr) \ + AnnotateRWLockCreate(__FILE__, __LINE__, (void*)(addr)) +#define TSAN_ANNOTATE_RWLOCK_DESTROY(addr) \ + AnnotateRWLockDestroy(__FILE__, __LINE__, (void*)(addr)) +#define TSAN_ANNOTATE_RWLOCK_ACQUIRED(addr, is_w) \ + AnnotateRWLockAcquired(__FILE__, __LINE__, (void*)(addr), (is_w)) +#define TSAN_ANNOTATE_RWLOCK_RELEASED(addr, is_w) \ + AnnotateRWLockReleased(__FILE__, __LINE__, (void*)(addr), (is_w)) + +#ifdef __cplusplus +extern "C" { +#endif +void AnnotateHappensBefore(const char* file, int line, const volatile void* cv); +void AnnotateHappensAfter(const char* file, int line, const volatile void* cv); +void AnnotateRWLockCreate(const char* file, int line, + const volatile void* lock); +void AnnotateRWLockDestroy(const char* file, int line, + const volatile void* lock); +void AnnotateRWLockAcquired(const char* file, int line, + const volatile void* lock, long is_w); +void AnnotateRWLockReleased(const char* file, int line, + const volatile void* lock, long is_w); +#ifdef __cplusplus +} +#endif + +#else /* GRPC_TSAN_ENABLED */ + +#define TSAN_ANNOTATE_HAPPENS_BEFORE(addr) +#define TSAN_ANNOTATE_HAPPENS_AFTER(addr) +#define TSAN_ANNOTATE_RWLOCK_CREATE(addr) +#define TSAN_ANNOTATE_RWLOCK_DESTROY(addr) +#define TSAN_ANNOTATE_RWLOCK_ACQUIRED(addr, is_w) +#define TSAN_ANNOTATE_RWLOCK_RELEASED(addr, is_w) + +#endif /* GRPC_TSAN_ENABLED */ + +#endif /* GRPC_CORE_LIB_IOMGR_DYNAMIC_ANNOTATIONS_H */ diff --git a/src/core/lib/iomgr/endpoint.cc b/src/core/lib/iomgr/endpoint.cc index 44fb47e19d..06316c6031 100644 --- a/src/core/lib/iomgr/endpoint.cc +++ b/src/core/lib/iomgr/endpoint.cc @@ -61,3 +61,7 @@ int grpc_endpoint_get_fd(grpc_endpoint* ep) { return ep->vtable->get_fd(ep); } grpc_resource_user* grpc_endpoint_get_resource_user(grpc_endpoint* ep) { return ep->vtable->get_resource_user(ep); } + +bool grpc_endpoint_can_track_err(grpc_endpoint* ep) { + return ep->vtable->can_track_err(ep); +} diff --git a/src/core/lib/iomgr/endpoint.h b/src/core/lib/iomgr/endpoint.h index 1f590a80ca..79c8ece263 100644 --- a/src/core/lib/iomgr/endpoint.h +++ b/src/core/lib/iomgr/endpoint.h @@ -47,6 +47,7 @@ struct grpc_endpoint_vtable { grpc_resource_user* (*get_resource_user)(grpc_endpoint* ep); char* (*get_peer)(grpc_endpoint* ep); int (*get_fd)(grpc_endpoint* ep); + bool (*can_track_err)(grpc_endpoint* ep); }; /* When data is available on the connection, calls the callback with slices. @@ -95,6 +96,8 @@ void grpc_endpoint_delete_from_pollset_set(grpc_endpoint* ep, grpc_resource_user* grpc_endpoint_get_resource_user(grpc_endpoint* endpoint); +bool grpc_endpoint_can_track_err(grpc_endpoint* ep); + struct grpc_endpoint { const grpc_endpoint_vtable* vtable; }; diff --git a/src/core/lib/iomgr/endpoint_cfstream.cc b/src/core/lib/iomgr/endpoint_cfstream.cc index df2cf508c8..7c4bc1ace2 100644 --- a/src/core/lib/iomgr/endpoint_cfstream.cc +++ b/src/core/lib/iomgr/endpoint_cfstream.cc @@ -315,6 +315,8 @@ char* CFStreamGetPeer(grpc_endpoint* ep) { int CFStreamGetFD(grpc_endpoint* ep) { return 0; } +bool CFStreamCanTrackErr(grpc_endpoint* ep) { return false; } + void CFStreamAddToPollset(grpc_endpoint* ep, grpc_pollset* pollset) {} void CFStreamAddToPollsetSet(grpc_endpoint* ep, grpc_pollset_set* pollset) {} void CFStreamDeleteFromPollsetSet(grpc_endpoint* ep, @@ -329,7 +331,8 @@ static const grpc_endpoint_vtable vtable = {CFStreamRead, CFStreamDestroy, CFStreamGetResourceUser, CFStreamGetPeer, - CFStreamGetFD}; + CFStreamGetFD, + CFStreamCanTrackErr}; grpc_endpoint* grpc_cfstream_endpoint_create( CFReadStreamRef read_stream, CFWriteStreamRef write_stream, diff --git a/src/core/lib/iomgr/endpoint_pair_posix.cc b/src/core/lib/iomgr/endpoint_pair_posix.cc index 3afbfd7254..5c5c246f99 100644 --- a/src/core/lib/iomgr/endpoint_pair_posix.cc +++ b/src/core/lib/iomgr/endpoint_pair_posix.cc @@ -59,11 +59,11 @@ grpc_endpoint_pair grpc_iomgr_create_endpoint_pair(const char* name, grpc_core::ExecCtx exec_ctx; gpr_asprintf(&final_name, "%s:client", name); - p.client = grpc_tcp_create(grpc_fd_create(sv[1], final_name, true), args, + p.client = grpc_tcp_create(grpc_fd_create(sv[1], final_name, false), args, "socketpair-server"); gpr_free(final_name); gpr_asprintf(&final_name, "%s:server", name); - p.server = grpc_tcp_create(grpc_fd_create(sv[0], final_name, true), args, + p.server = grpc_tcp_create(grpc_fd_create(sv[0], final_name, false), args, "socketpair-client"); gpr_free(final_name); diff --git a/src/core/lib/iomgr/ev_epoll1_linux.cc b/src/core/lib/iomgr/ev_epoll1_linux.cc index 38571b1957..4b8c891e9b 100644 --- a/src/core/lib/iomgr/ev_epoll1_linux.cc +++ b/src/core/lib/iomgr/ev_epoll1_linux.cc @@ -1242,6 +1242,8 @@ static void pollset_set_del_pollset_set(grpc_pollset_set* bag, * Event engine binding */ +static void shutdown_background_closure(void) {} + static void shutdown_engine(void) { fd_global_shutdown(); pollset_global_shutdown(); @@ -1255,6 +1257,7 @@ static void shutdown_engine(void) { static const grpc_event_engine_vtable vtable = { sizeof(grpc_pollset), true, + false, fd_create, fd_wrapped_fd, @@ -1284,6 +1287,7 @@ static const grpc_event_engine_vtable vtable = { pollset_set_add_fd, pollset_set_del_fd, + shutdown_background_closure, shutdown_engine, }; diff --git a/src/core/lib/iomgr/ev_epollex_linux.cc b/src/core/lib/iomgr/ev_epollex_linux.cc index 06a382c556..7a4870db78 100644 --- a/src/core/lib/iomgr/ev_epollex_linux.cc +++ b/src/core/lib/iomgr/ev_epollex_linux.cc @@ -1604,6 +1604,8 @@ static void pollset_set_del_pollset_set(grpc_pollset_set* bag, * Event engine binding */ +static void shutdown_background_closure(void) {} + static void shutdown_engine(void) { fd_global_shutdown(); pollset_global_shutdown(); @@ -1612,6 +1614,7 @@ static void shutdown_engine(void) { static const grpc_event_engine_vtable vtable = { sizeof(grpc_pollset), true, + false, fd_create, fd_wrapped_fd, @@ -1641,6 +1644,7 @@ static const grpc_event_engine_vtable vtable = { pollset_set_add_fd, pollset_set_del_fd, + shutdown_background_closure, shutdown_engine, }; diff --git a/src/core/lib/iomgr/ev_poll_posix.cc b/src/core/lib/iomgr/ev_poll_posix.cc index 16562538a6..67cbfbbd02 100644 --- a/src/core/lib/iomgr/ev_poll_posix.cc +++ b/src/core/lib/iomgr/ev_poll_posix.cc @@ -1782,6 +1782,8 @@ static void global_cv_fd_table_shutdown() { * event engine binding */ +static void shutdown_background_closure(void) {} + static void shutdown_engine(void) { pollset_global_shutdown(); if (grpc_cv_wakeup_fds_enabled()) { @@ -1796,6 +1798,7 @@ static void shutdown_engine(void) { static const grpc_event_engine_vtable vtable = { sizeof(grpc_pollset), false, + false, fd_create, fd_wrapped_fd, @@ -1825,6 +1828,7 @@ static const grpc_event_engine_vtable vtable = { pollset_set_add_fd, pollset_set_del_fd, + shutdown_background_closure, shutdown_engine, }; diff --git a/src/core/lib/iomgr/ev_posix.cc b/src/core/lib/iomgr/ev_posix.cc index 8a7dc7b004..32d1b6c43e 100644 --- a/src/core/lib/iomgr/ev_posix.cc +++ b/src/core/lib/iomgr/ev_posix.cc @@ -36,6 +36,7 @@ #include "src/core/lib/iomgr/ev_epoll1_linux.h" #include "src/core/lib/iomgr/ev_epollex_linux.h" #include "src/core/lib/iomgr/ev_poll_posix.h" +#include "src/core/lib/iomgr/internal_errqueue.h" grpc_core::TraceFlag grpc_polling_trace(false, "polling"); /* Disabled by default */ @@ -236,19 +237,22 @@ void grpc_event_engine_shutdown(void) { } bool grpc_event_engine_can_track_errors(void) { -/* Only track errors if platform supports errqueue. */ -#ifdef GRPC_LINUX_ERRQUEUE - return g_event_engine->can_track_err; -#else + /* Only track errors if platform supports errqueue. */ + if (grpc_core::kernel_supports_errqueue()) { + return g_event_engine->can_track_err; + } return false; -#endif /* GRPC_LINUX_ERRQUEUE */ +} + +bool grpc_event_engine_run_in_background(void) { + return g_event_engine->run_in_background; } grpc_fd* grpc_fd_create(int fd, const char* name, bool track_err) { GRPC_POLLING_API_TRACE("fd_create(%d, %s, %d)", fd, name, track_err); GRPC_FD_TRACE("fd_create(%d, %s, %d)", fd, name, track_err); - return g_event_engine->fd_create(fd, name, - track_err && g_event_engine->can_track_err); + return g_event_engine->fd_create( + fd, name, track_err && grpc_event_engine_can_track_errors()); } int grpc_fd_wrapped_fd(grpc_fd* fd) { @@ -395,4 +399,8 @@ void grpc_pollset_set_del_fd(grpc_pollset_set* pollset_set, grpc_fd* fd) { g_event_engine->pollset_set_del_fd(pollset_set, fd); } +void grpc_shutdown_background_closure(void) { + g_event_engine->shutdown_background_closure(); +} + #endif // GRPC_POSIX_SOCKET_EV diff --git a/src/core/lib/iomgr/ev_posix.h b/src/core/lib/iomgr/ev_posix.h index b8fb8f534b..812c7a0f0f 100644 --- a/src/core/lib/iomgr/ev_posix.h +++ b/src/core/lib/iomgr/ev_posix.h @@ -42,6 +42,7 @@ typedef struct grpc_fd grpc_fd; typedef struct grpc_event_engine_vtable { size_t pollset_size; bool can_track_err; + bool run_in_background; grpc_fd* (*fd_create)(int fd, const char* name, bool track_err); int (*fd_wrapped_fd)(grpc_fd* fd); @@ -79,6 +80,7 @@ typedef struct grpc_event_engine_vtable { void (*pollset_set_add_fd)(grpc_pollset_set* pollset_set, grpc_fd* fd); void (*pollset_set_del_fd)(grpc_pollset_set* pollset_set, grpc_fd* fd); + void (*shutdown_background_closure)(void); void (*shutdown_engine)(void); } grpc_event_engine_vtable; @@ -101,6 +103,11 @@ const char* grpc_get_poll_strategy_name(); */ bool grpc_event_engine_can_track_errors(); +/* Returns true if polling engine runs in the background, false otherwise. + * Currently only 'epollbg' runs in the background. + */ +bool grpc_event_engine_run_in_background(); + /* Create a wrapped file descriptor. Requires fd is a non-blocking file descriptor. \a track_err if true means that error events would be tracked separately @@ -174,6 +181,9 @@ void grpc_pollset_add_fd(grpc_pollset* pollset, struct grpc_fd* fd); void grpc_pollset_set_add_fd(grpc_pollset_set* pollset_set, grpc_fd* fd); void grpc_pollset_set_del_fd(grpc_pollset_set* pollset_set, grpc_fd* fd); +/* Shut down all the closures registered in the background poller. */ +void grpc_shutdown_background_closure(); + /* override to allow tests to hook poll() usage */ typedef int (*grpc_poll_function_type)(struct pollfd*, nfds_t, int); extern grpc_poll_function_type grpc_poll_function; diff --git a/src/core/lib/iomgr/exec_ctx.cc b/src/core/lib/iomgr/exec_ctx.cc index d68fa0714b..683dd2f649 100644 --- a/src/core/lib/iomgr/exec_ctx.cc +++ b/src/core/lib/iomgr/exec_ctx.cc @@ -53,6 +53,13 @@ static void exec_ctx_sched(grpc_closure* closure, grpc_error* error) { static gpr_timespec g_start_time; +// For debug of the timer manager crash only. +// TODO (mxyan): remove after bug is fixed. +#ifdef GRPC_DEBUG_TIMER_MANAGER +extern int64_t g_start_time_sec; +extern int64_t g_start_time_nsec; +#endif // GRPC_DEBUG_TIMER_MANAGER + static grpc_millis timespec_to_millis_round_down(gpr_timespec ts) { ts = gpr_time_sub(ts, g_start_time); double x = GPR_MS_PER_SEC * static_cast<double>(ts.tv_sec) + @@ -117,6 +124,12 @@ void ExecCtx::TestOnlyGlobalInit(gpr_timespec new_val) { void ExecCtx::GlobalInit(void) { g_start_time = gpr_now(GPR_CLOCK_MONOTONIC); + // For debug of the timer manager crash only. + // TODO (mxyan): remove after bug is fixed. +#ifdef GRPC_DEBUG_TIMER_MANAGER + g_start_time_sec = g_start_time.tv_sec; + g_start_time_nsec = g_start_time.tv_nsec; +#endif gpr_tls_init(&exec_ctx_); } diff --git a/src/core/lib/iomgr/fork_posix.cc b/src/core/lib/iomgr/fork_posix.cc index e957bad73d..05ecd2a49b 100644 --- a/src/core/lib/iomgr/fork_posix.cc +++ b/src/core/lib/iomgr/fork_posix.cc @@ -60,7 +60,7 @@ void grpc_prefork() { } if (strcmp(grpc_get_poll_strategy_name(), "epoll1") != 0 && strcmp(grpc_get_poll_strategy_name(), "poll") != 0) { - gpr_log(GPR_ERROR, + gpr_log(GPR_INFO, "Fork support is only compatible with the epoll1 and poll polling " "strategies"); } diff --git a/src/core/lib/iomgr/internal_errqueue.cc b/src/core/lib/iomgr/internal_errqueue.cc index 99c22e9055..982d709f09 100644 --- a/src/core/lib/iomgr/internal_errqueue.cc +++ b/src/core/lib/iomgr/internal_errqueue.cc @@ -20,17 +20,50 @@ #include "src/core/lib/iomgr/port.h" +#include <grpc/impl/codegen/log.h> #include "src/core/lib/iomgr/internal_errqueue.h" #ifdef GRPC_POSIX_SOCKET_TCP -bool kernel_supports_errqueue() { +#include <errno.h> +#include <stdlib.h> +#include <string.h> +#include <sys/utsname.h> + +namespace grpc_core { +static bool errqueue_supported = false; + +bool kernel_supports_errqueue() { return errqueue_supported; } + +void grpc_errqueue_init() { +/* Both-compile time and run-time linux kernel versions should be atleast 4.0.0 + */ #ifdef LINUX_VERSION_CODE #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 0, 0) - return true; + struct utsname buffer; + if (uname(&buffer) != 0) { + gpr_log(GPR_ERROR, "uname: %s", strerror(errno)); + return; + } + char* release = buffer.release; + if (release == nullptr) { + return; + } + + if (strtol(release, nullptr, 10) >= 4) { + errqueue_supported = true; + } else { + gpr_log(GPR_DEBUG, "ERRQUEUE support not enabled"); + } #endif /* LINUX_VERSION_CODE <= KERNEL_VERSION(4, 0, 0) */ #endif /* LINUX_VERSION_CODE */ - return false; } +} /* namespace grpc_core */ + +#else + +namespace grpc_core { +void grpc_errqueue_init() {} +} /* namespace grpc_core */ #endif /* GRPC_POSIX_SOCKET_TCP */ diff --git a/src/core/lib/iomgr/internal_errqueue.h b/src/core/lib/iomgr/internal_errqueue.h index 9d122808f9..f8644c2536 100644 --- a/src/core/lib/iomgr/internal_errqueue.h +++ b/src/core/lib/iomgr/internal_errqueue.h @@ -76,8 +76,14 @@ constexpr uint32_t kTimestampingRecordingOptions = * Currently allowing only linux kernels above 4.0.0 */ bool kernel_supports_errqueue(); -} // namespace grpc_core + +} /* namespace grpc_core */ #endif /* GRPC_POSIX_SOCKET_TCP */ +namespace grpc_core { +/* Initializes errqueue support */ +void grpc_errqueue_init(); +} /* namespace grpc_core */ + #endif /* GRPC_CORE_LIB_IOMGR_INTERNAL_ERRQUEUE_H */ diff --git a/src/core/lib/iomgr/iomgr.cc b/src/core/lib/iomgr/iomgr.cc index 46afda1774..eb29973514 100644 --- a/src/core/lib/iomgr/iomgr.cc +++ b/src/core/lib/iomgr/iomgr.cc @@ -33,8 +33,10 @@ #include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/useful.h" #include "src/core/lib/gprpp/thd.h" +#include "src/core/lib/iomgr/buffer_list.h" #include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/iomgr/executor.h" +#include "src/core/lib/iomgr/internal_errqueue.h" #include "src/core/lib/iomgr/iomgr_internal.h" #include "src/core/lib/iomgr/network_status_tracker.h" #include "src/core/lib/iomgr/timer.h" @@ -57,6 +59,7 @@ void grpc_iomgr_init() { g_root_object.name = (char*)"root"; grpc_network_status_init(); grpc_iomgr_platform_init(); + grpc_core::grpc_errqueue_init(); } void grpc_iomgr_start() { grpc_timer_manager_init(); } @@ -154,6 +157,10 @@ void grpc_iomgr_shutdown() { gpr_cv_destroy(&g_rcv); } +void grpc_iomgr_shutdown_background_closure() { + grpc_iomgr_platform_shutdown_background_closure(); +} + void grpc_iomgr_register_object(grpc_iomgr_object* obj, const char* name) { obj->name = gpr_strdup(name); gpr_mu_lock(&g_mu); diff --git a/src/core/lib/iomgr/iomgr.h b/src/core/lib/iomgr/iomgr.h index 537ef8a6ff..8ea9289e06 100644 --- a/src/core/lib/iomgr/iomgr.h +++ b/src/core/lib/iomgr/iomgr.h @@ -35,6 +35,10 @@ void grpc_iomgr_start(); * exec_ctx. */ void grpc_iomgr_shutdown(); +/** Signals the intention to shutdown all the closures registered in the + * background poller. */ +void grpc_iomgr_shutdown_background_closure(); + /* Exposed only for testing */ size_t grpc_iomgr_count_objects_for_testing(); diff --git a/src/core/lib/iomgr/iomgr_custom.cc b/src/core/lib/iomgr/iomgr_custom.cc index d34c8e7cd1..4b112c9097 100644 --- a/src/core/lib/iomgr/iomgr_custom.cc +++ b/src/core/lib/iomgr/iomgr_custom.cc @@ -40,9 +40,11 @@ static void iomgr_platform_init(void) { } static void iomgr_platform_flush(void) {} static void iomgr_platform_shutdown(void) { grpc_pollset_global_shutdown(); } +static void iomgr_platform_shutdown_background_closure(void) {} static grpc_iomgr_platform_vtable vtable = { - iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown}; + iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure}; void grpc_custom_iomgr_init(grpc_socket_vtable* socket, grpc_custom_resolver_vtable* resolver, diff --git a/src/core/lib/iomgr/iomgr_internal.cc b/src/core/lib/iomgr/iomgr_internal.cc index 32dbabb79d..b6c9211865 100644 --- a/src/core/lib/iomgr/iomgr_internal.cc +++ b/src/core/lib/iomgr/iomgr_internal.cc @@ -41,3 +41,7 @@ void grpc_iomgr_platform_init() { iomgr_platform_vtable->init(); } void grpc_iomgr_platform_flush() { iomgr_platform_vtable->flush(); } void grpc_iomgr_platform_shutdown() { iomgr_platform_vtable->shutdown(); } + +void grpc_iomgr_platform_shutdown_background_closure() { + iomgr_platform_vtable->shutdown_background_closure(); +} diff --git a/src/core/lib/iomgr/iomgr_internal.h b/src/core/lib/iomgr/iomgr_internal.h index b011d9c7b1..bca7409907 100644 --- a/src/core/lib/iomgr/iomgr_internal.h +++ b/src/core/lib/iomgr/iomgr_internal.h @@ -35,6 +35,7 @@ typedef struct grpc_iomgr_platform_vtable { void (*init)(void); void (*flush)(void); void (*shutdown)(void); + void (*shutdown_background_closure)(void); } grpc_iomgr_platform_vtable; void grpc_iomgr_register_object(grpc_iomgr_object* obj, const char* name); @@ -52,6 +53,9 @@ void grpc_iomgr_platform_flush(void); /** tear down all platform specific global iomgr structures */ void grpc_iomgr_platform_shutdown(void); +/** shut down all the closures registered in the background poller */ +void grpc_iomgr_platform_shutdown_background_closure(void); + bool grpc_iomgr_abort_on_leaks(void); #endif /* GRPC_CORE_LIB_IOMGR_IOMGR_INTERNAL_H */ diff --git a/src/core/lib/iomgr/iomgr_posix.cc b/src/core/lib/iomgr/iomgr_posix.cc index ca7334c9a4..9386adf060 100644 --- a/src/core/lib/iomgr/iomgr_posix.cc +++ b/src/core/lib/iomgr/iomgr_posix.cc @@ -51,8 +51,13 @@ static void iomgr_platform_shutdown(void) { grpc_wakeup_fd_global_destroy(); } +static void iomgr_platform_shutdown_background_closure(void) { + grpc_shutdown_background_closure(); +} + static grpc_iomgr_platform_vtable vtable = { - iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown}; + iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure}; void grpc_set_default_iomgr_platform() { grpc_set_tcp_client_impl(&grpc_posix_tcp_client_vtable); diff --git a/src/core/lib/iomgr/iomgr_posix_cfstream.cc b/src/core/lib/iomgr/iomgr_posix_cfstream.cc index 235a9e0712..552ef4309c 100644 --- a/src/core/lib/iomgr/iomgr_posix_cfstream.cc +++ b/src/core/lib/iomgr/iomgr_posix_cfstream.cc @@ -54,8 +54,13 @@ static void iomgr_platform_shutdown(void) { grpc_wakeup_fd_global_destroy(); } +static void iomgr_platform_shutdown_background_closure(void) { + grpc_shutdown_background_closure(); +} + static grpc_iomgr_platform_vtable vtable = { - iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown}; + iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure}; void grpc_set_default_iomgr_platform() { char* enable_cfstream = getenv(grpc_cfstream_env_var); diff --git a/src/core/lib/iomgr/iomgr_windows.cc b/src/core/lib/iomgr/iomgr_windows.cc index cdef89cbf0..24ef0dba7b 100644 --- a/src/core/lib/iomgr/iomgr_windows.cc +++ b/src/core/lib/iomgr/iomgr_windows.cc @@ -71,8 +71,11 @@ static void iomgr_platform_shutdown(void) { winsock_shutdown(); } +static void iomgr_platform_shutdown_background_closure(void) {} + static grpc_iomgr_platform_vtable vtable = { - iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown}; + iomgr_platform_init, iomgr_platform_flush, iomgr_platform_shutdown, + iomgr_platform_shutdown_background_closure}; void grpc_set_default_iomgr_platform() { grpc_set_tcp_client_impl(&grpc_windows_tcp_client_vtable); diff --git a/src/core/lib/iomgr/port.h b/src/core/lib/iomgr/port.h index bf56a7298d..c8046b21dc 100644 --- a/src/core/lib/iomgr/port.h +++ b/src/core/lib/iomgr/port.h @@ -62,8 +62,7 @@ #define GRPC_HAVE_UNIX_SOCKET 1 #ifdef LINUX_VERSION_CODE #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 0, 0) -/* TODO(yashykt): Re-enable once Fathom changes are commited. -#define GRPC_LINUX_ERRQUEUE 1 */ +#define GRPC_LINUX_ERRQUEUE 1 #endif /* LINUX_VERSION_CODE >= KERNEL_VERSION(4, 0, 0) */ #endif /* LINUX_VERSION_CODE */ #define GRPC_LINUX_MULTIPOLL_WITH_EPOLL 1 diff --git a/src/core/lib/iomgr/resolve_address.h b/src/core/lib/iomgr/resolve_address.h index 6afe94a7a9..7016ffc31a 100644 --- a/src/core/lib/iomgr/resolve_address.h +++ b/src/core/lib/iomgr/resolve_address.h @@ -65,7 +65,7 @@ void grpc_set_resolver_impl(grpc_address_resolver_vtable* vtable); /* Asynchronously resolve addr. Use default_port if a port isn't designated in addr, otherwise use the port in addr. */ -/* TODO(ctiller): add a timeout here */ +/* TODO(apolcyn): add a timeout here */ void grpc_resolve_address(const char* addr, const char* default_port, grpc_pollset_set* interested_parties, grpc_closure* on_done, diff --git a/src/core/lib/iomgr/sockaddr_utils.cc b/src/core/lib/iomgr/sockaddr_utils.cc index 1b66dceb13..0839bdfef2 100644 --- a/src/core/lib/iomgr/sockaddr_utils.cc +++ b/src/core/lib/iomgr/sockaddr_utils.cc @@ -217,6 +217,7 @@ void grpc_string_to_sockaddr(grpc_resolved_address* out, char* addr, int port) { } char* grpc_sockaddr_to_uri(const grpc_resolved_address* resolved_addr) { + if (resolved_addr->len == 0) return nullptr; grpc_resolved_address addr_normalized; if (grpc_sockaddr_is_v4mapped(resolved_addr, &addr_normalized)) { resolved_addr = &addr_normalized; diff --git a/src/core/lib/iomgr/tcp_client_posix.cc b/src/core/lib/iomgr/tcp_client_posix.cc index 8553ed0db4..0bff74e88b 100644 --- a/src/core/lib/iomgr/tcp_client_posix.cc +++ b/src/core/lib/iomgr/tcp_client_posix.cc @@ -76,6 +76,8 @@ static grpc_error* prepare_socket(const grpc_resolved_address* addr, int fd, if (!grpc_is_unix_socket(addr)) { err = grpc_set_socket_low_latency(fd, 1); if (err != GRPC_ERROR_NONE) goto error; + err = grpc_set_socket_reuse_addr(fd, 1); + if (err != GRPC_ERROR_NONE) goto error; err = grpc_set_socket_tcp_user_timeout(fd, channel_args, true /* is_client */); if (err != GRPC_ERROR_NONE) goto error; diff --git a/src/core/lib/iomgr/tcp_custom.cc b/src/core/lib/iomgr/tcp_custom.cc index e02a1898f2..f7a5f36cdc 100644 --- a/src/core/lib/iomgr/tcp_custom.cc +++ b/src/core/lib/iomgr/tcp_custom.cc @@ -326,6 +326,8 @@ static grpc_resource_user* endpoint_get_resource_user(grpc_endpoint* ep) { static int endpoint_get_fd(grpc_endpoint* ep) { return -1; } +static bool endpoint_can_track_err(grpc_endpoint* ep) { return false; } + static grpc_endpoint_vtable vtable = {endpoint_read, endpoint_write, endpoint_add_to_pollset, @@ -335,7 +337,8 @@ static grpc_endpoint_vtable vtable = {endpoint_read, endpoint_destroy, endpoint_get_resource_user, endpoint_get_peer, - endpoint_get_fd}; + endpoint_get_fd, + endpoint_can_track_err}; grpc_endpoint* custom_tcp_endpoint_create(grpc_custom_socket* socket, grpc_resource_quota* resource_quota, diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc index aa2704ce26..c268c18664 100644 --- a/src/core/lib/iomgr/tcp_posix.cc +++ b/src/core/lib/iomgr/tcp_posix.cc @@ -126,6 +126,7 @@ struct grpc_tcp { int bytes_counter; bool socket_ts_enabled; /* True if timestamping options are set on the socket */ + bool ts_capable; /* Cache whether we can set timestamping options */ gpr_atm stop_error_notification; /* Set to 1 if we do not want to be notified on errors anymore */ @@ -260,10 +261,17 @@ static void notify_on_write(grpc_tcp* tcp) { if (grpc_tcp_trace.enabled()) { gpr_log(GPR_INFO, "TCP:%p notify_on_write", tcp); } - cover_self(tcp); - GRPC_CLOSURE_INIT(&tcp->write_done_closure, - tcp_drop_uncovered_then_handle_write, tcp, - grpc_schedule_on_exec_ctx); + if (grpc_event_engine_run_in_background()) { + // If there is a polling engine always running in the background, there is + // no need to run the backup poller. + GRPC_CLOSURE_INIT(&tcp->write_done_closure, tcp_handle_write, tcp, + grpc_schedule_on_exec_ctx); + } else { + cover_self(tcp); + GRPC_CLOSURE_INIT(&tcp->write_done_closure, + tcp_drop_uncovered_then_handle_write, tcp, + grpc_schedule_on_exec_ctx); + } grpc_fd_notify_on_write(tcp->em_fd, &tcp->write_done_closure); } @@ -384,6 +392,12 @@ static void tcp_destroy(grpc_endpoint* ep) { grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); if (grpc_event_engine_can_track_errors()) { + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::Shutdown( + &tcp->tb_head, tcp->outgoing_buffer_arg, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed")); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); grpc_fd_set_error(tcp->em_fd); } @@ -576,7 +590,7 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) { */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, grpc_error** error); + ssize_t* sent_length); /** The callback function to be invoked when we get an error on the socket. */ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); @@ -584,13 +598,11 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error); #ifdef GRPC_LINUX_ERRQUEUE static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { if (!tcp->socket_ts_enabled) { uint32_t opt = grpc_core::kTimestampingSocketOptions; if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING, static_cast<void*>(&opt), sizeof(opt)) != 0) { - *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "setsockopt"), tcp); grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); if (grpc_tcp_trace.enabled()) { gpr_log(GPR_ERROR, "Failed to set timestamping options on the socket."); @@ -621,7 +633,7 @@ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, if (sending_length == static_cast<size_t>(length)) { gpr_mu_lock(&tcp->tb_mu); grpc_core::TracedBuffer::AddNewEntry( - &tcp->tb_head, static_cast<int>(tcp->bytes_counter + length), + &tcp->tb_head, static_cast<uint32_t>(tcp->bytes_counter + length), tcp->outgoing_buffer_arg); gpr_mu_unlock(&tcp->tb_mu); tcp->outgoing_buffer_arg = nullptr; @@ -673,11 +685,9 @@ struct cmsghdr* process_timestamp(grpc_tcp* tcp, msghdr* msg, } /** For linux platforms, reads the socket's error queue and processes error - * messages from the queue. Returns true if all the errors processed were - * timestamps. Returns false if any of the errors were not timestamps. For - * non-linux platforms, error processing is not used/enabled currently. + * messages from the queue. */ -static bool process_errors(grpc_tcp* tcp) { +static void process_errors(grpc_tcp* tcp) { while (true) { struct iovec iov; iov.iov_base = nullptr; @@ -706,10 +716,10 @@ static bool process_errors(grpc_tcp* tcp) { } while (r < 0 && saved_errno == EINTR); if (r == -1 && saved_errno == EAGAIN) { - return true; /* No more errors to process */ + return; /* No more errors to process */ } if (r == -1) { - return false; + return; } if (grpc_tcp_trace.enabled()) { if ((msg.msg_flags & MSG_CTRUNC) == 1) { @@ -719,8 +729,9 @@ static bool process_errors(grpc_tcp* tcp) { if (msg.msg_controllen == 0) { /* There was no control message found. It was probably spurious. */ - return true; + return; } + bool seen = false; for (auto cmsg = CMSG_FIRSTHDR(&msg); cmsg && cmsg->cmsg_len; cmsg = CMSG_NXTHDR(&msg, cmsg)) { if (cmsg->cmsg_level != SOL_SOCKET || @@ -732,9 +743,13 @@ static bool process_errors(grpc_tcp* tcp) { "unknown control message cmsg_level:%d cmsg_type:%d", cmsg->cmsg_level, cmsg->cmsg_type); } - return false; + return; } - process_timestamp(tcp, &msg, cmsg); + cmsg = process_timestamp(tcp, &msg, cmsg); + seen = true; + } + if (!seen) { + return; } } } @@ -749,20 +764,17 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { static_cast<bool>(gpr_atm_acq_load(&tcp->stop_error_notification))) { /* We aren't going to register to hear on error anymore, so it is safe to * unref. */ - grpc_core::TracedBuffer::Shutdown(&tcp->tb_head, GRPC_ERROR_REF(error)); TCP_UNREF(tcp, "error-tracking"); return; } /* We are still interested in collecting timestamps, so let's try reading * them. */ - if (!process_errors(tcp)) { - /* This was not a timestamps error. This was an actual error. Set the - * read and write closures to be ready. - */ - grpc_fd_set_readable(tcp->em_fd); - grpc_fd_set_writable(tcp->em_fd); - } + process_errors(tcp); + /* This might not a timestamps error. Set the read and write closures to be + * ready. */ + grpc_fd_set_readable(tcp->em_fd); + grpc_fd_set_writable(tcp->em_fd); GRPC_CLOSURE_INIT(&tcp->error_closure, tcp_handle_error, tcp, grpc_schedule_on_exec_ctx); grpc_fd_notify_on_error(tcp->em_fd, &tcp->error_closure); @@ -771,8 +783,7 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { #else /* GRPC_LINUX_ERRQUEUE */ static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg, size_t sending_length, - ssize_t* sent_length, - grpc_error** error) { + ssize_t* sent_length) { gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform"); GPR_ASSERT(0); return false; @@ -784,6 +795,19 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) { } #endif /* GRPC_LINUX_ERRQUEUE */ +/* If outgoing_buffer_arg is filled, shuts down the list early, so that any + * release operations needed can be performed on the arg */ +void tcp_shutdown_buffer_list(grpc_tcp* tcp) { + if (tcp->outgoing_buffer_arg) { + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::Shutdown( + &tcp->tb_head, tcp->outgoing_buffer_arg, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("TracedBuffer list shutdown")); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; + } +} + /* returns true if done, false if pending; if returning true, *error is set */ #if defined(IOV_MAX) && IOV_MAX < 1000 #define MAX_WRITE_IOVEC IOV_MAX @@ -794,7 +818,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { struct msghdr msg; struct iovec iov[MAX_WRITE_IOVEC]; msg_iovlen_type iov_size; - ssize_t sent_length; + ssize_t sent_length = 0; size_t sending_length; size_t trailing; size_t unwind_slice_idx; @@ -829,11 +853,19 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { msg.msg_iov = iov; msg.msg_iovlen = iov_size; msg.msg_flags = 0; + bool tried_sending_message = false; if (tcp->outgoing_buffer_arg != nullptr) { - if (!tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length, - error)) - return true; /* something went wrong with timestamps */ - } else { + if (!tcp->ts_capable || + !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length)) { + /* We could not set socket options to collect Fathom timestamps. + * Fallback on writing without timestamps. */ + tcp->ts_capable = false; + tcp_shutdown_buffer_list(tcp); + } else { + tried_sending_message = true; + } + } + if (!tried_sending_message) { msg.msg_control = nullptr; msg.msg_controllen = 0; @@ -856,10 +888,12 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) { } else if (errno == EPIPE) { *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); + tcp_shutdown_buffer_list(tcp); return true; } else { *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "sendmsg"), tcp); grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer); + tcp_shutdown_buffer_list(tcp); return true; } } @@ -936,17 +970,18 @@ static void tcp_write(grpc_endpoint* ep, grpc_slice_buffer* buf, GPR_ASSERT(tcp->write_cb == nullptr); + tcp->outgoing_buffer_arg = arg; if (buf->length == 0) { GRPC_CLOSURE_SCHED( cb, grpc_fd_is_shutdown(tcp->em_fd) ? tcp_annotate_error( GRPC_ERROR_CREATE_FROM_STATIC_STRING("EOF"), tcp) : GRPC_ERROR_NONE); + tcp_shutdown_buffer_list(tcp); return; } tcp->outgoing_buffer = buf; tcp->outgoing_byte_idx = 0; - tcp->outgoing_buffer_arg = arg; if (arg) { GPR_ASSERT(grpc_event_engine_can_track_errors()); } @@ -999,6 +1034,22 @@ static grpc_resource_user* tcp_get_resource_user(grpc_endpoint* ep) { return tcp->resource_user; } +static bool tcp_can_track_err(grpc_endpoint* ep) { + grpc_tcp* tcp = reinterpret_cast<grpc_tcp*>(ep); + if (!grpc_event_engine_can_track_errors()) { + return false; + } + struct sockaddr addr; + socklen_t len = sizeof(addr); + if (getsockname(tcp->fd, &addr, &len) < 0) { + return false; + } + if (addr.sa_family == AF_INET || addr.sa_family == AF_INET6) { + return true; + } + return false; +} + static const grpc_endpoint_vtable vtable = {tcp_read, tcp_write, tcp_add_to_pollset, @@ -1008,7 +1059,8 @@ static const grpc_endpoint_vtable vtable = {tcp_read, tcp_destroy, tcp_get_resource_user, tcp_get_peer, - tcp_get_fd}; + tcp_get_fd, + tcp_can_track_err}; #define MAX_CHUNK_SIZE 32 * 1024 * 1024 @@ -1069,6 +1121,8 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd, tcp->is_first_read = true; tcp->bytes_counter = -1; tcp->socket_ts_enabled = false; + tcp->ts_capable = true; + tcp->outgoing_buffer_arg = nullptr; /* paired with unref in grpc_tcp_destroy */ gpr_ref_init(&tcp->refcount, 1); gpr_atm_no_barrier_store(&tcp->shutdown_count, 0); @@ -1113,6 +1167,12 @@ void grpc_tcp_destroy_and_release_fd(grpc_endpoint* ep, int* fd, grpc_slice_buffer_reset_and_unref_internal(&tcp->last_read_buffer); if (grpc_event_engine_can_track_errors()) { /* Stop errors notification. */ + gpr_mu_lock(&tcp->tb_mu); + grpc_core::TracedBuffer::Shutdown( + &tcp->tb_head, tcp->outgoing_buffer_arg, + GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed")); + gpr_mu_unlock(&tcp->tb_mu); + tcp->outgoing_buffer_arg = nullptr; gpr_atm_no_barrier_store(&tcp->stop_error_notification, true); grpc_fd_set_error(tcp->em_fd); } diff --git a/src/core/lib/iomgr/tcp_windows.cc b/src/core/lib/iomgr/tcp_windows.cc index 64c4a56ae9..4b5250803d 100644 --- a/src/core/lib/iomgr/tcp_windows.cc +++ b/src/core/lib/iomgr/tcp_windows.cc @@ -427,6 +427,8 @@ static grpc_resource_user* win_get_resource_user(grpc_endpoint* ep) { static int win_get_fd(grpc_endpoint* ep) { return -1; } +static bool win_can_track_err(grpc_endpoint* ep) { return false; } + static grpc_endpoint_vtable vtable = {win_read, win_write, win_add_to_pollset, @@ -436,7 +438,8 @@ static grpc_endpoint_vtable vtable = {win_read, win_destroy, win_get_resource_user, win_get_peer, - win_get_fd}; + win_get_fd, + win_can_track_err}; grpc_endpoint* grpc_tcp_create(grpc_winsocket* socket, grpc_channel_args* channel_args, diff --git a/src/core/lib/iomgr/timer_manager.cc b/src/core/lib/iomgr/timer_manager.cc index ceba79f678..cb123298cf 100644 --- a/src/core/lib/iomgr/timer_manager.cc +++ b/src/core/lib/iomgr/timer_manager.cc @@ -67,6 +67,7 @@ static void timer_thread(void* completed_thread_ptr); extern int64_t g_timer_manager_init_count; extern int64_t g_timer_manager_shutdown_count; extern int64_t g_fork_count; +extern int64_t g_next_value; #endif // GRPC_DEBUG_TIMER_MANAGER static void gc_completed_threads(void) { @@ -193,6 +194,11 @@ static bool wait_until(grpc_millis next) { gpr_log(GPR_INFO, "sleep until kicked"); } + // For debug of the timer manager crash only. + // TODO (mxyan): remove after bug is fixed. +#ifdef GRPC_DEBUG_TIMER_MANAGER + g_next_value = next; +#endif gpr_cv_wait(&g_cv_wait, &g_mu, grpc_millis_to_timespec(next, GPR_CLOCK_MONOTONIC)); diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc index 16f40b4f55..8443ee0695 100644 --- a/src/core/lib/security/context/security_context.cc +++ b/src/core/lib/security/context/security_context.cc @@ -23,6 +23,8 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/arena.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/surface/api_trace.h" #include "src/core/lib/surface/call.h" @@ -50,13 +52,11 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call, ctx = static_cast<grpc_client_security_context*>( grpc_call_context_get(call, GRPC_CONTEXT_SECURITY)); if (ctx == nullptr) { - ctx = grpc_client_security_context_create(grpc_call_get_arena(call)); - ctx->creds = grpc_call_credentials_ref(creds); + ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds); grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx, grpc_client_security_context_destroy); } else { - grpc_call_credentials_unref(ctx->creds); - ctx->creds = grpc_call_credentials_ref(creds); + ctx->creds = creds != nullptr ? creds->Ref() : nullptr; } return GRPC_CALL_OK; @@ -66,33 +66,45 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) { void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY); GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call)); if (sec_ctx == nullptr) return nullptr; - return grpc_call_is_client(call) - ? GRPC_AUTH_CONTEXT_REF( - ((grpc_client_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context client") - : GRPC_AUTH_CONTEXT_REF( - ((grpc_server_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context server"); + if (grpc_call_is_client(call)) { + auto* sc = static_cast<grpc_client_security_context*>(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context client") + .release(); + } + } else { + auto* sc = static_cast<grpc_server_security_context*>(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context server") + .release(); + } + } } void grpc_auth_context_release(grpc_auth_context* context) { GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context)); - GRPC_AUTH_CONTEXT_UNREF(context, "grpc_auth_context_unref"); + if (context == nullptr) return; + context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref"); } /* --- grpc_client_security_context --- */ grpc_client_security_context::~grpc_client_security_context() { - grpc_call_credentials_unref(creds); - GRPC_AUTH_CONTEXT_UNREF(auth_context, "client_security_context"); + auth_context.reset(DEBUG_LOCATION, "client_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } } grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena) { + gpr_arena* arena, grpc_call_credentials* creds) { return new (gpr_arena_alloc(arena, sizeof(grpc_client_security_context))) - grpc_client_security_context(); + grpc_client_security_context(creds != nullptr ? creds->Ref() : nullptr); } void grpc_client_security_context_destroy(void* ctx) { @@ -104,7 +116,7 @@ void grpc_client_security_context_destroy(void* ctx) { /* --- grpc_server_security_context --- */ grpc_server_security_context::~grpc_server_security_context() { - GRPC_AUTH_CONTEXT_UNREF(auth_context, "server_security_context"); + auth_context.reset(DEBUG_LOCATION, "server_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } @@ -126,69 +138,11 @@ void grpc_server_security_context_destroy(void* ctx) { static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr}; -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained) { - grpc_auth_context* ctx = - static_cast<grpc_auth_context*>(gpr_zalloc(sizeof(grpc_auth_context))); - gpr_ref_init(&ctx->refcount, 1); - if (chained != nullptr) { - ctx->chained = GRPC_AUTH_CONTEXT_REF(chained, "chained"); - ctx->peer_identity_property_name = - ctx->chained->peer_identity_property_name; - } - return ctx; -} - -#ifndef NDEBUG -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx, - const char* file, int line, - const char* reason) { - if (ctx == nullptr) return nullptr; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val + 1, reason); - } -#else -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx) { - if (ctx == nullptr) return nullptr; -#endif - gpr_ref(&ctx->refcount); - return ctx; -} - -#ifndef NDEBUG -void grpc_auth_context_unref(grpc_auth_context* ctx, const char* file, int line, - const char* reason) { - if (ctx == nullptr) return; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val - 1, reason); - } -#else -void grpc_auth_context_unref(grpc_auth_context* ctx) { - if (ctx == nullptr) return; -#endif - if (gpr_unref(&ctx->refcount)) { - size_t i; - GRPC_AUTH_CONTEXT_UNREF(ctx->chained, "chained"); - if (ctx->properties.array != nullptr) { - for (i = 0; i < ctx->properties.count; i++) { - grpc_auth_property_reset(&ctx->properties.array[i]); - } - gpr_free(ctx->properties.array); - } - gpr_free(ctx); - } -} - const char* grpc_auth_context_peer_identity_property_name( const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name; + return ctx->peer_identity_property_name(); } int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, @@ -204,13 +158,13 @@ int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, name != nullptr ? name : "NULL"); return 0; } - ctx->peer_identity_property_name = prop->name; + ctx->set_peer_identity_property_name(prop->name); return 1; } int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name == nullptr ? 0 : 1; + return ctx->is_authenticated(); } grpc_auth_property_iterator grpc_auth_context_property_iterator( @@ -226,16 +180,17 @@ const grpc_auth_property* grpc_auth_property_iterator_next( grpc_auth_property_iterator* it) { GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it)); if (it == nullptr || it->ctx == nullptr) return nullptr; - while (it->index == it->ctx->properties.count) { - if (it->ctx->chained == nullptr) return nullptr; - it->ctx = it->ctx->chained; + while (it->index == it->ctx->properties().count) { + if (it->ctx->chained() == nullptr) return nullptr; + it->ctx = it->ctx->chained(); it->index = 0; } if (it->name == nullptr) { - return &it->ctx->properties.array[it->index++]; + return &it->ctx->properties().array[it->index++]; } else { - while (it->index < it->ctx->properties.count) { - const grpc_auth_property* prop = &it->ctx->properties.array[it->index++]; + while (it->index < it->ctx->properties().count) { + const grpc_auth_property* prop = + &it->ctx->properties().array[it->index++]; GPR_ASSERT(prop->name != nullptr); if (strcmp(it->name, prop->name) == 0) { return prop; @@ -262,49 +217,56 @@ grpc_auth_property_iterator grpc_auth_context_peer_identity( GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx)); if (ctx == nullptr) return empty_iterator; return grpc_auth_context_find_properties_by_name( - ctx, ctx->peer_identity_property_name); + ctx, ctx->peer_identity_property_name()); } -static void ensure_auth_context_capacity(grpc_auth_context* ctx) { - if (ctx->properties.count == ctx->properties.capacity) { - ctx->properties.capacity = - GPR_MAX(ctx->properties.capacity + 8, ctx->properties.capacity * 2); - ctx->properties.array = static_cast<grpc_auth_property*>( - gpr_realloc(ctx->properties.array, - ctx->properties.capacity * sizeof(grpc_auth_property))); +void grpc_auth_context::ensure_capacity() { + if (properties_.count == properties_.capacity) { + properties_.capacity = + GPR_MAX(properties_.capacity + 8, properties_.capacity * 2); + properties_.array = static_cast<grpc_auth_property*>(gpr_realloc( + properties_.array, properties_.capacity * sizeof(grpc_auth_property))); } } +void grpc_auth_context::add_property(const char* name, const char* value, + size_t value_length) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; + prop->name = gpr_strdup(name); + prop->value = static_cast<char*>(gpr_malloc(value_length + 1)); + memcpy(prop->value, value, value_length); + prop->value[value_length] = '\0'; + prop->value_length = value_length; +} + void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name, const char* value, size_t value_length) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, " "value_length=%lu)", 6, (ctx, name, (int)value_length, (int)value_length, value, (unsigned long)value_length)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; + ctx->add_property(name, value, value_length); +} + +void grpc_auth_context::add_cstring_property(const char* name, + const char* value) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; prop->name = gpr_strdup(name); - prop->value = static_cast<char*>(gpr_malloc(value_length + 1)); - memcpy(prop->value, value, value_length); - prop->value[value_length] = '\0'; - prop->value_length = value_length; + prop->value = gpr_strdup(value); + prop->value_length = strlen(value); } void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx, const char* name, const char* value) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3, (ctx, name, value)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; - prop->name = gpr_strdup(name); - prop->value = gpr_strdup(value); - prop->value_length = strlen(value); + ctx->add_cstring_property(name, value); } void grpc_auth_property_reset(grpc_auth_property* property) { @@ -314,12 +276,17 @@ void grpc_auth_property_reset(grpc_auth_property* property) { } static void auth_context_pointer_arg_destroy(void* p) { - GRPC_AUTH_CONTEXT_UNREF((grpc_auth_context*)p, "auth_context_pointer_arg"); + if (p != nullptr) { + static_cast<grpc_auth_context*>(p)->Unref(DEBUG_LOCATION, + "auth_context_pointer_arg"); + } } static void* auth_context_pointer_arg_copy(void* p) { - return GRPC_AUTH_CONTEXT_REF((grpc_auth_context*)p, - "auth_context_pointer_arg"); + auto* ctx = static_cast<grpc_auth_context*>(p); + return ctx == nullptr + ? nullptr + : ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release(); } static int auth_context_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h index e45415f63b..b43ee5e62d 100644 --- a/src/core/lib/security/context/security_context.h +++ b/src/core/lib/security/context/security_context.h @@ -21,6 +21,8 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/security/credentials/credentials.h" @@ -40,39 +42,59 @@ struct grpc_auth_property_array { size_t capacity = 0; }; -struct grpc_auth_context { - grpc_auth_context() { gpr_ref_init(&refcount, 0); } +void grpc_auth_property_reset(grpc_auth_property* property); - struct grpc_auth_context* chained = nullptr; - grpc_auth_property_array properties; - gpr_refcount refcount; - const char* peer_identity_property_name = nullptr; - grpc_pollset* pollset = nullptr; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_auth_context + : public grpc_core::RefCounted<grpc_auth_context, + grpc_core::NonPolymorphicRefCount> { + public: + explicit grpc_auth_context( + grpc_core::RefCountedPtr<grpc_auth_context> chained) + : grpc_core::RefCounted<grpc_auth_context, + grpc_core::NonPolymorphicRefCount>( + &grpc_trace_auth_context_refcount), + chained_(std::move(chained)) { + if (chained_ != nullptr) { + peer_identity_property_name_ = chained_->peer_identity_property_name_; + } + } + + ~grpc_auth_context() { + chained_.reset(DEBUG_LOCATION, "chained"); + if (properties_.array != nullptr) { + for (size_t i = 0; i < properties_.count; i++) { + grpc_auth_property_reset(&properties_.array[i]); + } + gpr_free(properties_.array); + } + } + + const grpc_auth_context* chained() const { return chained_.get(); } + const grpc_auth_property_array& properties() const { return properties_; } + + bool is_authenticated() const { + return peer_identity_property_name_ != nullptr; + } + const char* peer_identity_property_name() const { + return peer_identity_property_name_; + } + void set_peer_identity_property_name(const char* name) { + peer_identity_property_name_ = name; + } + + void ensure_capacity(); + void add_property(const char* name, const char* value, size_t value_length); + void add_cstring_property(const char* name, const char* value); + + private: + grpc_core::RefCountedPtr<grpc_auth_context> chained_; + grpc_auth_property_array properties_; + const char* peer_identity_property_name_ = nullptr; }; -/* Creation. */ -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained); - -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_AUTH_CONTEXT_REF(p, r) \ - grpc_auth_context_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) \ - grpc_auth_context_unref((p), __FILE__, __LINE__, (r)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy, - const char* file, int line, - const char* reason); -void grpc_auth_context_unref(grpc_auth_context* policy, const char* file, - int line, const char* reason); -#else -#define GRPC_AUTH_CONTEXT_REF(p, r) grpc_auth_context_ref((p)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) grpc_auth_context_unref((p)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy); -void grpc_auth_context_unref(grpc_auth_context* policy); -#endif - -void grpc_auth_property_reset(grpc_auth_property* property); - /* --- grpc_security_context_extension --- Extension to the security context that may be set in a filter and accessed @@ -88,16 +110,18 @@ struct grpc_security_context_extension { Internal client-side security context. */ struct grpc_client_security_context { - grpc_client_security_context() = default; + explicit grpc_client_security_context( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) + : creds(std::move(creds)) {} ~grpc_client_security_context(); - grpc_call_credentials* creds = nullptr; - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> creds; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; grpc_security_context_extension extension; }; grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena); + gpr_arena* arena, grpc_call_credentials* creds); void grpc_client_security_context_destroy(void* ctx); /* --- grpc_server_security_context --- @@ -108,7 +132,7 @@ struct grpc_server_security_context { grpc_server_security_context() = default; ~grpc_server_security_context(); - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; grpc_security_context_extension extension; }; diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc index 1fbef4ae0c..06546492bc 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.cc +++ b/src/core/lib/security/credentials/alts/alts_credentials.cc @@ -33,40 +33,47 @@ #define GRPC_CREDENTIALS_TYPE_ALTS "Alts" #define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal:8080" -static void alts_credentials_destruct(grpc_channel_credentials* creds) { - grpc_alts_credentials* alts_creds = - reinterpret_cast<grpc_alts_credentials*>(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); -} - -static void alts_server_credentials_destruct(grpc_server_credentials* creds) { - grpc_alts_server_credentials* alts_creds = - reinterpret_cast<grpc_alts_server_credentials*>(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); +grpc_alts_credentials::grpc_alts_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_alts_credentials::~grpc_alts_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); } -static grpc_security_status alts_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_alts_channel_security_connector_create( - creds, request_metadata_creds, target_name, sc); + this->Ref(), std::move(call_creds), target_name); } -static grpc_security_status alts_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_alts_server_security_connector_create(creds, sc); +grpc_alts_server_credentials::grpc_alts_server_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_credentials::create_security_connector() { + return grpc_alts_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable alts_credentials_vtable = { - alts_credentials_destruct, alts_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable alts_server_credentials_vtable = { - alts_server_credentials_destruct, alts_server_create_security_connector}; +grpc_alts_server_credentials::~grpc_alts_server_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); +} grpc_channel_credentials* grpc_alts_credentials_create_customized( const grpc_alts_credentials_options* options, @@ -74,17 +81,7 @@ grpc_channel_credentials* grpc_alts_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast<grpc_alts_credentials*>( - gpr_zalloc(sizeof(grpc_alts_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_alts_credentials>(options, handshaker_service_url); } grpc_server_credentials* grpc_alts_server_credentials_create_customized( @@ -93,17 +90,8 @@ grpc_server_credentials* grpc_alts_server_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast<grpc_alts_server_credentials*>( - gpr_zalloc(sizeof(grpc_alts_server_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_alts_server_credentials>(options, + handshaker_service_url); } grpc_channel_credentials* grpc_alts_credentials_create( diff --git a/src/core/lib/security/credentials/alts/alts_credentials.h b/src/core/lib/security/credentials/alts/alts_credentials.h index 810117f2be..cc6d5222b1 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.h +++ b/src/core/lib/security/credentials/alts/alts_credentials.h @@ -27,18 +27,45 @@ #include "src/core/lib/security/credentials/credentials.h" /* Main struct for grpc ALTS channel credential. */ -typedef struct grpc_alts_credentials { - grpc_channel_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_credentials; +class grpc_alts_credentials final : public grpc_channel_credentials { + public: + grpc_alts_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_credentials() override; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /* Main struct for grpc ALTS server credential. */ -typedef struct grpc_alts_server_credentials { - grpc_server_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_server_credentials; +class grpc_alts_server_credentials final : public grpc_server_credentials { + public: + grpc_alts_server_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_server_credentials() override; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /** * This method creates an ALTS channel credential object with customized diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc index b8f409260f..586bbed778 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.cc +++ b/src/core/lib/security/credentials/composite/composite_credentials.cc @@ -20,8 +20,10 @@ #include "src/core/lib/security/credentials/composite/composite_credentials.h" -#include <string.h> +#include <cstring> +#include <new> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/surface/api_trace.h" @@ -31,35 +33,44 @@ /* -- Composite call credentials. -- */ -typedef struct { +static void composite_call_metadata_cb(void* arg, grpc_error* error); + +namespace { +struct grpc_composite_call_credentials_metadata_context { + grpc_composite_call_credentials_metadata_context( + grpc_composite_call_credentials* composite_creds, + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata) + : composite_creds(composite_creds), + pollent(pollent), + auth_md_context(auth_md_context), + md_array(md_array), + on_request_metadata(on_request_metadata) { + GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb, + this, grpc_schedule_on_exec_ctx); + } + grpc_composite_call_credentials* composite_creds; - size_t creds_index; + size_t creds_index = 0; grpc_polling_entity* pollent; grpc_auth_metadata_context auth_md_context; grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_closure internal_on_request_metadata; -} grpc_composite_call_credentials_metadata_context; - -static void composite_call_destruct(grpc_call_credentials* creds) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); - for (size_t i = 0; i < c->inner.num_creds; i++) { - grpc_call_credentials_unref(c->inner.creds_array[i]); - } - gpr_free(c->inner.creds_array); -} +}; +} // namespace static void composite_call_metadata_cb(void* arg, grpc_error* error) { grpc_composite_call_credentials_metadata_context* ctx = static_cast<grpc_composite_call_credentials_metadata_context*>(arg); if (error == GRPC_ERROR_NONE) { + const grpc_composite_call_credentials::CallCredentialsList& inner = + ctx->composite_creds->inner(); /* See if we need to get some more metadata. */ - if (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, + if (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, &ctx->internal_on_request_metadata, &error)) { // Synchronous response, so call ourselves recursively. composite_call_metadata_cb(arg, error); @@ -73,29 +84,18 @@ static void composite_call_metadata_cb(void* arg, grpc_error* error) { gpr_free(ctx); } -static bool composite_call_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context auth_md_context, +bool grpc_composite_call_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, grpc_error** error) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); grpc_composite_call_credentials_metadata_context* ctx; - ctx = static_cast<grpc_composite_call_credentials_metadata_context*>( - gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context))); - ctx->composite_creds = c; - ctx->pollent = pollent; - ctx->auth_md_context = auth_md_context; - ctx->md_array = md_array; - ctx->on_request_metadata = on_request_metadata; - GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata, - composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx); + ctx = grpc_core::New<grpc_composite_call_credentials_metadata_context>( + this, pollent, auth_md_context, md_array, on_request_metadata); bool synchronous = true; - while (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, + const CallCredentialsList& inner = ctx->composite_creds->inner(); + while (ctx->creds_index < inner.size()) { + if (inner[ctx->creds_index++]->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, &ctx->internal_on_request_metadata, error)) { if (*error != GRPC_ERROR_NONE) break; } else { @@ -103,46 +103,66 @@ static bool composite_call_get_request_metadata( break; } } - if (synchronous) gpr_free(ctx); + if (synchronous) grpc_core::Delete(ctx); return synchronous; } -static void composite_call_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_composite_call_credentials* c = - reinterpret_cast<grpc_composite_call_credentials*>(creds); - for (size_t i = 0; i < c->inner.num_creds; ++i) { - grpc_call_credentials_cancel_get_request_metadata( - c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error)); +void grpc_composite_call_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + for (size_t i = 0; i < inner_.size(); ++i) { + inner_[i]->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable composite_call_credentials_vtable = { - composite_call_destruct, composite_call_get_request_metadata, - composite_call_cancel_get_request_metadata}; +static size_t get_creds_array_size(const grpc_call_credentials* creds, + bool is_composite) { + return is_composite + ? static_cast<const grpc_composite_call_credentials*>(creds) + ->inner() + .size() + : 1; +} -static grpc_call_credentials_array get_creds_array( - grpc_call_credentials** creds_addr) { - grpc_call_credentials_array result; - grpc_call_credentials* creds = *creds_addr; - result.creds_array = creds_addr; - result.num_creds = 1; - if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - result = *grpc_composite_call_credentials_get_credentials(creds); +void grpc_composite_call_credentials::push_to_inner( + grpc_core::RefCountedPtr<grpc_call_credentials> creds, bool is_composite) { + if (!is_composite) { + inner_.push_back(std::move(creds)); + return; + } + auto composite_creds = + static_cast<grpc_composite_call_credentials*>(creds.get()); + for (size_t i = 0; i < composite_creds->inner().size(); ++i) { + inner_.push_back(std::move(composite_creds->inner_[i])); } - return result; +} + +grpc_composite_call_credentials::grpc_composite_call_credentials( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) { + const bool creds1_is_composite = + strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const bool creds2_is_composite = + strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) + + get_creds_array_size(creds2.get(), creds2_is_composite); + inner_.reserve(size); + push_to_inner(std::move(creds1), creds1_is_composite); + push_to_inner(std::move(creds2), creds2_is_composite); +} + +static grpc_core::RefCountedPtr<grpc_call_credentials> +composite_call_credentials_create( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2) { + return grpc_core::MakeRefCounted<grpc_composite_call_credentials>( + std::move(creds1), std::move(creds2)); } grpc_call_credentials* grpc_composite_call_credentials_create( grpc_call_credentials* creds1, grpc_call_credentials* creds2, void* reserved) { - size_t i; - size_t creds_array_byte_size; - grpc_call_credentials_array creds1_array; - grpc_call_credentials_array creds2_array; - grpc_composite_call_credentials* c; GRPC_API_TRACE( "grpc_composite_call_credentials_create(creds1=%p, creds2=%p, " "reserved=%p)", @@ -150,120 +170,40 @@ grpc_call_credentials* grpc_composite_call_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(creds1 != nullptr); GPR_ASSERT(creds2 != nullptr); - c = static_cast<grpc_composite_call_credentials*>( - gpr_zalloc(sizeof(grpc_composite_call_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE; - c->base.vtable = &composite_call_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - creds1_array = get_creds_array(&creds1); - creds2_array = get_creds_array(&creds2); - c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds; - creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*); - c->inner.creds_array = - static_cast<grpc_call_credentials**>(gpr_zalloc(creds_array_byte_size)); - for (i = 0; i < creds1_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds1_array.creds_array[i]; - c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds); - } - for (i = 0; i < creds2_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds2_array.creds_array[i]; - c->inner.creds_array[i + creds1_array.num_creds] = - grpc_call_credentials_ref(cur_creds); - } - return &c->base; -} - -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) { - const grpc_composite_call_credentials* c = - reinterpret_cast<const grpc_composite_call_credentials*>(creds); - GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); - return &c->inner; -} -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds) { - size_t i; - if (strcmp(creds->type, type) == 0) { - if (composite_creds != nullptr) *composite_creds = nullptr; - return creds; - } else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - const grpc_call_credentials_array* inner_creds_array = - grpc_composite_call_credentials_get_credentials(creds); - for (i = 0; i < inner_creds_array->num_creds; i++) { - if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) { - if (composite_creds != nullptr) *composite_creds = creds; - return inner_creds_array->creds_array[i]; - } - } - } - return nullptr; + return composite_call_credentials_create(creds1->Ref(), creds2->Ref()) + .release(); } /* -- Composite channel credentials. -- */ -static void composite_channel_destruct(grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - grpc_channel_credentials_unref(c->inner_creds); - grpc_call_credentials_unref(c->call_creds); -} - -static grpc_security_status composite_channel_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_composite_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - grpc_security_status status = GRPC_SECURITY_ERROR; - - GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr && - c->inner_creds->vtable != nullptr && - c->inner_creds->vtable->create_security_connector != nullptr); + grpc_channel_args** new_args) { + GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr); /* If we are passed a call_creds, create a call composite to pass it downstream. */ if (call_creds != nullptr) { - grpc_call_credentials* composite_call_creds = - grpc_composite_call_credentials_create(c->call_creds, call_creds, - nullptr); - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, composite_call_creds, target, args, sc, new_args); - grpc_call_credentials_unref(composite_call_creds); + return inner_creds_->create_security_connector( + composite_call_credentials_create(call_creds_, std::move(call_creds)), + target, args, new_args); } else { - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, c->call_creds, target, args, sc, new_args); + return inner_creds_->create_security_connector(call_creds_, target, args, + new_args); } - return status; } -static grpc_channel_credentials* -composite_channel_duplicate_without_call_credentials( - grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast<grpc_composite_channel_credentials*>(creds); - return grpc_channel_credentials_ref(c->inner_creds); -} - -static grpc_channel_credentials_vtable composite_channel_credentials_vtable = { - composite_channel_destruct, composite_channel_create_security_connector, - composite_channel_duplicate_without_call_credentials}; - grpc_channel_credentials* grpc_composite_channel_credentials_create( grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds, void* reserved) { - grpc_composite_channel_credentials* c = - static_cast<grpc_composite_channel_credentials*>(gpr_zalloc(sizeof(*c))); GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr && reserved == nullptr); GRPC_API_TRACE( "grpc_composite_channel_credentials_create(channel_creds=%p, " "call_creds=%p, reserved=%p)", 3, (channel_creds, call_creds, reserved)); - c->base.type = channel_creds->type; - c->base.vtable = &composite_channel_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->inner_creds = grpc_channel_credentials_ref(channel_creds); - c->call_creds = grpc_call_credentials_ref(call_creds); - return &c->base; + return grpc_core::New<grpc_composite_channel_credentials>( + channel_creds->Ref(), call_creds->Ref()); } diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h index a952ad57f1..7a1c7d5e42 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.h +++ b/src/core/lib/security/credentials/composite/composite_credentials.h @@ -21,39 +21,75 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/inlined_vector.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials** creds_array; - size_t num_creds; -} grpc_call_credentials_array; +/* -- Composite channel credentials. -- */ -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials( - grpc_call_credentials* composite_creds); +class grpc_composite_channel_credentials : public grpc_channel_credentials { + public: + grpc_composite_channel_credentials( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds) + : grpc_channel_credentials(channel_creds->type()), + inner_creds_(std::move(channel_creds)), + call_creds_(std::move(call_creds)) {} -/* Returns creds if creds is of the specified type or the inner creds of the - specified type (if found), if the creds is of type COMPOSITE. - If composite_creds is not NULL, *composite_creds will point to creds if of - type COMPOSITE in case of success. */ -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds); + ~grpc_composite_channel_credentials() override = default; -/* -- Composite channel credentials. -- */ + grpc_core::RefCountedPtr<grpc_channel_credentials> + duplicate_without_call_credentials() override { + return inner_creds_; + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* inner_creds; - grpc_call_credentials* call_creds; -} grpc_composite_channel_credentials; + const grpc_channel_credentials* inner_creds() const { + return inner_creds_.get(); + } + const grpc_call_credentials* call_creds() const { return call_creds_.get(); } + grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); } + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> inner_creds_; + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds_; +}; /* -- Composite call credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_call_credentials_array inner; -} grpc_composite_call_credentials; +class grpc_composite_call_credentials : public grpc_call_credentials { + public: + using CallCredentialsList = + grpc_core::InlinedVector<grpc_core::RefCountedPtr<grpc_call_credentials>, + 2>; + + grpc_composite_call_credentials( + grpc_core::RefCountedPtr<grpc_call_credentials> creds1, + grpc_core::RefCountedPtr<grpc_call_credentials> creds2); + ~grpc_composite_call_credentials() override = default; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const CallCredentialsList& inner() const { return inner_; } + + private: + void push_to_inner(grpc_core::RefCountedPtr<grpc_call_credentials> creds, + bool is_composite); + + CallCredentialsList inner_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_COMPOSITE_COMPOSITE_CREDENTIALS_H \ */ diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc index c43cb440eb..90452d68d6 100644 --- a/src/core/lib/security/credentials/credentials.cc +++ b/src/core/lib/security/credentials/credentials.cc @@ -39,120 +39,24 @@ /* -- Common. -- */ -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds) { - grpc_credentials_metadata_request* r = - static_cast<grpc_credentials_metadata_request*>( - gpr_zalloc(sizeof(grpc_credentials_metadata_request))); - r->creds = grpc_call_credentials_ref(creds); - return r; -} - -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r) { - grpc_call_credentials_unref(r->creds); - grpc_http_response_destroy(&r->response); - gpr_free(r); -} - -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_channel_credentials_unref(grpc_channel_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } -} - void grpc_channel_credentials_release(grpc_channel_credentials* creds) { GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_channel_credentials_unref(creds); -} - -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_call_credentials_unref(grpc_call_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } + if (creds) creds->Unref(); } void grpc_call_credentials_release(grpc_call_credentials* creds) { GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials_unref(creds); -} - -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - if (creds == nullptr || creds->vtable->get_request_metadata == nullptr) { - return true; - } - return creds->vtable->get_request_metadata(creds, pollent, context, md_array, - on_request_metadata, error); -} - -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - if (creds == nullptr || - creds->vtable->cancel_get_request_metadata == nullptr) { - return; - } - creds->vtable->cancel_get_request_metadata(creds, md_array, error); -} - -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* channel_creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args) { - *new_args = nullptr; - if (channel_creds == nullptr) { - return GRPC_SECURITY_ERROR; - } - GPR_ASSERT(channel_creds->vtable->create_security_connector != nullptr); - return channel_creds->vtable->create_security_connector( - channel_creds, nullptr, target, args, sc, new_args); -} - -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* channel_creds) { - if (channel_creds != nullptr && channel_creds->vtable != nullptr && - channel_creds->vtable->duplicate_without_call_credentials != nullptr) { - return channel_creds->vtable->duplicate_without_call_credentials( - channel_creds); - } else { - return grpc_channel_credentials_ref(channel_creds); - } + if (creds) creds->Unref(); } static void credentials_pointer_arg_destroy(void* p) { - grpc_channel_credentials_unref(static_cast<grpc_channel_credentials*>(p)); + static_cast<grpc_channel_credentials*>(p)->Unref(); } static void* credentials_pointer_arg_copy(void* p) { - return grpc_channel_credentials_ref( - static_cast<grpc_channel_credentials*>(p)); + return static_cast<grpc_channel_credentials*>(p)->Ref().release(); } static int credentials_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } @@ -191,63 +95,35 @@ grpc_channel_credentials* grpc_channel_credentials_find_in_args( return nullptr; } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_server_credentials_unref(grpc_server_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - gpr_free(creds); - } -} - void grpc_server_credentials_release(grpc_server_credentials* creds) { GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_server_credentials_unref(creds); + if (creds) creds->Unref(); } -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - if (creds == nullptr || creds->vtable->create_security_connector == nullptr) { - gpr_log(GPR_ERROR, "Server credentials cannot create security context."); - return GRPC_SECURITY_ERROR; - } - return creds->vtable->create_security_connector(creds, sc); -} - -void grpc_server_credentials_set_auth_metadata_processor( - grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { +void grpc_server_credentials::set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor) { GRPC_API_TRACE( "grpc_server_credentials_set_auth_metadata_processor(" "creds=%p, " "processor=grpc_auth_metadata_processor { process: %p, state: %p })", - 3, (creds, (void*)(intptr_t)processor.process, processor.state)); - if (creds == nullptr) return; - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - creds->processor = processor; + 3, (this, (void*)(intptr_t)processor.process, processor.state)); + DestroyProcessor(); + processor_ = processor; +} + +void grpc_server_credentials_set_auth_metadata_processor( + grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { + GPR_DEBUG_ASSERT(creds != nullptr); + creds->set_auth_metadata_processor(processor); } static void server_credentials_pointer_arg_destroy(void* p) { - grpc_server_credentials_unref(static_cast<grpc_server_credentials*>(p)); + static_cast<grpc_server_credentials*>(p)->Unref(); } static void* server_credentials_pointer_arg_copy(void* p) { - return grpc_server_credentials_ref(static_cast<grpc_server_credentials*>(p)); + return static_cast<grpc_server_credentials*>(p)->Ref().release(); } static int server_credentials_pointer_cmp(void* a, void* b) { diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index 3878958b38..4091ef3dfb 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -26,6 +26,7 @@ #include <grpc/support/sync.h> #include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/polling_entity.h" @@ -90,44 +91,46 @@ void grpc_override_well_known_credentials_path_getter( #define GRPC_ARG_CHANNEL_CREDENTIALS "grpc.channel_credentials" -typedef struct { - void (*destruct)(grpc_channel_credentials* c); - - grpc_security_status (*create_security_connector)( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_channel_credentials + : grpc_core::RefCounted<grpc_channel_credentials> { + public: + explicit grpc_channel_credentials(const char* type) : type_(type) {} + virtual ~grpc_channel_credentials() = default; + + // Creates a security connector for the channel. May also create new channel + // args for the channel to be used in place of the passed in const args if + // returned non NULL. In that case the caller is responsible for destroying + // new_args after channel creation. + virtual grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args); - - grpc_channel_credentials* (*duplicate_without_call_credentials)( - grpc_channel_credentials* c); -} grpc_channel_credentials_vtable; - -struct grpc_channel_credentials { - const grpc_channel_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; + grpc_channel_args** new_args) { + // Tell clang-tidy that call_creds cannot be passed as const-ref. + call_creds.reset(); + GRPC_ABSTRACT; + } + + // Creates a version of the channel credentials without any attached call + // credentials. This can be used in order to open a channel to a non-trusted + // gRPC load balancer. + virtual grpc_core::RefCountedPtr<grpc_channel_credentials> + duplicate_without_call_credentials() { + // By default we just increment the refcount. + return Ref(); + } + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds); -void grpc_channel_credentials_unref(grpc_channel_credentials* creds); - -/* Creates a security connector for the channel. May also create new channel - args for the channel to be used in place of the passed in const args if - returned non NULL. In that case the caller is responsible for destroying - new_args after channel creation. */ -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args); - -/* Creates a version of the channel credentials without any attached call - credentials. This can be used in order to open a channel to a non-trusted - gRPC load balancer. */ -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* creds); - /* Util to encapsulate the channel credentials in a channel arg. */ grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials); @@ -158,44 +161,39 @@ void grpc_credentials_mdelem_array_destroy(grpc_credentials_mdelem_array* list); /* --- grpc_call_credentials. --- */ -typedef struct { - void (*destruct)(grpc_call_credentials* c); - bool (*get_request_metadata)(grpc_call_credentials* c, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error); - void (*cancel_get_request_metadata)(grpc_call_credentials* c, - grpc_credentials_mdelem_array* md_array, - grpc_error* error); -} grpc_call_credentials_vtable; - -struct grpc_call_credentials { - const grpc_call_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_call_credentials + : public grpc_core::RefCounted<grpc_call_credentials> { + public: + explicit grpc_call_credentials(const char* type) : type_(type) {} + virtual ~grpc_call_credentials() = default; + + // Returns true if completed synchronously, in which case \a error will + // be set to indicate the result. Otherwise, \a on_request_metadata will + // be invoked asynchronously when complete. \a md_array will be populated + // with the resulting metadata once complete. + virtual bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) GRPC_ABSTRACT; + + // Cancels a pending asynchronous operation started by + // grpc_call_credentials_get_request_metadata() with the corresponding + // value of \a md_array. + virtual void cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) GRPC_ABSTRACT; + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds); -void grpc_call_credentials_unref(grpc_call_credentials* creds); - -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_request_metadata will -/// be invoked asynchronously when complete. \a md_array will be populated -/// with the resulting metadata once complete. -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error); - -/// Cancels a pending asynchronous operation started by -/// grpc_call_credentials_get_request_metadata() with the corresponding -/// value of \a md_array. -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error); - /* Metadata-only credentials with the specified key and value where asynchronicity can be simulated for testing. */ grpc_call_credentials* grpc_md_only_test_credentials_create( @@ -203,26 +201,40 @@ grpc_call_credentials* grpc_md_only_test_credentials_create( /* --- grpc_server_credentials. --- */ -typedef struct { - void (*destruct)(grpc_server_credentials* c); - grpc_security_status (*create_security_connector)( - grpc_server_credentials* c, grpc_server_security_connector** sc); -} grpc_server_credentials_vtable; - -struct grpc_server_credentials { - const grpc_server_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; - grpc_auth_metadata_processor processor; -}; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_server_credentials + : public grpc_core::RefCounted<grpc_server_credentials> { + public: + explicit grpc_server_credentials(const char* type) : type_(type) {} -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc); + virtual ~grpc_server_credentials() { DestroyProcessor(); } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds); + virtual grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() GRPC_ABSTRACT; -void grpc_server_credentials_unref(grpc_server_credentials* creds); + const char* type() const { return type_; } + + const grpc_auth_metadata_processor& auth_metadata_processor() const { + return processor_; + } + void set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor); + + GRPC_ABSTRACT_BASE_CLASS + + private: + void DestroyProcessor() { + if (processor_.destroy != nullptr && processor_.state != nullptr) { + processor_.destroy(processor_.state); + } + } + + const char* type_; + grpc_auth_metadata_processor processor_ = + grpc_auth_metadata_processor(); // Zero-initialize the C struct. +}; #define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials" @@ -233,15 +245,27 @@ grpc_server_credentials* grpc_find_server_credentials_in_args( /* -- Credentials Metadata Request. -- */ -typedef struct { - grpc_call_credentials* creds; +struct grpc_credentials_metadata_request { + explicit grpc_credentials_metadata_request( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) + : creds(std::move(creds)) {} + ~grpc_credentials_metadata_request() { + grpc_http_response_destroy(&response); + } + + grpc_core::RefCountedPtr<grpc_call_credentials> creds; grpc_http_response response; -} grpc_credentials_metadata_request; +}; -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds); +inline grpc_credentials_metadata_request* +grpc_credentials_metadata_request_create( + grpc_core::RefCountedPtr<grpc_call_credentials> creds) { + return grpc_core::New<grpc_credentials_metadata_request>(std::move(creds)); +} -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r); +inline void grpc_credentials_metadata_request_destroy( + grpc_credentials_metadata_request* r) { + grpc_core::Delete(r); +} #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index d3e0e8c816..337dd7679f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -33,49 +33,45 @@ /* -- Fake transport security credentials. -- */ -static grpc_security_status fake_transport_security_create_security_connector( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, - const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - *sc = - grpc_fake_channel_security_connector_create(c, call_creds, target, args); - return GRPC_SECURITY_OK; -} - -static grpc_security_status -fake_transport_security_server_create_security_connector( - grpc_server_credentials* c, grpc_server_security_connector** sc) { - *sc = grpc_fake_server_security_connector_create(c); - return GRPC_SECURITY_OK; -} +namespace { +class grpc_fake_channel_credentials final : public grpc_channel_credentials { + public: + grpc_fake_channel_credentials() + : grpc_channel_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_channel_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override { + return grpc_fake_channel_security_connector_create( + this->Ref(), std::move(call_creds), target, args); + } +}; + +class grpc_fake_server_credentials final : public grpc_server_credentials { + public: + grpc_fake_server_credentials() + : grpc_server_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_server_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override { + return grpc_fake_server_security_connector_create(this->Ref()); + } +}; +} // namespace -static grpc_channel_credentials_vtable - fake_transport_security_credentials_vtable = { - nullptr, fake_transport_security_create_security_connector, nullptr}; - -static grpc_server_credentials_vtable - fake_transport_security_server_credentials_vtable = { - nullptr, fake_transport_security_server_create_security_connector}; - -grpc_channel_credentials* grpc_fake_transport_security_credentials_create( - void) { - grpc_channel_credentials* c = static_cast<grpc_channel_credentials*>( - gpr_zalloc(sizeof(grpc_channel_credentials))); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - c->vtable = &fake_transport_security_credentials_vtable; - gpr_ref_init(&c->refcount, 1); - return c; +grpc_channel_credentials* grpc_fake_transport_security_credentials_create() { + return grpc_core::New<grpc_fake_channel_credentials>(); } -grpc_server_credentials* grpc_fake_transport_security_server_credentials_create( - void) { - grpc_server_credentials* c = static_cast<grpc_server_credentials*>( - gpr_malloc(sizeof(grpc_server_credentials))); - memset(c, 0, sizeof(grpc_server_credentials)); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - gpr_ref_init(&c->refcount, 1); - c->vtable = &fake_transport_security_server_credentials_vtable; - return c; +grpc_server_credentials* +grpc_fake_transport_security_server_credentials_create() { + return grpc_core::New<grpc_fake_server_credentials>(); } grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) { @@ -92,46 +88,25 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only test credentials. -- */ -static void md_only_test_destruct(grpc_call_credentials* creds) { - grpc_md_only_test_credentials* c = - reinterpret_cast<grpc_md_only_test_credentials*>(creds); - GRPC_MDELEM_UNREF(c->md); -} - -static bool md_only_test_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_md_only_test_credentials* c = - reinterpret_cast<grpc_md_only_test_credentials*>(creds); - grpc_credentials_mdelem_array_add(md_array, c->md); - if (c->is_async) { +bool grpc_md_only_test_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, md_); + if (is_async_) { GRPC_CLOSURE_SCHED(on_request_metadata, GRPC_ERROR_NONE); return false; } return true; } -static void md_only_test_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_md_only_test_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable md_only_test_vtable = { - md_only_test_destruct, md_only_test_get_request_metadata, - md_only_test_cancel_get_request_metadata}; - grpc_call_credentials* grpc_md_only_test_credentials_create( const char* md_key, const char* md_value, bool is_async) { - grpc_md_only_test_credentials* c = - static_cast<grpc_md_only_test_credentials*>( - gpr_zalloc(sizeof(grpc_md_only_test_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &md_only_test_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->md = grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), - grpc_slice_from_copied_string(md_value)); - c->is_async = is_async; - return &c->base; + return grpc_core::New<grpc_md_only_test_credentials>(md_key, md_value, + is_async); } diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index e89e6e24cc..b7f6a1909f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -55,10 +55,28 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only Test credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_mdelem md; - bool is_async; -} grpc_md_only_test_credentials; +class grpc_md_only_test_credentials : public grpc_call_credentials { + public: + grpc_md_only_test_credentials(const char* md_key, const char* md_value, + bool is_async) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + md_(grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), + grpc_slice_from_copied_string(md_value))), + is_async_(is_async) {} + ~grpc_md_only_test_credentials() override { GRPC_MDELEM_UNREF(md_); } + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem md_; + bool is_async_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_FAKE_FAKE_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc index fcab252959..a86a17d586 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.cc +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/load_file.h" @@ -49,9 +50,16 @@ /* -- Default credentials. -- */ -static int g_compute_engine_detection_done = 0; -static int g_need_compute_engine_creds = 0; +/* A sticky bit that will be set only if the result of metadata server detection + * is positive. We do not set the bit if the result is negative. Because it + * means the detection is done via network test that is unreliable and the + * unreliable result should not be referred by successive calls. */ +static int g_metadata_server_available = 0; +static int g_is_on_gce = 0; static gpr_mu g_state_mu; +/* Protect a metadata_server_detector instance that can be modified by more than + * one gRPC threads */ +static gpr_mu* g_polling_mu; static gpr_once g_once = GPR_ONCE_INIT; static grpc_core::internal::grpc_gce_tenancy_checker g_gce_tenancy_checker = grpc_alts_is_running_on_gcp; @@ -63,22 +71,13 @@ typedef struct { int is_done; int success; grpc_http_response response; -} compute_engine_detector; - -static void google_default_credentials_destruct( - grpc_channel_credentials* creds) { - grpc_google_default_channel_credentials* c = - reinterpret_cast<grpc_google_default_channel_credentials*>(creds); - grpc_channel_credentials_unref(c->alts_creds); - grpc_channel_credentials_unref(c->ssl_creds); -} +} metadata_server_detector; -static grpc_security_status google_default_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_google_default_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_google_default_channel_credentials* c = - reinterpret_cast<grpc_google_default_channel_credentials*>(creds); + grpc_channel_args** new_args) { bool is_grpclb_load_balancer = grpc_channel_arg_get_bool( grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER), false); @@ -88,11 +87,17 @@ static grpc_security_status google_default_create_security_connector( false); bool use_alts = is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer; - grpc_security_status status = GRPC_SECURITY_ERROR; - status = use_alts ? c->alts_creds->vtable->create_security_connector( - c->alts_creds, call_creds, target, args, sc, new_args) - : c->ssl_creds->vtable->create_security_connector( - c->ssl_creds, call_creds, target, args, sc, new_args); + /* Return failure if ALTS is selected but not running on GCE. */ + if (use_alts && !g_is_on_gce) { + gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE."); + return nullptr; + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + use_alts ? alts_creds_->create_security_connector(call_creds, target, + args, new_args) + : ssl_creds_->create_security_connector(call_creds, target, args, + new_args); /* grpclb-specific channel args are removed from the channel args set * to ensure backends and fallback adresses will have the same set of channel * args. By doing that, it guarantees the connections to backends will not be @@ -106,20 +111,103 @@ static grpc_security_status google_default_create_security_connector( *new_args = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0); } - return status; + return sc; +} + +static void on_metadata_server_detection_http_response(void* user_data, + grpc_error* error) { + metadata_server_detector* detector = + static_cast<metadata_server_detector*>(user_data); + if (error == GRPC_ERROR_NONE && detector->response.status == 200 && + detector->response.hdr_count > 0) { + /* Internet providers can return a generic response to all requests, so + it is necessary to check that metadata header is present also. */ + size_t i; + for (i = 0; i < detector->response.hdr_count; i++) { + grpc_http_header* header = &detector->response.hdrs[i]; + if (strcmp(header->key, "Metadata-Flavor") == 0 && + strcmp(header->value, "Google") == 0) { + detector->success = 1; + break; + } + } + } + gpr_mu_lock(g_polling_mu); + detector->is_done = 1; + GRPC_LOG_IF_ERROR( + "Pollset kick", + grpc_pollset_kick(grpc_polling_entity_pollset(&detector->pollent), + nullptr)); + gpr_mu_unlock(g_polling_mu); } -static grpc_channel_credentials_vtable google_default_credentials_vtable = { - google_default_credentials_destruct, - google_default_create_security_connector, nullptr}; +static void destroy_pollset(void* p, grpc_error* e) { + grpc_pollset_destroy(static_cast<grpc_pollset*>(p)); +} + +static int is_metadata_server_reachable() { + metadata_server_detector detector; + grpc_httpcli_request request; + grpc_httpcli_context context; + grpc_closure destroy_closure; + /* The http call is local. If it takes more than one sec, it is for sure not + on compute engine. */ + grpc_millis max_detection_delay = GPR_MS_PER_SEC; + grpc_pollset* pollset = + static_cast<grpc_pollset*>(gpr_zalloc(grpc_pollset_size())); + grpc_pollset_init(pollset, &g_polling_mu); + detector.pollent = grpc_polling_entity_create_from_pollset(pollset); + detector.is_done = 0; + detector.success = 0; + memset(&detector.response, 0, sizeof(detector.response)); + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = (char*)GRPC_COMPUTE_ENGINE_DETECTION_HOST; + request.http.path = (char*)"/"; + grpc_httpcli_context_init(&context); + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("google_default_credentials"); + grpc_httpcli_get( + &context, &detector.pollent, resource_quota, &request, + grpc_core::ExecCtx::Get()->Now() + max_detection_delay, + GRPC_CLOSURE_CREATE(on_metadata_server_detection_http_response, &detector, + grpc_schedule_on_exec_ctx), + &detector.response); + grpc_resource_quota_unref_internal(resource_quota); + grpc_core::ExecCtx::Get()->Flush(); + /* Block until we get the response. This is not ideal but this should only be + called once for the lifetime of the process by the default credentials. */ + gpr_mu_lock(g_polling_mu); + while (!detector.is_done) { + grpc_pollset_worker* worker = nullptr; + if (!GRPC_LOG_IF_ERROR( + "pollset_work", + grpc_pollset_work(grpc_polling_entity_pollset(&detector.pollent), + &worker, GRPC_MILLIS_INF_FUTURE))) { + detector.is_done = 1; + detector.success = 0; + } + } + gpr_mu_unlock(g_polling_mu); + grpc_httpcli_context_destroy(&context); + GRPC_CLOSURE_INIT(&destroy_closure, destroy_pollset, + grpc_polling_entity_pollset(&detector.pollent), + grpc_schedule_on_exec_ctx); + grpc_pollset_shutdown(grpc_polling_entity_pollset(&detector.pollent), + &destroy_closure); + g_polling_mu = nullptr; + grpc_core::ExecCtx::Get()->Flush(); + gpr_free(grpc_polling_entity_pollset(&detector.pollent)); + grpc_http_response_destroy(&detector.response); + return detector.success; +} /* Takes ownership of creds_path if not NULL. */ static grpc_error* create_default_creds_from_path( - char* creds_path, grpc_call_credentials** creds) { + char* creds_path, grpc_core::RefCountedPtr<grpc_call_credentials>* creds) { grpc_json* json = nullptr; grpc_auth_json_key key; grpc_auth_refresh_token token; - grpc_call_credentials* result = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> result; grpc_slice creds_data = grpc_empty_slice(); grpc_error* error = GRPC_ERROR_NONE; if (creds_path == nullptr) { @@ -176,13 +264,12 @@ end: return error; } -grpc_channel_credentials* grpc_google_default_credentials_create(void) { +grpc_channel_credentials* grpc_google_default_credentials_create() { grpc_channel_credentials* result = nullptr; - grpc_call_credentials* call_creds = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds; grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Failed to create Google credentials"); grpc_error* err; - int need_compute_engine_creds = 0; grpc_core::ExecCtx exec_ctx; GRPC_API_TRACE("grpc_google_default_credentials_create(void)", 0, ()); @@ -202,17 +289,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { error = grpc_error_add_child(error, err); gpr_mu_lock(&g_state_mu); - /* At last try to see if we're on compute engine (do the detection only once - since it requires a network test). */ - if (!g_compute_engine_detection_done) { - g_need_compute_engine_creds = g_gce_tenancy_checker(); - g_compute_engine_detection_done = 1; + + /* Try a platform-provided hint for GCE. */ + if (!g_metadata_server_available) { + g_is_on_gce = g_gce_tenancy_checker(); + g_metadata_server_available = g_is_on_gce; + } + /* TODO: Add a platform-provided hint for GAE. */ + + /* Do a network test for metadata server. */ + if (!g_metadata_server_available) { + g_metadata_server_available = is_metadata_server_reachable(); } - need_compute_engine_creds = g_need_compute_engine_creds; gpr_mu_unlock(&g_state_mu); - if (need_compute_engine_creds) { - call_creds = grpc_google_compute_engine_credentials_create(nullptr); + if (g_metadata_server_available) { + call_creds = grpc_core::RefCountedPtr<grpc_call_credentials>( + grpc_google_compute_engine_credentials_create(nullptr)); if (call_creds == nullptr) { error = grpc_error_add_child( error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( @@ -223,23 +316,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { end: if (call_creds != nullptr) { /* Create google default credentials. */ - auto creds = static_cast<grpc_google_default_channel_credentials*>( - gpr_zalloc(sizeof(grpc_google_default_channel_credentials))); - creds->base.vtable = &google_default_credentials_vtable; - creds->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT; - gpr_ref_init(&creds->base.refcount, 1); - creds->ssl_creds = + grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); - GPR_ASSERT(creds->ssl_creds != nullptr); + GPR_ASSERT(ssl_creds != nullptr); grpc_alts_credentials_options* options = grpc_alts_credentials_client_options_create(); - creds->alts_creds = grpc_alts_credentials_create(options); + grpc_channel_credentials* alts_creds = + grpc_alts_credentials_create(options); grpc_alts_credentials_options_destroy(options); - result = grpc_composite_channel_credentials_create(&creds->base, call_creds, - nullptr); + auto creds = + grpc_core::MakeRefCounted<grpc_google_default_channel_credentials>( + alts_creds != nullptr ? alts_creds->Ref() : nullptr, + ssl_creds != nullptr ? ssl_creds->Ref() : nullptr); + if (ssl_creds) ssl_creds->Unref(); + if (alts_creds) alts_creds->Unref(); + result = grpc_composite_channel_credentials_create( + creds.get(), call_creds.get(), nullptr); GPR_ASSERT(result != nullptr); - grpc_channel_credentials_unref(&creds->base); - grpc_call_credentials_unref(call_creds); } else { gpr_log(GPR_ERROR, "Could not create google default credentials: %s", grpc_error_string(error)); @@ -259,7 +352,7 @@ void grpc_flush_cached_google_default_credentials(void) { grpc_core::ExecCtx exec_ctx; gpr_once_init(&g_once, init_default_credentials); gpr_mu_lock(&g_state_mu); - g_compute_engine_detection_done = 0; + g_metadata_server_available = 0; gpr_mu_unlock(&g_state_mu); } diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.h b/src/core/lib/security/credentials/google_default/google_default_credentials.h index b9e2efb04f..bf00f7285a 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.h +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.h @@ -21,6 +21,7 @@ #include <grpc/support/port_platform.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" #define GRPC_GOOGLE_CLOUD_SDK_CONFIG_DIRECTORY "gcloud" @@ -39,11 +40,33 @@ "/" GRPC_GOOGLE_WELL_KNOWN_CREDENTIALS_FILE #endif -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* alts_creds; - grpc_channel_credentials* ssl_creds; -} grpc_google_default_channel_credentials; +class grpc_google_default_channel_credentials + : public grpc_channel_credentials { + public: + grpc_google_default_channel_credentials( + grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds, + grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT), + alts_creds_(std::move(alts_creds)), + ssl_creds_(std::move(ssl_creds)) {} + + ~grpc_google_default_channel_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_channel_credentials* alts_creds() const { + return alts_creds_.get(); + } + const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); } + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds_; + grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds_; +}; namespace grpc_core { namespace internal { diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc index 5d92fa88c4..5cd561f676 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.cc +++ b/src/core/lib/security/credentials/iam/iam_credentials.cc @@ -22,6 +22,7 @@ #include <string.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include <grpc/support/alloc.h> @@ -29,32 +30,37 @@ #include <grpc/support/string_util.h> #include <grpc/support/sync.h> -static void iam_destruct(grpc_call_credentials* creds) { - grpc_google_iam_credentials* c = - reinterpret_cast<grpc_google_iam_credentials*>(creds); - grpc_credentials_mdelem_array_destroy(&c->md_array); +grpc_google_iam_credentials::~grpc_google_iam_credentials() { + grpc_credentials_mdelem_array_destroy(&md_array_); } -static bool iam_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_google_iam_credentials* c = - reinterpret_cast<grpc_google_iam_credentials*>(creds); - grpc_credentials_mdelem_array_append(md_array, &c->md_array); +bool grpc_google_iam_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_append(md_array, &md_array_); return true; } -static void iam_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_google_iam_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable iam_vtable = { - iam_destruct, iam_get_request_metadata, iam_cancel_get_request_metadata}; +grpc_google_iam_credentials::grpc_google_iam_credentials( + const char* token, const char* authority_selector) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM) { + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), + grpc_slice_from_copied_string(token)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); + md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), + grpc_slice_from_copied_string(authority_selector)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); +} grpc_call_credentials* grpc_google_iam_credentials_create( const char* token, const char* authority_selector, void* reserved) { @@ -66,21 +72,7 @@ grpc_call_credentials* grpc_google_iam_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(token != nullptr); GPR_ASSERT(authority_selector != nullptr); - grpc_google_iam_credentials* c = - static_cast<grpc_google_iam_credentials*>(gpr_zalloc(sizeof(*c))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_IAM; - c->base.vtable = &iam_vtable; - gpr_ref_init(&c->base.refcount, 1); - grpc_mdelem md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), - grpc_slice_from_copied_string(token)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), - grpc_slice_from_copied_string(authority_selector)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - - return &c->base; + return grpc_core::MakeRefCounted<grpc_google_iam_credentials>( + token, authority_selector) + .release(); } diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h index a45710fe0f..36f5ee8930 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.h +++ b/src/core/lib/security/credentials/iam/iam_credentials.h @@ -23,9 +23,23 @@ #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials base; - grpc_credentials_mdelem_array md_array; -} grpc_google_iam_credentials; +class grpc_google_iam_credentials : public grpc_call_credentials { + public: + grpc_google_iam_credentials(const char* token, + const char* authority_selector); + ~grpc_google_iam_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_credentials_mdelem_array md_array_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_IAM_IAM_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc index 05c08a68b0..f2591a1ea5 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.cc +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc @@ -23,6 +23,8 @@ #include <inttypes.h> #include <string.h> +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include <grpc/support/alloc.h> @@ -30,71 +32,66 @@ #include <grpc/support/string_util.h> #include <grpc/support/sync.h> -static void jwt_reset_cache(grpc_service_account_jwt_access_credentials* c) { - GRPC_MDELEM_UNREF(c->cached.jwt_md); - c->cached.jwt_md = GRPC_MDNULL; - if (c->cached.service_url != nullptr) { - gpr_free(c->cached.service_url); - c->cached.service_url = nullptr; +void grpc_service_account_jwt_access_credentials::reset_cache() { + GRPC_MDELEM_UNREF(cached_.jwt_md); + cached_.jwt_md = GRPC_MDNULL; + if (cached_.service_url != nullptr) { + gpr_free(cached_.service_url); + cached_.service_url = nullptr; } - c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); + cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); } -static void jwt_destruct(grpc_call_credentials* creds) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds); - grpc_auth_json_key_destruct(&c->key); - jwt_reset_cache(c); - gpr_mu_destroy(&c->cache_mu); +grpc_service_account_jwt_access_credentials:: + ~grpc_service_account_jwt_access_credentials() { + grpc_auth_json_key_destruct(&key_); + reset_cache(); + gpr_mu_destroy(&cache_mu_); } -static bool jwt_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds); +bool grpc_service_account_jwt_access_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { gpr_timespec refresh_threshold = gpr_time_from_seconds( GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN); /* See if we can return a cached jwt. */ grpc_mdelem jwt_md = GRPC_MDNULL; { - gpr_mu_lock(&c->cache_mu); - if (c->cached.service_url != nullptr && - strcmp(c->cached.service_url, context.service_url) == 0 && - !GRPC_MDISNULL(c->cached.jwt_md) && - (gpr_time_cmp(gpr_time_sub(c->cached.jwt_expiration, - gpr_now(GPR_CLOCK_REALTIME)), - refresh_threshold) > 0)) { - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + gpr_mu_lock(&cache_mu_); + if (cached_.service_url != nullptr && + strcmp(cached_.service_url, context.service_url) == 0 && + !GRPC_MDISNULL(cached_.jwt_md) && + (gpr_time_cmp( + gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)), + refresh_threshold) > 0)) { + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (GRPC_MDISNULL(jwt_md)) { char* jwt = nullptr; /* Generate a new jwt. */ - gpr_mu_lock(&c->cache_mu); - jwt_reset_cache(c); - jwt = grpc_jwt_encode_and_sign(&c->key, context.service_url, - c->jwt_lifetime, nullptr); + gpr_mu_lock(&cache_mu_); + reset_cache(); + jwt = grpc_jwt_encode_and_sign(&key_, context.service_url, jwt_lifetime_, + nullptr); if (jwt != nullptr) { char* md_value; gpr_asprintf(&md_value, "Bearer %s", jwt); gpr_free(jwt); - c->cached.jwt_expiration = - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c->jwt_lifetime); - c->cached.service_url = gpr_strdup(context.service_url); - c->cached.jwt_md = grpc_mdelem_from_slices( + cached_.jwt_expiration = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_); + cached_.service_url = gpr_strdup(context.service_url); + cached_.jwt_md = grpc_mdelem_from_slices( grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), grpc_slice_from_copied_string(md_value)); gpr_free(md_value); - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (!GRPC_MDISNULL(jwt_md)) { @@ -106,29 +103,15 @@ static bool jwt_get_request_metadata(grpc_call_credentials* creds, return true; } -static void jwt_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable jwt_vtable = { - jwt_destruct, jwt_get_request_metadata, jwt_cancel_get_request_metadata}; - -grpc_call_credentials* -grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key key, gpr_timespec token_lifetime) { - grpc_service_account_jwt_access_credentials* c; - if (!grpc_auth_json_key_is_valid(&key)) { - gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); - return nullptr; - } - c = static_cast<grpc_service_account_jwt_access_credentials*>( - gpr_zalloc(sizeof(grpc_service_account_jwt_access_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_JWT; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &jwt_vtable; - c->key = key; +grpc_service_account_jwt_access_credentials:: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) { gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime(); if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) { gpr_log(GPR_INFO, @@ -136,10 +119,20 @@ grpc_service_account_jwt_access_credentials_create_from_auth_json_key( static_cast<int>(max_token_lifetime.tv_sec)); token_lifetime = grpc_max_auth_token_lifetime(); } - c->jwt_lifetime = token_lifetime; - gpr_mu_init(&c->cache_mu); - jwt_reset_cache(c); - return &c->base; + jwt_lifetime_ = token_lifetime; + gpr_mu_init(&cache_mu_); + reset_cache(); +} + +grpc_core::RefCountedPtr<grpc_call_credentials> +grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key key, gpr_timespec token_lifetime) { + if (!grpc_auth_json_key_is_valid(&key)) { + gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); + return nullptr; + } + return grpc_core::MakeRefCounted<grpc_service_account_jwt_access_credentials>( + key, token_lifetime); } static char* redact_private_key(const char* json_key) { @@ -182,9 +175,7 @@ grpc_call_credentials* grpc_service_account_jwt_access_credentials_create( } GPR_ASSERT(reserved == nullptr); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials* creds = - grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key_create_from_string(json_key), token_lifetime); - - return creds; + return grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key_create_from_string(json_key), token_lifetime) + .release(); } diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h index 5c3d34aa56..5af909f44d 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.h +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h @@ -24,25 +24,44 @@ #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/jwt/json_token.h" -typedef struct { - grpc_call_credentials base; +class grpc_service_account_jwt_access_credentials + : public grpc_call_credentials { + public: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime); + ~grpc_service_account_jwt_access_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; } + const grpc_auth_json_key& key() const { return key_; } + + private: + void reset_cache(); // Have a simple cache for now with just 1 entry. We could have a map based on // the service_url for a more sophisticated one. - gpr_mu cache_mu; + gpr_mu cache_mu_; struct { - grpc_mdelem jwt_md; - char* service_url; + grpc_mdelem jwt_md = GRPC_MDNULL; + char* service_url = nullptr; gpr_timespec jwt_expiration; - } cached; + } cached_; - grpc_auth_json_key key; - gpr_timespec jwt_lifetime; -} grpc_service_account_jwt_access_credentials; + grpc_auth_json_key key_; + gpr_timespec jwt_lifetime_; +}; // Private constructor for jwt credentials from an already parsed json key. // Takes ownership of the key. -grpc_call_credentials* +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_service_account_jwt_access_credentials_create_from_auth_json_key( grpc_auth_json_key key, gpr_timespec token_lifetime); diff --git a/src/core/lib/security/credentials/jwt/jwt_verifier.cc b/src/core/lib/security/credentials/jwt/jwt_verifier.cc index c7d1b36ff0..cdef0f322a 100644 --- a/src/core/lib/security/credentials/jwt/jwt_verifier.cc +++ b/src/core/lib/security/credentials/jwt/jwt_verifier.cc @@ -31,7 +31,9 @@ #include <grpc/support/sync.h> extern "C" { +#include <openssl/bn.h> #include <openssl/pem.h> +#include <openssl/rsa.h> } #include "src/core/lib/gpr/string.h" diff --git a/src/core/lib/security/credentials/local/local_credentials.cc b/src/core/lib/security/credentials/local/local_credentials.cc index 3ccfa2b908..6f6f95a34a 100644 --- a/src/core/lib/security/credentials/local/local_credentials.cc +++ b/src/core/lib/security/credentials/local/local_credentials.cc @@ -29,49 +29,36 @@ #define GRPC_CREDENTIALS_TYPE_LOCAL "Local" -static void local_credentials_destruct(grpc_channel_credentials* creds) {} - -static void local_server_credentials_destruct(grpc_server_credentials* creds) {} - -static grpc_security_status local_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_local_channel_security_connector_create( - creds, request_metadata_creds, args, target_name, sc); + this->Ref(), std::move(request_metadata_creds), args, target_name); } -static grpc_security_status local_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_local_server_security_connector_create(creds, sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_credentials::create_security_connector() { + return grpc_local_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable local_credentials_vtable = { - local_credentials_destruct, local_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable local_server_credentials_vtable = { - local_server_credentials_destruct, local_server_create_security_connector}; +grpc_local_credentials::grpc_local_credentials( + grpc_local_connect_type connect_type) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} grpc_channel_credentials* grpc_local_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast<grpc_local_credentials*>( - gpr_zalloc(sizeof(grpc_local_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_local_credentials>(connect_type); } +grpc_local_server_credentials::grpc_local_server_credentials( + grpc_local_connect_type connect_type) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} + grpc_server_credentials* grpc_local_server_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast<grpc_local_server_credentials*>( - gpr_zalloc(sizeof(grpc_local_server_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New<grpc_local_server_credentials>(connect_type); } diff --git a/src/core/lib/security/credentials/local/local_credentials.h b/src/core/lib/security/credentials/local/local_credentials.h index 47358b04bc..60a8a4f64c 100644 --- a/src/core/lib/security/credentials/local/local_credentials.h +++ b/src/core/lib/security/credentials/local/local_credentials.h @@ -25,16 +25,37 @@ #include "src/core/lib/security/credentials/credentials.h" -/* Main struct for grpc local channel credential. */ -typedef struct grpc_local_credentials { - grpc_channel_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_credentials; - -/* Main struct for grpc local server credential. */ -typedef struct grpc_local_server_credentials { - grpc_server_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_server_credentials; +/* Main class for grpc local channel credential. */ +class grpc_local_credentials final : public grpc_channel_credentials { + public: + explicit grpc_local_credentials(grpc_local_connect_type connect_type); + ~grpc_local_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; + +/* Main class for grpc local server credential. */ +class grpc_local_server_credentials final : public grpc_server_credentials { + public: + explicit grpc_local_server_credentials(grpc_local_connect_type connect_type); + ~grpc_local_server_credentials() override = default; + + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_LOCAL_LOCAL_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc index 44b093557f..ad63b01e75 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc @@ -22,6 +22,7 @@ #include <string.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/util/json_util.h" #include "src/core/lib/surface/api_trace.h" @@ -105,13 +106,12 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) { // Oauth2 Token Fetcher credentials. // -static void oauth2_token_fetcher_destruct(grpc_call_credentials* creds) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); - GRPC_MDELEM_UNREF(c->access_token_md); - gpr_mu_destroy(&c->mu); - grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&c->pollent)); - grpc_httpcli_context_destroy(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials:: + ~grpc_oauth2_token_fetcher_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); + gpr_mu_destroy(&mu_); + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); + grpc_httpcli_context_destroy(&httpcli_context_); } grpc_credentials_status @@ -209,25 +209,29 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, grpc_credentials_metadata_request* r = static_cast<grpc_credentials_metadata_request*>(user_data); grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds); + reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds.get()); + c->on_http_response(r, error); +} + +void grpc_oauth2_token_fetcher_credentials::on_http_response( + grpc_credentials_metadata_request* r, grpc_error* error) { grpc_mdelem access_token_md = GRPC_MDNULL; grpc_millis token_lifetime; grpc_credentials_status status = grpc_oauth2_token_fetcher_credentials_parse_server_response( &r->response, &access_token_md, &token_lifetime); // Update cache and grab list of pending requests. - gpr_mu_lock(&c->mu); - c->token_fetch_pending = false; - c->access_token_md = GRPC_MDELEM_REF(access_token_md); - c->token_expiration = + gpr_mu_lock(&mu_); + token_fetch_pending_ = false; + access_token_md_ = GRPC_MDELEM_REF(access_token_md); + token_expiration_ = status == GRPC_CREDENTIALS_OK ? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_millis(token_lifetime, GPR_TIMESPAN)) : gpr_inf_past(GPR_CLOCK_MONOTONIC); - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; - c->pending_requests = nullptr; - gpr_mu_unlock(&c->mu); + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; + pending_requests_ = nullptr; + gpr_mu_unlock(&mu_); // Invoke callbacks for all pending requests. while (pending_request != nullptr) { if (status == GRPC_CREDENTIALS_OK) { @@ -239,42 +243,40 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, } GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, error); grpc_polling_entity_del_from_pollset_set( - pending_request->pollent, grpc_polling_entity_pollset_set(&c->pollent)); + pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_)); grpc_oauth2_pending_get_request_metadata* prev = pending_request; pending_request = pending_request->next; gpr_free(prev); } GRPC_MDELEM_UNREF(access_token_md); - grpc_call_credentials_unref(r->creds); + Unref(); grpc_credentials_metadata_request_destroy(r); } -static bool oauth2_token_fetcher_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); +bool grpc_oauth2_token_fetcher_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { // Check if we can use the cached token. grpc_millis refresh_threshold = GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC; grpc_mdelem cached_access_token_md = GRPC_MDNULL; - gpr_mu_lock(&c->mu); - if (!GRPC_MDISNULL(c->access_token_md) && + gpr_mu_lock(&mu_); + if (!GRPC_MDISNULL(access_token_md_) && gpr_time_cmp( - gpr_time_sub(c->token_expiration, gpr_now(GPR_CLOCK_MONOTONIC)), + gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)), gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN)) > 0) { - cached_access_token_md = GRPC_MDELEM_REF(c->access_token_md); + cached_access_token_md = GRPC_MDELEM_REF(access_token_md_); } if (!GRPC_MDISNULL(cached_access_token_md)) { - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); grpc_credentials_mdelem_array_add(md_array, cached_access_token_md); GRPC_MDELEM_UNREF(cached_access_token_md); return true; } // Couldn't get the token from the cache. - // Add request to c->pending_requests and start a new fetch if needed. + // Add request to pending_requests_ and start a new fetch if needed. grpc_oauth2_pending_get_request_metadata* pending_request = static_cast<grpc_oauth2_pending_get_request_metadata*>( gpr_malloc(sizeof(*pending_request))); @@ -282,41 +284,37 @@ static bool oauth2_token_fetcher_get_request_metadata( pending_request->on_request_metadata = on_request_metadata; pending_request->pollent = pollent; grpc_polling_entity_add_to_pollset_set( - pollent, grpc_polling_entity_pollset_set(&c->pollent)); - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; + pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_request->next = pending_requests_; + pending_requests_ = pending_request; bool start_fetch = false; - if (!c->token_fetch_pending) { - c->token_fetch_pending = true; + if (!token_fetch_pending_) { + token_fetch_pending_ = true; start_fetch = true; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); if (start_fetch) { - grpc_call_credentials_ref(creds); - c->fetch_func(grpc_credentials_metadata_request_create(creds), - &c->httpcli_context, &c->pollent, - on_oauth2_token_fetcher_http_response, - grpc_core::ExecCtx::Get()->Now() + refresh_threshold); + Ref().release(); + fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()), + &httpcli_context_, &pollent_, + on_oauth2_token_fetcher_http_response, + grpc_core::ExecCtx::Get()->Now() + refresh_threshold); } return false; } -static void oauth2_token_fetcher_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds); - gpr_mu_lock(&c->mu); +void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); grpc_oauth2_pending_get_request_metadata* prev = nullptr; - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; while (pending_request != nullptr) { if (pending_request->md_array == md_array) { // Remove matching pending request from the list. if (prev != nullptr) { prev->next = pending_request->next; } else { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } // Invoke the callback immediately with an error. GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, @@ -327,96 +325,89 @@ static void oauth2_token_fetcher_cancel_get_request_metadata( prev = pending_request; pending_request = pending_request->next; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials* c, - grpc_fetch_oauth2_func fetch_func) { - memset(c, 0, sizeof(grpc_oauth2_token_fetcher_credentials)); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - gpr_ref_init(&c->base.refcount, 1); - gpr_mu_init(&c->mu); - c->token_expiration = gpr_inf_past(GPR_CLOCK_MONOTONIC); - c->fetch_func = fetch_func; - c->pollent = - grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create()); - grpc_httpcli_context_init(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials() + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)), + pollent_(grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set_create())) { + gpr_mu_init(&mu_); + grpc_httpcli_context_init(&httpcli_context_); } // // Google Compute Engine credentials. // -static grpc_call_credentials_vtable compute_engine_vtable = { - oauth2_token_fetcher_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; +namespace { + +class grpc_compute_engine_token_fetcher_credentials + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_compute_engine_token_fetcher_credentials() = default; + ~grpc_compute_engine_token_fetcher_credentials() override = default; + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* http_context, + grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, + grpc_millis deadline) override { + grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; + request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; + request.http.hdr_count = 1; + request.http.hdrs = &header; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials"); + grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline, + GRPC_CLOSURE_CREATE(response_cb, metadata_req, + grpc_schedule_on_exec_ctx), + &metadata_req->response); + grpc_resource_quota_unref_internal(resource_quota); + } +}; -static void compute_engine_fetch_oauth2( - grpc_credentials_metadata_request* metadata_req, - grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, - grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; - grpc_httpcli_request request; - memset(&request, 0, sizeof(grpc_httpcli_request)); - request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; - request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; - request.http.hdr_count = 1; - request.http.hdrs = &header; - /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host - channel. This would allow us to cancel an authentication query when under - extreme memory pressure. */ - grpc_resource_quota* resource_quota = - grpc_resource_quota_create("oauth2_credentials"); - grpc_httpcli_get( - httpcli_context, pollent, resource_quota, &request, deadline, - GRPC_CLOSURE_CREATE(response_cb, metadata_req, grpc_schedule_on_exec_ctx), - &metadata_req->response); - grpc_resource_quota_unref_internal(resource_quota); -} +} // namespace grpc_call_credentials* grpc_google_compute_engine_credentials_create( void* reserved) { - grpc_oauth2_token_fetcher_credentials* c = - static_cast<grpc_oauth2_token_fetcher_credentials*>( - gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials))); GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - init_oauth2_token_fetcher(c, compute_engine_fetch_oauth2); - c->base.vtable = &compute_engine_vtable; - return &c->base; + return grpc_core::MakeRefCounted< + grpc_compute_engine_token_fetcher_credentials>() + .release(); } // // Google Refresh Token credentials. // -static void refresh_token_destruct(grpc_call_credentials* creds) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast<grpc_google_refresh_token_credentials*>(creds); - grpc_auth_refresh_token_destruct(&c->refresh_token); - oauth2_token_fetcher_destruct(&c->base.base); +grpc_google_refresh_token_credentials:: + ~grpc_google_refresh_token_credentials() { + grpc_auth_refresh_token_destruct(&refresh_token_); } -static grpc_call_credentials_vtable refresh_token_vtable = { - refresh_token_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; - -static void refresh_token_fetch_oauth2( +void grpc_google_refresh_token_credentials::fetch_oauth2( grpc_credentials_metadata_request* metadata_req, grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast<grpc_google_refresh_token_credentials*>( - metadata_req->creds); grpc_http_header header = {(char*)"Content-Type", (char*)"application/x-www-form-urlencoded"}; grpc_httpcli_request request; char* body = nullptr; gpr_asprintf(&body, GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING, - c->refresh_token.client_id, c->refresh_token.client_secret, - c->refresh_token.refresh_token); + refresh_token_.client_id, refresh_token_.client_secret, + refresh_token_.refresh_token); memset(&request, 0, sizeof(grpc_httpcli_request)); request.host = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_HOST; request.http.path = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH; @@ -437,20 +428,19 @@ static void refresh_token_fetch_oauth2( gpr_free(body); } -grpc_call_credentials* +grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials( + grpc_auth_refresh_token refresh_token) + : refresh_token_(refresh_token) {} + +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token refresh_token) { - grpc_google_refresh_token_credentials* c; if (!grpc_auth_refresh_token_is_valid(&refresh_token)) { gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation"); return nullptr; } - c = static_cast<grpc_google_refresh_token_credentials*>( - gpr_zalloc(sizeof(grpc_google_refresh_token_credentials))); - init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2); - c->base.base.vtable = &refresh_token_vtable; - c->refresh_token = refresh_token; - return &c->base.base; + return grpc_core::MakeRefCounted<grpc_google_refresh_token_credentials>( + refresh_token); } static char* create_loggable_refresh_token(grpc_auth_refresh_token* token) { @@ -478,59 +468,50 @@ grpc_call_credentials* grpc_google_refresh_token_credentials_create( gpr_free(loggable_token); } GPR_ASSERT(reserved == nullptr); - return grpc_refresh_token_credentials_create_from_auth_refresh_token(token); + return grpc_refresh_token_credentials_create_from_auth_refresh_token(token) + .release(); } // // Oauth2 Access Token credentials. // -static void access_token_destruct(grpc_call_credentials* creds) { - grpc_access_token_credentials* c = - reinterpret_cast<grpc_access_token_credentials*>(creds); - GRPC_MDELEM_UNREF(c->access_token_md); +grpc_access_token_credentials::~grpc_access_token_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); } -static bool access_token_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_access_token_credentials* c = - reinterpret_cast<grpc_access_token_credentials*>(creds); - grpc_credentials_mdelem_array_add(md_array, c->access_token_md); +bool grpc_access_token_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, access_token_md_); return true; } -static void access_token_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_access_token_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable access_token_vtable = { - access_token_destruct, access_token_get_request_metadata, - access_token_cancel_get_request_metadata}; +grpc_access_token_credentials::grpc_access_token_credentials( + const char* access_token) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) { + char* token_md_value; + gpr_asprintf(&token_md_value, "Bearer %s", access_token); + grpc_core::ExecCtx exec_ctx; + access_token_md_ = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_copied_string(token_md_value)); + gpr_free(token_md_value); +} grpc_call_credentials* grpc_access_token_credentials_create( const char* access_token, void* reserved) { - grpc_access_token_credentials* c = - static_cast<grpc_access_token_credentials*>( - gpr_zalloc(sizeof(grpc_access_token_credentials))); GRPC_API_TRACE( "grpc_access_token_credentials_create(access_token=<redacted>, " "reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &access_token_vtable; - gpr_ref_init(&c->base.refcount, 1); - char* token_md_value; - gpr_asprintf(&token_md_value, "Bearer %s", access_token); - grpc_core::ExecCtx exec_ctx; - c->access_token_md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), - grpc_slice_from_copied_string(token_md_value)); - - gpr_free(token_md_value); - return &c->base; + return grpc_core::MakeRefCounted<grpc_access_token_credentials>(access_token) + .release(); } diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h index 12a1d4484f..510a78b484 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h @@ -54,46 +54,91 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token); // This object is a base for credentials that need to acquire an oauth2 token // from an http service. -typedef void (*grpc_fetch_oauth2_func)(grpc_credentials_metadata_request* req, - grpc_httpcli_context* http_context, - grpc_polling_entity* pollent, - grpc_iomgr_cb_func cb, - grpc_millis deadline); - -typedef struct grpc_oauth2_pending_get_request_metadata { +struct grpc_oauth2_pending_get_request_metadata { grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_polling_entity* pollent; struct grpc_oauth2_pending_get_request_metadata* next; -} grpc_oauth2_pending_get_request_metadata; - -typedef struct { - grpc_call_credentials base; - gpr_mu mu; - grpc_mdelem access_token_md; - gpr_timespec token_expiration; - bool token_fetch_pending; - grpc_oauth2_pending_get_request_metadata* pending_requests; - grpc_httpcli_context httpcli_context; - grpc_fetch_oauth2_func fetch_func; - grpc_polling_entity pollent; -} grpc_oauth2_token_fetcher_credentials; +}; + +class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials { + public: + grpc_oauth2_token_fetcher_credentials(); + ~grpc_oauth2_token_fetcher_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + void on_http_response(grpc_credentials_metadata_request* r, + grpc_error* error); + + GRPC_ABSTRACT_BASE_CLASS + + protected: + virtual void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) GRPC_ABSTRACT; + + private: + gpr_mu mu_; + grpc_mdelem access_token_md_ = GRPC_MDNULL; + gpr_timespec token_expiration_; + bool token_fetch_pending_ = false; + grpc_oauth2_pending_get_request_metadata* pending_requests_ = nullptr; + grpc_httpcli_context httpcli_context_; + grpc_polling_entity pollent_; +}; // Google refresh token credentials. -typedef struct { - grpc_oauth2_token_fetcher_credentials base; - grpc_auth_refresh_token refresh_token; -} grpc_google_refresh_token_credentials; +class grpc_google_refresh_token_credentials final + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_google_refresh_token_credentials(grpc_auth_refresh_token refresh_token); + ~grpc_google_refresh_token_credentials() override; + + const grpc_auth_refresh_token& refresh_token() const { + return refresh_token_; + } + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) override; + + private: + grpc_auth_refresh_token refresh_token_; +}; // Access token credentials. -typedef struct { - grpc_call_credentials base; - grpc_mdelem access_token_md; -} grpc_access_token_credentials; +class grpc_access_token_credentials final : public grpc_call_credentials { + public: + grpc_access_token_credentials(const char* access_token); + ~grpc_access_token_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem access_token_md_; +}; // Private constructor for refresh token credentials from an already parsed // refresh token. Takes ownership of the refresh token. -grpc_call_credentials* +grpc_core::RefCountedPtr<grpc_call_credentials> grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token token); diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc index 4015124298..52982fdb8f 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.cc +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc @@ -35,20 +35,17 @@ grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials"); -static void plugin_destruct(grpc_call_credentials* creds) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); - gpr_mu_destroy(&c->mu); - if (c->plugin.state != nullptr && c->plugin.destroy != nullptr) { - c->plugin.destroy(c->plugin.state); +grpc_plugin_credentials::~grpc_plugin_credentials() { + gpr_mu_destroy(&mu_); + if (plugin_.state != nullptr && plugin_.destroy != nullptr) { + plugin_.destroy(plugin_.state); } } -static void pending_request_remove_locked( - grpc_plugin_credentials* c, - grpc_plugin_credentials_pending_request* pending_request) { +void grpc_plugin_credentials::pending_request_remove_locked( + pending_request* pending_request) { if (pending_request->prev == nullptr) { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } else { pending_request->prev->next = pending_request->next; } @@ -62,17 +59,17 @@ static void pending_request_remove_locked( // cancelled out from under us. // When this returns, r->cancelled indicates whether the request was // cancelled before completion. -static void pending_request_complete( - grpc_plugin_credentials_pending_request* r) { - gpr_mu_lock(&r->creds->mu); - if (!r->cancelled) pending_request_remove_locked(r->creds, r); - gpr_mu_unlock(&r->creds->mu); +void grpc_plugin_credentials::pending_request_complete(pending_request* r) { + GPR_DEBUG_ASSERT(r->creds == this); + gpr_mu_lock(&mu_); + if (!r->cancelled) pending_request_remove_locked(r); + gpr_mu_unlock(&mu_); // Ref to credentials not needed anymore. - grpc_call_credentials_unref(&r->creds->base); + Unref(); } static grpc_error* process_plugin_result( - grpc_plugin_credentials_pending_request* r, const grpc_metadata* md, + grpc_plugin_credentials::pending_request* r, const grpc_metadata* md, size_t num_md, grpc_status_code status, const char* error_details) { grpc_error* error = GRPC_ERROR_NONE; if (status != GRPC_STATUS_OK) { @@ -119,8 +116,8 @@ static void plugin_md_request_metadata_ready(void* request, /* called from application code */ grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED | GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP); - grpc_plugin_credentials_pending_request* r = - static_cast<grpc_plugin_credentials_pending_request*>(request); + grpc_plugin_credentials::pending_request* r = + static_cast<grpc_plugin_credentials::pending_request*>(request); if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " @@ -128,7 +125,7 @@ static void plugin_md_request_metadata_ready(void* request, r->creds, r); } // Remove request from pending list if not previously cancelled. - pending_request_complete(r); + r->creds->pending_request_complete(r); // If it has not been cancelled, process it. if (!r->cancelled) { grpc_error* error = @@ -143,65 +140,59 @@ static void plugin_md_request_metadata_ready(void* request, gpr_free(r); } -static bool plugin_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); +bool grpc_plugin_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { bool retval = true; // Synchronous return. - if (c->plugin.get_metadata != nullptr) { + if (plugin_.get_metadata != nullptr) { // Create pending_request object. - grpc_plugin_credentials_pending_request* pending_request = - static_cast<grpc_plugin_credentials_pending_request*>( - gpr_zalloc(sizeof(*pending_request))); - pending_request->creds = c; - pending_request->md_array = md_array; - pending_request->on_request_metadata = on_request_metadata; + pending_request* request = + static_cast<pending_request*>(gpr_zalloc(sizeof(*request))); + request->creds = this; + request->md_array = md_array; + request->on_request_metadata = on_request_metadata; // Add it to the pending list. - gpr_mu_lock(&c->mu); - if (c->pending_requests != nullptr) { - c->pending_requests->prev = pending_request; + gpr_mu_lock(&mu_); + if (pending_requests_ != nullptr) { + pending_requests_->prev = request; } - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; - gpr_mu_unlock(&c->mu); + request->next = pending_requests_; + pending_requests_ = request; + gpr_mu_unlock(&mu_); // Invoke the plugin. The callback holds a ref to us. if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin", - c, pending_request); + this, request); } - grpc_call_credentials_ref(creds); + Ref().release(); grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX]; size_t num_creds_md = 0; grpc_status_code status = GRPC_STATUS_OK; const char* error_details = nullptr; - if (!c->plugin.get_metadata(c->plugin.state, context, - plugin_md_request_metadata_ready, - pending_request, creds_md, &num_creds_md, - &status, &error_details)) { + if (!plugin_.get_metadata( + plugin_.state, context, plugin_md_request_metadata_ready, request, + creds_md, &num_creds_md, &status, &error_details)) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin will return " "asynchronously", - c, pending_request); + this, request); } return false; // Asynchronous return. } // Returned synchronously. // Remove request from pending list if not previously cancelled. - pending_request_complete(pending_request); + request->creds->pending_request_complete(request); // If the request was cancelled, the error will have been returned // asynchronously by plugin_cancel_get_request_metadata(), so return // false. Otherwise, process the result. - if (pending_request->cancelled) { + if (request->cancelled) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p was cancelled, error " "will be returned asynchronously", - c, pending_request); + this, request); } retval = false; } else { @@ -209,10 +200,10 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " "synchronously", - c, pending_request); + this, request); } - *error = process_plugin_result(pending_request, creds_md, num_creds_md, - status, error_details); + *error = process_plugin_result(request, creds_md, num_creds_md, status, + error_details); } // Clean up. for (size_t i = 0; i < num_creds_md; ++i) { @@ -220,51 +211,42 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, grpc_slice_unref_internal(creds_md[i].value); } gpr_free((void*)error_details); - gpr_free(pending_request); + gpr_free(request); } return retval; } -static void plugin_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_plugin_credentials* c = - reinterpret_cast<grpc_plugin_credentials*>(creds); - gpr_mu_lock(&c->mu); - for (grpc_plugin_credentials_pending_request* pending_request = - c->pending_requests; +void grpc_plugin_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); + for (pending_request* pending_request = pending_requests_; pending_request != nullptr; pending_request = pending_request->next) { if (pending_request->md_array == md_array) { if (grpc_plugin_credentials_trace.enabled()) { - gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", c, + gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this, pending_request); } pending_request->cancelled = true; GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, GRPC_ERROR_REF(error)); - pending_request_remove_locked(c, pending_request); + pending_request_remove_locked(pending_request); break; } } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable plugin_vtable = { - plugin_destruct, plugin_get_request_metadata, - plugin_cancel_get_request_metadata}; +grpc_plugin_credentials::grpc_plugin_credentials( + grpc_metadata_credentials_plugin plugin) + : grpc_call_credentials(plugin.type), plugin_(plugin) { + gpr_mu_init(&mu_); +} grpc_call_credentials* grpc_metadata_credentials_create_from_plugin( grpc_metadata_credentials_plugin plugin, void* reserved) { - grpc_plugin_credentials* c = - static_cast<grpc_plugin_credentials*>(gpr_zalloc(sizeof(*c))); GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = plugin.type; - c->base.vtable = &plugin_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->plugin = plugin; - gpr_mu_init(&c->mu); - return &c->base; + return grpc_core::New<grpc_plugin_credentials>(plugin); } diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h index caf990efa1..77a957e513 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.h +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h @@ -25,22 +25,45 @@ extern grpc_core::TraceFlag grpc_plugin_credentials_trace; -struct grpc_plugin_credentials; - -typedef struct grpc_plugin_credentials_pending_request { - bool cancelled; - struct grpc_plugin_credentials* creds; - grpc_credentials_mdelem_array* md_array; - grpc_closure* on_request_metadata; - struct grpc_plugin_credentials_pending_request* prev; - struct grpc_plugin_credentials_pending_request* next; -} grpc_plugin_credentials_pending_request; - -typedef struct grpc_plugin_credentials { - grpc_call_credentials base; - grpc_metadata_credentials_plugin plugin; - gpr_mu mu; - grpc_plugin_credentials_pending_request* pending_requests; -} grpc_plugin_credentials; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_plugin_credentials final : public grpc_call_credentials { + public: + struct pending_request { + bool cancelled; + struct grpc_plugin_credentials* creds; + grpc_credentials_mdelem_array* md_array; + grpc_closure* on_request_metadata; + struct pending_request* prev; + struct pending_request* next; + }; + + explicit grpc_plugin_credentials(grpc_metadata_credentials_plugin plugin); + ~grpc_plugin_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + // Checks if the request has been cancelled. + // If not, removes it from the pending list, so that it cannot be + // cancelled out from under us. + // When this returns, r->cancelled indicates whether the request was + // cancelled before completion. + void pending_request_complete(pending_request* r); + + private: + void pending_request_remove_locked(pending_request* pending_request); + + grpc_metadata_credentials_plugin plugin_; + gpr_mu mu_; + pending_request* pending_requests_ = nullptr; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_PLUGIN_PLUGIN_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index 3d6f2f200a..83db86f1ea 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -44,22 +44,27 @@ void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp, gpr_free(kp); } -static void ssl_destruct(grpc_channel_credentials* creds) { - grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds); - gpr_free(c->config.pem_root_certs); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pair, 1); - if (c->config.verify_options.verify_peer_destruct != nullptr) { - c->config.verify_options.verify_peer_destruct( - c->config.verify_options.verify_peer_callback_userdata); +grpc_ssl_credentials::grpc_ssl_credentials( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + build_config(pem_root_certs, pem_key_cert_pair, verify_options); +} + +grpc_ssl_credentials::~grpc_ssl_credentials() { + gpr_free(config_.pem_root_certs); + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1); + if (config_.verify_options.verify_peer_destruct != nullptr) { + config_.verify_options.verify_peer_destruct( + config_.verify_options.verify_peer_callback_userdata); } } -static grpc_security_status ssl_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_credentials::create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds); - grpc_security_status status = GRPC_SECURITY_OK; + grpc_channel_args** new_args) { const char* overridden_target_name = nullptr; tsi_ssl_session_cache* ssl_session_cache = nullptr; for (size_t i = 0; args && i < args->num_args; i++) { @@ -74,52 +79,47 @@ static grpc_security_status ssl_create_security_connector( static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p); } } - status = grpc_ssl_channel_security_connector_create( - creds, call_creds, &c->config, target, overridden_target_name, - ssl_session_cache, sc); - if (status != GRPC_SECURITY_OK) { - return status; + grpc_core::RefCountedPtr<grpc_channel_security_connector> sc = + grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name, ssl_session_cache); + if (sc == nullptr) { + return sc; } grpc_arg new_arg = grpc_channel_arg_string_create( (char*)GRPC_ARG_HTTP2_SCHEME, (char*)"https"); *new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1); - return status; + return sc; } -static grpc_channel_credentials_vtable ssl_vtable = { - ssl_destruct, ssl_create_security_connector, nullptr}; - -static void ssl_build_config(const char* pem_root_certs, - grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, - const verify_peer_options* verify_options, - grpc_ssl_config* config) { - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } +void grpc_ssl_credentials::build_config( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) { + config_.pem_root_certs = gpr_strdup(pem_root_certs); if (pem_key_cert_pair != nullptr) { GPR_ASSERT(pem_key_cert_pair->private_key != nullptr); GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr); - config->pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>( + config_.pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>( gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair))); - config->pem_key_cert_pair->cert_chain = + config_.pem_key_cert_pair->cert_chain = gpr_strdup(pem_key_cert_pair->cert_chain); - config->pem_key_cert_pair->private_key = + config_.pem_key_cert_pair->private_key = gpr_strdup(pem_key_cert_pair->private_key); + } else { + config_.pem_key_cert_pair = nullptr; } if (verify_options != nullptr) { - memcpy(&config->verify_options, verify_options, + memcpy(&config_.verify_options, verify_options, sizeof(verify_peer_options)); } else { // Otherwise set all options to default values - memset(&config->verify_options, 0, sizeof(verify_peer_options)); + memset(&config_.verify_options, 0, sizeof(verify_peer_options)); } } grpc_channel_credentials* grpc_ssl_credentials_create( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const verify_peer_options* verify_options, void* reserved) { - grpc_ssl_credentials* c = static_cast<grpc_ssl_credentials*>( - gpr_zalloc(sizeof(grpc_ssl_credentials))); GRPC_API_TRACE( "grpc_ssl_credentials_create(pem_root_certs=%s, " "pem_key_cert_pair=%p, " @@ -127,12 +127,9 @@ grpc_channel_credentials* grpc_ssl_credentials_create( "reserved=%p)", 4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - c->base.vtable = &ssl_vtable; - gpr_ref_init(&c->base.refcount, 1); - ssl_build_config(pem_root_certs, pem_key_cert_pair, verify_options, - &c->config); - return &c->base; + + return grpc_core::New<grpc_ssl_credentials>(pem_root_certs, pem_key_cert_pair, + verify_options); } // @@ -145,21 +142,29 @@ struct grpc_ssl_server_credentials_options { grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher; }; -static void ssl_server_destruct(grpc_server_credentials* creds) { - grpc_ssl_server_credentials* c = - reinterpret_cast<grpc_ssl_server_credentials*>(creds); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pairs, - c->config.num_key_cert_pairs); - gpr_free(c->config.pem_root_certs); +grpc_ssl_server_credentials::grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options) + : grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + if (options.certificate_config_fetcher != nullptr) { + config_.client_certificate_request = options.client_certificate_request; + certificate_config_fetcher_ = *options.certificate_config_fetcher; + } else { + build_config(options.certificate_config->pem_root_certs, + options.certificate_config->pem_key_cert_pairs, + options.certificate_config->num_key_cert_pairs, + options.client_certificate_request); + } } -static grpc_security_status ssl_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_ssl_server_security_connector_create(creds, sc); +grpc_ssl_server_credentials::~grpc_ssl_server_credentials() { + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs, + config_.num_key_cert_pairs); + gpr_free(config_.pem_root_certs); +} +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_credentials::create_security_connector() { + return grpc_ssl_server_security_connector_create(this->Ref()); } - -static grpc_server_credentials_vtable ssl_server_vtable = { - ssl_server_destruct, ssl_server_create_security_connector}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, @@ -179,18 +184,15 @@ tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( return tsi_pairs; } -static void ssl_build_server_config( +void grpc_ssl_server_credentials::build_config( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, - grpc_ssl_client_certificate_request_type client_certificate_request, - grpc_ssl_server_config* config) { - config->client_certificate_request = client_certificate_request; - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } - config->pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + grpc_ssl_client_certificate_request_type client_certificate_request) { + config_.client_certificate_request = client_certificate_request; + config_.pem_root_certs = gpr_strdup(pem_root_certs); + config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( pem_key_cert_pairs, num_key_cert_pairs); - config->num_key_cert_pairs = num_key_cert_pairs; + config_.num_key_cert_pairs = num_key_cert_pairs; } grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( @@ -200,9 +202,7 @@ grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( grpc_ssl_server_certificate_config* config = static_cast<grpc_ssl_server_certificate_config*>( gpr_zalloc(sizeof(grpc_ssl_server_certificate_config))); - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } + config->pem_root_certs = gpr_strdup(pem_root_certs); if (num_key_cert_pairs > 0) { GPR_ASSERT(pem_key_cert_pairs != nullptr); config->pem_key_cert_pairs = static_cast<grpc_ssl_pem_key_cert_pair*>( @@ -311,7 +311,6 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_ex( grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( grpc_ssl_server_credentials_options* options) { grpc_server_credentials* retval = nullptr; - grpc_ssl_server_credentials* c = nullptr; if (options == nullptr) { gpr_log(GPR_ERROR, @@ -331,23 +330,7 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( goto done; } - c = static_cast<grpc_ssl_server_credentials*>( - gpr_zalloc(sizeof(grpc_ssl_server_credentials))); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &ssl_server_vtable; - - if (options->certificate_config_fetcher != nullptr) { - c->config.client_certificate_request = options->client_certificate_request; - c->certificate_config_fetcher = *options->certificate_config_fetcher; - } else { - ssl_build_server_config(options->certificate_config->pem_root_certs, - options->certificate_config->pem_key_cert_pairs, - options->certificate_config->num_key_cert_pairs, - options->client_certificate_request, &c->config); - } - - retval = &c->base; + retval = grpc_core::New<grpc_ssl_server_credentials>(*options); done: grpc_ssl_server_credentials_options_destroy(options); diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 0fba413876..e1174327b3 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -24,27 +24,70 @@ #include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h" -typedef struct { - grpc_channel_credentials base; - grpc_ssl_config config; -} grpc_ssl_credentials; +class grpc_ssl_credentials : public grpc_channel_credentials { + public: + grpc_ssl_credentials(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + ~grpc_ssl_credentials() override; + + grpc_core::RefCountedPtr<grpc_channel_security_connector> + create_security_connector( + grpc_core::RefCountedPtr<grpc_call_credentials> call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + private: + void build_config(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + grpc_ssl_config config_; +}; struct grpc_ssl_server_certificate_config { - grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; }; -typedef struct { - grpc_ssl_server_certificate_config_callback cb; +struct grpc_ssl_server_certificate_config_fetcher { + grpc_ssl_server_certificate_config_callback cb = nullptr; void* user_data; -} grpc_ssl_server_certificate_config_fetcher; +}; + +class grpc_ssl_server_credentials final : public grpc_server_credentials { + public: + grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options); + ~grpc_ssl_server_credentials() override; -typedef struct { - grpc_server_credentials base; - grpc_ssl_server_config config; - grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher; -} grpc_ssl_server_credentials; + grpc_core::RefCountedPtr<grpc_server_security_connector> + create_security_connector() override; + + bool has_cert_config_fetcher() const { + return certificate_config_fetcher_.cb != nullptr; + } + + grpc_ssl_certificate_config_reload_status FetchCertConfig( + grpc_ssl_server_certificate_config** config) { + GPR_DEBUG_ASSERT(has_cert_config_fetcher()); + return certificate_config_fetcher_.cb(certificate_config_fetcher_.user_data, + config); + } + + const grpc_ssl_server_config& config() const { return config_; } + + private: + void build_config( + const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, + grpc_ssl_client_certificate_request_type client_certificate_request); + + grpc_ssl_server_config config_; + grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher_; +}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index dd71c8bc60..3ad0cc353c 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -28,6 +28,7 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/alts/alts_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" @@ -35,64 +36,9 @@ #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_alts_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_alts_server_security_connector; - -static void alts_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void alts_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void alts_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - grpc_alts_credentials* creds = - reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, c->target_name, creds->handshaker_service_url, - true, interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc); - grpc_alts_server_credentials* creds = - reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, nullptr, creds->handshaker_service_url, false, - interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_set_rpc_protocol_versions( +void alts_set_rpc_protocol_versions( grpc_gcp_rpc_protocol_versions* rpc_versions) { grpc_gcp_rpc_protocol_versions_set_max(rpc_versions, GRPC_PROTOCOL_VERSION_MAX_MAJOR, @@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions( GRPC_PROTOCOL_VERSION_MIN_MINOR); } +void alts_check_peer(tsi_peer peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = + grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer); + tsi_peer_destruct(&peer); + grpc_error* error = + *auth_context != nullptr + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not get ALTS auth context from TSI peer"); + GRPC_CLOSURE_SCHED(on_peer_checked, error); +} + +class grpc_alts_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_alts_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(/*url_scheme=*/nullptr, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) { + grpc_alts_credentials* creds = + static_cast<grpc_alts_credentials*>(mutable_channel_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + + ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_credentials* creds = + static_cast<const grpc_alts_credentials*>(channel_creds()); + GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_, + creds->handshaker_service_url(), true, + interested_parties, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_alts_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ALTS call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + private: + char* target_name_; +}; + +class grpc_alts_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_alts_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(/*url_scheme=*/nullptr, + std::move(server_creds)) { + grpc_alts_server_credentials* creds = + reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + ~grpc_alts_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_server_credentials* creds = + static_cast<const grpc_alts_server_credentials*>(server_creds()); + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), nullptr, creds->handshaker_service_url(), + false, interested_parties, &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + alts_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } +}; +} // namespace + namespace grpc_core { namespace internal { - -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx) { - if (peer == nullptr || ctx == nullptr) { +grpc_core::RefCountedPtr<grpc_auth_context> +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { + if (peer == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to grpc_alts_auth_context_from_tsi_peer()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - *ctx = nullptr; /* Validate certificate type. */ const tsi_peer_property* cert_type_prop = tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); @@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE, cert_type_prop->value.length) != 0) { gpr_log(GPR_ERROR, "Invalid or missing certificate type property."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Validate RPC protocol versions. */ const tsi_peer_property* rpc_versions_prop = tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS); if (rpc_versions_prop == nullptr) { gpr_log(GPR_ERROR, "Missing rpc protocol versions property."); - return GRPC_SECURITY_ERROR; + return nullptr; } grpc_gcp_rpc_protocol_versions local_versions, peer_versions; alts_set_rpc_protocol_versions(&local_versions); @@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( grpc_slice_unref_internal(slice); if (!decode_result) { gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* TODO: Pass highest common rpc protocol version to grpc caller. */ bool check_result = grpc_gcp_rpc_protocol_versions_check( &local_versions, &peer_versions, nullptr); if (!check_result) { gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + auto ctx = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_ALTS_TRANSPORT_SECURITY_TYPE); size_t i = 0; for (i = 0; i < peer->property_count; i++) { @@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( /* Add service account to auth context. */ if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) { grpc_auth_context_add_property( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data, - tsi_prop->value.length); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + tsi_prop->value.data, tsi_prop->value.length); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); } } - if (!grpc_auth_context_peer_is_authenticated(*ctx)) { + if (!grpc_auth_context_peer_is_authenticated(ctx.get())) { gpr_log(GPR_ERROR, "Invalid unauthenticated peer."); - GRPC_AUTH_CONTEXT_UNREF(*ctx, "test"); - *ctx = nullptr; - return GRPC_SECURITY_ERROR; + ctx.reset(DEBUG_LOCATION, "test"); + return nullptr; } - return GRPC_SECURITY_OK; + return ctx; } } // namespace internal } // namespace grpc_core -static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; - status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer( - &peer, auth_context); - tsi_peer_destruct(&peer); - grpc_error* error = - status == GRPC_SECURITY_OK - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not get ALTS auth context from TSI peer"); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} - -static int alts_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_channel_security_connector* c1 = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc1); - grpc_alts_channel_security_connector* c2 = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int alts_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_server_security_connector* c1 = - reinterpret_cast<grpc_alts_server_security_connector*>(sc1); - grpc_alts_server_security_connector* c2 = - reinterpret_cast<grpc_alts_server_security_connector*>(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_connector_vtable alts_channel_vtable = { - alts_channel_destroy, alts_check_peer, alts_channel_cmp}; - -static grpc_security_connector_vtable alts_server_vtable = { - alts_server_destroy, alts_check_peer, alts_server_cmp}; - -static bool alts_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_alts_channel_security_connector* alts_sc = - reinterpret_cast<grpc_alts_channel_security_connector*>(sc); - if (host == nullptr || alts_sc == nullptr || - strcmp(host, alts_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "ALTS call host does not match target name"); - } - return true; -} - -static void alts_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_alts_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_alts_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_channel_vtable; - c->base.add_handshakers = alts_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = alts_check_call_host; - c->base.cancel_check_call_host = alts_cancel_check_call_host; - grpc_alts_credentials* creds = - reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_alts_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), target_name); } -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_alts_server_security_connector*>( - gpr_zalloc(sizeof(grpc_alts_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.add_handshakers = alts_server_add_handshakers; - grpc_alts_server_credentials* creds = - reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_alts_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.h b/src/core/lib/security/security_connector/alts/alts_security_connector.h index d2e057a76a..b96dc36b30 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.h +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.h @@ -36,12 +36,13 @@ * - sc: address of ALTS channel security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name); /** * This method creates an ALTS server security connector. @@ -50,17 +51,18 @@ grpc_security_status grpc_alts_channel_security_connector_create( * - sc: address of ALTS server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); namespace grpc_core { namespace internal { /* Exposed only for testing. */ -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx); +grpc_core::RefCountedPtr<grpc_auth_context> +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer); } // namespace internal } // namespace grpc_core diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc index 5c0c89b88f..e3b8affb36 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -31,6 +31,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/fake/fake_credentials.h" @@ -38,91 +39,183 @@ #include "src/core/lib/security/transport/target_authority_table.h" #include "src/core/tsi/fake_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target; - char* expected_targets; - bool is_lb_channel; - char* target_name_override; -} grpc_fake_channel_security_connector; +namespace { +class grpc_fake_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_fake_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args) + : grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_(gpr_strdup(target)), + expected_targets_( + gpr_strdup(grpc_fake_transport_get_expected_targets(args))), + is_lb_channel_(grpc_core::FindTargetAuthorityTableInArgs(args) != + nullptr) { + const grpc_arg* target_name_override_arg = + grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); + if (target_name_override_arg != nullptr) { + target_name_override_ = + gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + } else { + target_name_override_ = nullptr; + } + } -static void fake_channel_destroy(grpc_security_connector* sc) { - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - gpr_free(c->target); - gpr_free(c->expected_targets); - gpr_free(c->target_name_override); - gpr_free(c); -} + ~grpc_fake_channel_security_connector() override { + gpr_free(target_); + gpr_free(expected_targets_); + if (target_name_override_ != nullptr) gpr_free(target_name_override_); + } -static void fake_server_destroy(grpc_security_connector* sc) { gpr_free(sc); } + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override; -static bool fake_check_target(const char* target_type, const char* target, - const char* set_str) { - GPR_ASSERT(target_type != nullptr); - GPR_ASSERT(target != nullptr); - char** set = nullptr; - size_t set_size = 0; - gpr_string_split(set_str, ",", &set, &set_size); - bool found = false; - for (size_t i = 0; i < set_size; ++i) { - if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_fake_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_, other->target_); + if (c != 0) return c; + if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) { + c = GPR_ICMP(expected_targets_, other->expected_targets_); + } else { + c = strcmp(expected_targets_, other->expected_targets_); + } + if (c != 0) return c; + return GPR_ICMP(is_lb_channel_, other->is_lb_channel_); } - for (size_t i = 0; i < set_size; ++i) { - gpr_free(set[i]); + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*is_client=*/true), this)); } - gpr_free(set); - return found; -} -static void fake_secure_name_check(const char* target, - const char* expected_targets, - bool is_lb_channel) { - if (expected_targets == nullptr) return; - char** lbs_and_backends = nullptr; - size_t lbs_and_backends_size = 0; - bool success = false; - gpr_string_split(expected_targets, ";", &lbs_and_backends, - &lbs_and_backends_size); - if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { - gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", - expected_targets); - goto done; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + char* authority_hostname = nullptr; + char* authority_ignored_port = nullptr; + char* target_hostname = nullptr; + char* target_ignored_port = nullptr; + gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); + gpr_split_host_port(target_, &target_hostname, &target_ignored_port); + if (target_name_override_ != nullptr) { + char* fake_security_target_name_override_hostname = nullptr; + char* fake_security_target_name_override_ignored_port = nullptr; + gpr_split_host_port(target_name_override_, + &fake_security_target_name_override_hostname, + &fake_security_target_name_override_ignored_port); + if (strcmp(authority_hostname, + fake_security_target_name_override_hostname) != 0) { + gpr_log(GPR_ERROR, + "Authority (host) '%s' != Fake Security Target override '%s'", + host, fake_security_target_name_override_hostname); + abort(); + } + gpr_free(fake_security_target_name_override_hostname); + gpr_free(fake_security_target_name_override_ignored_port); + } else if (strcmp(authority_hostname, target_hostname) != 0) { + gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", + authority_hostname, target_hostname); + abort(); + } + gpr_free(authority_hostname); + gpr_free(authority_ignored_port); + gpr_free(target_hostname); + gpr_free(target_ignored_port); + return true; } - if (is_lb_channel) { - if (lbs_and_backends_size != 2) { - gpr_log(GPR_ERROR, - "Invalid expected targets arg value: '%s'. Expectations for LB " - "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", - expected_targets); - goto done; + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + char* target() const { return target_; } + char* expected_targets() const { return expected_targets_; } + bool is_lb_channel() const { return is_lb_channel_; } + char* target_name_override() const { return target_name_override_; } + + private: + bool fake_check_target(const char* target_type, const char* target, + const char* set_str) const { + GPR_ASSERT(target_type != nullptr); + GPR_ASSERT(target != nullptr); + char** set = nullptr; + size_t set_size = 0; + gpr_string_split(set_str, ",", &set, &set_size); + bool found = false; + for (size_t i = 0; i < set_size; ++i) { + if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; } - if (!fake_check_target("LB", target, lbs_and_backends[1])) { - gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", - target, lbs_and_backends[1]); - goto done; + for (size_t i = 0; i < set_size; ++i) { + gpr_free(set[i]); } - success = true; - } else { - if (!fake_check_target("Backend", target, lbs_and_backends[0])) { - gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", - target, lbs_and_backends[0]); + gpr_free(set); + return found; + } + + void fake_secure_name_check() const { + if (expected_targets_ == nullptr) return; + char** lbs_and_backends = nullptr; + size_t lbs_and_backends_size = 0; + bool success = false; + gpr_string_split(expected_targets_, ";", &lbs_and_backends, + &lbs_and_backends_size); + if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { + gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", + expected_targets_); goto done; } - success = true; - } -done: - for (size_t i = 0; i < lbs_and_backends_size; ++i) { - gpr_free(lbs_and_backends[i]); + if (is_lb_channel_) { + if (lbs_and_backends_size != 2) { + gpr_log(GPR_ERROR, + "Invalid expected targets arg value: '%s'. Expectations for LB " + "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", + expected_targets_); + goto done; + } + if (!fake_check_target("LB", target_, lbs_and_backends[1])) { + gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", + target_, lbs_and_backends[1]); + goto done; + } + success = true; + } else { + if (!fake_check_target("Backend", target_, lbs_and_backends[0])) { + gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", + target_, lbs_and_backends[0]); + goto done; + } + success = true; + } + done: + for (size_t i = 0; i < lbs_and_backends_size; ++i) { + gpr_free(lbs_and_backends[i]); + } + gpr_free(lbs_and_backends); + if (!success) abort(); } - gpr_free(lbs_and_backends); - if (!success) abort(); -} -static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { + char* target_; + char* expected_targets_; + bool is_lb_channel_; + char* target_name_override_; +}; + +static void fake_check_peer( + grpc_security_connector* sc, tsi_peer peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { const char* prop_name; grpc_error* error = GRPC_ERROR_NONE; *auth_context = nullptr; @@ -147,164 +240,66 @@ static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, "Invalid value for cert type property."); goto end; } - *auth_context = grpc_auth_context_create(nullptr); + *auth_context = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_FAKE_TRANSPORT_SECURITY_TYPE); end: GRPC_CLOSURE_SCHED(on_peer_checked, error); tsi_peer_destruct(&peer); } -static void fake_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - fake_secure_name_check(c->target, c->expected_targets, c->is_lb_channel); +void grpc_fake_channel_security_connector::check_peer( + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) { + fake_check_peer(this, peer, auth_context, on_peer_checked); + fake_secure_name_check(); } -static void fake_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); -} +class grpc_fake_server_security_connector + : public grpc_server_security_connector { + public: + grpc_fake_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(server_creds)) {} + ~grpc_fake_server_security_connector() override = default; -static int fake_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_fake_channel_security_connector* c1 = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc1); - grpc_fake_channel_security_connector* c2 = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target, c2->target); - if (c != 0) return c; - if (c1->expected_targets == nullptr || c2->expected_targets == nullptr) { - c = GPR_ICMP(c1->expected_targets, c2->expected_targets); - } else { - c = strcmp(c1->expected_targets, c2->expected_targets); + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + fake_check_peer(this, peer, auth_context, on_peer_checked); } - if (c != 0) return c; - return GPR_ICMP(c1->is_lb_channel, c2->is_lb_channel); -} -static int fake_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast<grpc_server_security_connector*>(sc1), - reinterpret_cast<grpc_server_security_connector*>(sc2)); -} - -static bool fake_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_fake_channel_security_connector* c = - reinterpret_cast<grpc_fake_channel_security_connector*>(sc); - char* authority_hostname = nullptr; - char* authority_ignored_port = nullptr; - char* target_hostname = nullptr; - char* target_ignored_port = nullptr; - gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); - gpr_split_host_port(c->target, &target_hostname, &target_ignored_port); - if (c->target_name_override != nullptr) { - char* fake_security_target_name_override_hostname = nullptr; - char* fake_security_target_name_override_ignored_port = nullptr; - gpr_split_host_port(c->target_name_override, - &fake_security_target_name_override_hostname, - &fake_security_target_name_override_ignored_port); - if (strcmp(authority_hostname, - fake_security_target_name_override_hostname) != 0) { - gpr_log(GPR_ERROR, - "Authority (host) '%s' != Fake Security Target override '%s'", - host, fake_security_target_name_override_hostname); - abort(); - } - gpr_free(fake_security_target_name_override_hostname); - gpr_free(fake_security_target_name_override_ignored_port); - } else if (strcmp(authority_hostname, target_hostname) != 0) { - gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", - authority_hostname, target_hostname); - abort(); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*=is_client*/ false), this)); } - gpr_free(authority_hostname); - gpr_free(authority_ignored_port); - gpr_free(target_hostname); - gpr_free(target_ignored_port); - return true; -} -static void fake_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -static void fake_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(true /* is_client */), &sc->base)); -} - -static void fake_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(false /* is_client */), &sc->base)); -} - -static grpc_security_connector_vtable fake_channel_vtable = { - fake_channel_destroy, fake_channel_check_peer, fake_channel_cmp}; - -static grpc_security_connector_vtable fake_server_vtable = { - fake_server_destroy, fake_server_check_peer, fake_server_cmp}; - -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args) { - grpc_fake_channel_security_connector* c = - static_cast<grpc_fake_channel_security_connector*>( - gpr_zalloc(sizeof(*c))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->base.base.vtable = &fake_channel_vtable; - c->base.channel_creds = channel_creds; - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = fake_channel_check_call_host; - c->base.cancel_check_call_host = fake_channel_cancel_check_call_host; - c->base.add_handshakers = fake_channel_add_handshakers; - c->target = gpr_strdup(target); - const char* expected_targets = grpc_fake_transport_get_expected_targets(args); - c->expected_targets = gpr_strdup(expected_targets); - c->is_lb_channel = grpc_core::FindTargetAuthorityTableInArgs(args) != nullptr; - const grpc_arg* target_name_override_arg = - grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); - if (target_name_override_arg != nullptr) { - c->target_name_override = - gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); } - return &c->base; +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args) { + return grpc_core::MakeRefCounted<grpc_fake_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), target, + args); } -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds) { - grpc_server_security_connector* c = - static_cast<grpc_server_security_connector*>( - gpr_zalloc(sizeof(grpc_server_security_connector))); - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &fake_server_vtable; - c->base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->server_creds = server_creds; - c->add_handshakers = fake_server_add_handshakers; - return c; +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + return grpc_core::MakeRefCounted<grpc_fake_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.h b/src/core/lib/security/security_connector/fake/fake_security_connector.h index fdfe048c6e..344a2349a4 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.h +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.h @@ -24,19 +24,22 @@ #include <grpc/grpc_security.h> #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/security_connector/security_connector.h" #define GRPC_FAKE_SECURITY_URL_SCHEME "http+fake_security" /* Creates a fake connector that emulates real channel security. */ -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target, const grpc_channel_args* args); /* Creates a fake connector that emulates real server security. */ -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_FAKE_FAKE_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc index 008a98df28..7cc482c16c 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.cc +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -30,217 +30,224 @@ #include "src/core/ext/filters/client_channel/client_channel.h" #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/iomgr/resolve_address.h" +#include "src/core/lib/iomgr/sockaddr.h" +#include "src/core/lib/iomgr/sockaddr_utils.h" +#include "src/core/lib/iomgr/socket_utils.h" +#include "src/core/lib/iomgr/unix_sockets_posix.h" #include "src/core/lib/security/credentials/local/local_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/tsi/local_transport_security.h" #define GRPC_UDS_URI_PATTERN "unix:" -#define GRPC_UDS_URL_SCHEME "unix" #define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_local_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_local_server_security_connector; - -static void local_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_local_channel_security_connector*>(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void local_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast<grpc_local_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void local_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void local_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static int local_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_channel_security_connector* c1 = - reinterpret_cast<grpc_local_channel_security_connector*>(sc1); - grpc_local_channel_security_connector* c2 = - reinterpret_cast<grpc_local_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int local_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_server_security_connector* c1 = - reinterpret_cast<grpc_local_server_security_connector*>(sc1); - grpc_local_server_security_connector* c2 = - reinterpret_cast<grpc_local_server_security_connector*>(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_status local_auth_context_create(grpc_auth_context** ctx) { - if (ctx == nullptr) { - gpr_log(GPR_ERROR, "Invalid arguments to local_auth_context_create()"); - return GRPC_SECURITY_ERROR; - } +grpc_core::RefCountedPtr<grpc_auth_context> local_auth_context_create() { /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr<grpc_auth_context> ctx = + grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_LOCAL_TRANSPORT_SECURITY_TYPE); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); - return GRPC_SECURITY_OK; + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); + return ctx; } -static void local_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; +void local_check_peer(grpc_security_connector* sc, tsi_peer peer, + grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked, + grpc_local_connect_type type) { + int fd = grpc_endpoint_get_fd(ep); + grpc_resolved_address resolved_addr; + memset(&resolved_addr, 0, sizeof(resolved_addr)); + resolved_addr.len = GRPC_MAX_SOCKADDR_SIZE; + bool is_endpoint_local = false; + if (getsockname(fd, reinterpret_cast<grpc_sockaddr*>(resolved_addr.addr), + &resolved_addr.len) == 0) { + grpc_resolved_address addr_normalized; + grpc_resolved_address* addr = + grpc_sockaddr_is_v4mapped(&resolved_addr, &addr_normalized) + ? &addr_normalized + : &resolved_addr; + grpc_sockaddr* sock_addr = reinterpret_cast<grpc_sockaddr*>(&addr->addr); + // UDS + if (type == UDS && grpc_is_unix_socket(addr)) { + is_endpoint_local = true; + // IPV4 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET) { + const grpc_sockaddr_in* addr4 = + reinterpret_cast<const grpc_sockaddr_in*>(sock_addr); + if (grpc_htonl(addr4->sin_addr.s_addr) == INADDR_LOOPBACK) { + is_endpoint_local = true; + } + // IPv6 + } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET6) { + const grpc_sockaddr_in6* addr6 = + reinterpret_cast<const grpc_sockaddr_in6*>(addr); + if (memcmp(&addr6->sin6_addr, &in6addr_loopback, + sizeof(in6addr_loopback)) == 0) { + is_endpoint_local = true; + } + } + } + grpc_error* error = GRPC_ERROR_NONE; + if (!is_endpoint_local) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Endpoint is neither UDS or TCP loopback address."); + GRPC_CLOSURE_SCHED(on_peer_checked, error); + return; + } /* Create an auth context which is necessary to pass the santiy check in * {client, server}_auth_filter that verifies if the peer's auth context is * obtained during handshakes. The auth context is only checked for its * existence and not actually used. */ - status = local_auth_context_create(auth_context); - grpc_error* error = status == GRPC_SECURITY_OK - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not create local auth context"); + *auth_context = local_auth_context_create(); + error = *auth_context != nullptr ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not create local auth context"); GRPC_CLOSURE_SCHED(on_peer_checked, error); } -static grpc_security_connector_vtable local_channel_vtable = { - local_channel_destroy, local_check_peer, local_channel_cmp}; - -static grpc_security_connector_vtable local_server_vtable = { - local_server_destroy, local_check_peer, local_server_cmp}; - -static bool local_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_local_channel_security_connector* local_sc = - reinterpret_cast<grpc_local_channel_security_connector*>(sc); - if (host == nullptr || local_sc == nullptr || - strcmp(host, local_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "local call host does not match target name"); +class grpc_local_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_local_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(nullptr, std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) {} + + ~grpc_local_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == + TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - return true; -} -static void local_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_local_channel_security_connector*>( + other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_credentials* creds = + reinterpret_cast<grpc_local_credentials*>(mutable_channel_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "local call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { + const char* target_name() const { return target_name_; } + + private: + char* target_name_; +}; + +class grpc_local_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_local_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(nullptr, std::move(server_creds)) {} + ~grpc_local_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_local_server_credentials* creds = + static_cast<grpc_local_server_credentials*>(mutable_server_creds()); + local_check_peer(this, peer, ep, auth_context, on_peer_checked, + creds->connect_type()); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_channel_args* args, const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - // Check if local_connect_type is UDS. Only UDS is supported for now. + // Perform sanity check on UDS address. For TCP local connection, the check + // will be done during check_peer procedure. grpc_local_credentials* creds = - reinterpret_cast<grpc_local_credentials*>(channel_creds); - if (creds->connect_type != UDS) { - gpr_log(GPR_ERROR, - "Invalid local channel type to " - "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; - } - // Check if target_name is a valid UDS address. + static_cast<grpc_local_credentials*>(channel_creds.get()); const grpc_arg* server_uri_arg = grpc_channel_args_find(args, GRPC_ARG_SERVER_URI); const char* server_uri_str = grpc_channel_arg_get_string(server_uri_arg); - if (strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, + if (creds->connect_type() == UDS && + strncmp(GRPC_UDS_URI_PATTERN, server_uri_str, strlen(GRPC_UDS_URI_PATTERN)) != 0) { gpr_log(GPR_ERROR, - "Invalid target_name to " + "Invalid UDS target name to " "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_local_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_local_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_channel_vtable; - c->base.add_handshakers = local_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = local_check_call_host; - c->base.cancel_check_call_host = local_cancel_check_call_host; - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_local_channel_security_connector>( + channel_creds, request_metadata_creds, target_name); } -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; - } - // Check if local_connect_type is UDS. Only UDS is supported for now. - grpc_local_server_credentials* creds = - reinterpret_cast<grpc_local_server_credentials*>(server_creds); - if (creds->connect_type != UDS) { - gpr_log(GPR_ERROR, - "Invalid local server type to " - "grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast<grpc_local_server_security_connector*>( - gpr_zalloc(sizeof(grpc_local_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->base.add_handshakers = local_server_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted<grpc_local_server_security_connector>( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/local/local_security_connector.h b/src/core/lib/security/security_connector/local/local_security_connector.h index 5369a2127a..6eee0ca9a6 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.h +++ b/src/core/lib/security/security_connector/local/local_security_connector.h @@ -34,13 +34,13 @@ * - sc: address of local channel security connector instance to be returned * from the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_channel_args* args, const char* target_name); /** * This method creates a local server security connector. @@ -49,10 +49,11 @@ grpc_security_status grpc_local_channel_security_connector_create( * - sc: address of local server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_LOCAL_LOCAL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index 02cecb0eb1..96a1960546 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -35,150 +35,67 @@ #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/security/transport/security_handshaker.h" grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount( false, "security_connector_refcount"); -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - if (sc == nullptr) { - GRPC_CLOSURE_SCHED(on_peer_checked, - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check peer -- no security connector")); - tsi_peer_destruct(&peer); - } else { - sc->vtable->check_peer(sc, peer, auth_context, on_peer_checked); - } -} - -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other) { +grpc_server_security_connector::grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_security_connector(url_scheme), + server_creds_(std::move(server_creds)) {} + +grpc_channel_security_connector::grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds) + : grpc_security_connector(url_scheme), + channel_creds_(std::move(channel_creds)), + request_metadata_creds_(std::move(request_metadata_creds)) {} +grpc_channel_security_connector::~grpc_channel_security_connector() {} + +int grpc_security_connector_cmp(const grpc_security_connector* sc, + const grpc_security_connector* other) { if (sc == nullptr || other == nullptr) return GPR_ICMP(sc, other); - int c = GPR_ICMP(sc->vtable, other->vtable); - if (c != 0) return c; - return sc->vtable->cmp(sc, other); + return sc->cmp(other); } -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2) { - GPR_ASSERT(sc1->channel_creds != nullptr); - GPR_ASSERT(sc2->channel_creds != nullptr); - int c = GPR_ICMP(sc1->channel_creds, sc2->channel_creds); - if (c != 0) return c; - c = GPR_ICMP(sc1->request_metadata_creds, sc2->request_metadata_creds); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->check_call_host, (void*)sc2->check_call_host); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->cancel_check_call_host, - (void*)sc2->cancel_check_call_host); +int grpc_channel_security_connector::channel_security_connector_cmp( + const grpc_channel_security_connector* other) const { + const grpc_channel_security_connector* other_sc = + static_cast<const grpc_channel_security_connector*>(other); + GPR_ASSERT(channel_creds() != nullptr); + GPR_ASSERT(other_sc->channel_creds() != nullptr); + int c = GPR_ICMP(channel_creds(), other_sc->channel_creds()); if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); + return GPR_ICMP(request_metadata_creds(), other_sc->request_metadata_creds()); } -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2) { - GPR_ASSERT(sc1->server_creds != nullptr); - GPR_ASSERT(sc2->server_creds != nullptr); - int c = GPR_ICMP(sc1->server_creds, sc2->server_creds); - if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); -} - -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error) { - if (sc == nullptr || sc->check_call_host == nullptr) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check call host -- no security connector"); - return true; - } - return sc->check_call_host(sc, host, auth_context, on_call_host_checked, - error); -} - -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - if (sc == nullptr || sc->cancel_check_call_host == nullptr) { - GRPC_ERROR_UNREF(error); - return; - } - sc->cancel_check_call_host(sc, on_call_host_checked, error); -} - -#ifndef NDEBUG -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc, const char* file, int line, - const char* reason) { - if (sc == nullptr) return nullptr; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val + 1, reason); - } -#else -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc) { - if (sc == nullptr) return nullptr; -#endif - gpr_ref(&sc->refcount); - return sc; -} - -#ifndef NDEBUG -void grpc_security_connector_unref(grpc_security_connector* sc, - const char* file, int line, - const char* reason) { - if (sc == nullptr) return; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val - 1, reason); - } -#else -void grpc_security_connector_unref(grpc_security_connector* sc) { - if (sc == nullptr) return; -#endif - if (gpr_unref(&sc->refcount)) sc->vtable->destroy(sc); +int grpc_server_security_connector::server_security_connector_cmp( + const grpc_server_security_connector* other) const { + const grpc_server_security_connector* other_sc = + static_cast<const grpc_server_security_connector*>(other); + GPR_ASSERT(server_creds() != nullptr); + GPR_ASSERT(other_sc->server_creds() != nullptr); + return GPR_ICMP(server_creds(), other_sc->server_creds()); } static void connector_arg_destroy(void* p) { - GRPC_SECURITY_CONNECTOR_UNREF((grpc_security_connector*)p, - "connector_arg_destroy"); + static_cast<grpc_security_connector*>(p)->Unref(DEBUG_LOCATION, + "connector_arg_destroy"); } static void* connector_arg_copy(void* p) { - return GRPC_SECURITY_CONNECTOR_REF((grpc_security_connector*)p, - "connector_arg_copy"); + return static_cast<grpc_security_connector*>(p) + ->Ref(DEBUG_LOCATION, "connector_arg_copy") + .release(); } static int connector_cmp(void* a, void* b) { - return grpc_security_connector_cmp(static_cast<grpc_security_connector*>(a), - static_cast<grpc_security_connector*>(b)); + return static_cast<grpc_security_connector*>(a)->cmp( + static_cast<grpc_security_connector*>(b)); } static const grpc_arg_pointer_vtable connector_arg_vtable = { diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 4c921a8793..74b0ef21a6 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -26,6 +26,7 @@ #include <grpc/grpc_security.h> #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/iomgr/tcp_server.h" @@ -34,8 +35,6 @@ extern grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount; -/* --- status enum. --- */ - typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; /* --- security_connector object. --- @@ -43,54 +42,34 @@ typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; A security connector object represents away to configure the underlying transport security mechanism and check the resulting trusted peer. */ -typedef struct grpc_security_connector grpc_security_connector; - #define GRPC_ARG_SECURITY_CONNECTOR "grpc.security_connector" -typedef struct { - void (*destroy)(grpc_security_connector* sc); - void (*check_peer)(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - int (*cmp)(grpc_security_connector* sc, grpc_security_connector* other); -} grpc_security_connector_vtable; - -struct grpc_security_connector { - const grpc_security_connector_vtable* vtable; - gpr_refcount refcount; - const char* url_scheme; -}; +class grpc_security_connector + : public grpc_core::RefCounted<grpc_security_connector> { + public: + explicit grpc_security_connector(const char* url_scheme) + : grpc_core::RefCounted<grpc_security_connector>( + &grpc_trace_security_connector_refcount), + url_scheme_(url_scheme) {} + virtual ~grpc_security_connector() = default; + + /* Check the peer. Callee takes ownership of the peer object. + When done, sets *auth_context and invokes on_peer_checked. */ + virtual void check_peer( + tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) GRPC_ABSTRACT; + + /* Compares two security connectors. */ + virtual int cmp(const grpc_security_connector* other) const GRPC_ABSTRACT; + + const char* url_scheme() const { return url_scheme_; } -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_SECURITY_CONNECTOR_REF(p, r) \ - grpc_security_connector_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) \ - grpc_security_connector_unref((p), __FILE__, __LINE__, (r)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy, const char* file, int line, - const char* reason); -void grpc_security_connector_unref(grpc_security_connector* policy, - const char* file, int line, - const char* reason); -#else -#define GRPC_SECURITY_CONNECTOR_REF(p, r) grpc_security_connector_ref((p)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) grpc_security_connector_unref((p)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy); -void grpc_security_connector_unref(grpc_security_connector* policy); -#endif - -/* Check the peer. Callee takes ownership of the peer object. - When done, sets *auth_context and invokes on_peer_checked. */ -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - -/* Compares two security connectors. */ -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other); + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* url_scheme_; +}; /* Util to encapsulate the connector in a channel arg. */ grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc); @@ -107,71 +86,89 @@ grpc_security_connector* grpc_security_connector_find_in_args( A channel security connector object represents a way to configure the underlying transport security mechanism on the client side. */ -typedef struct grpc_channel_security_connector grpc_channel_security_connector; - -struct grpc_channel_security_connector { - grpc_security_connector base; - grpc_channel_credentials* channel_creds; - grpc_call_credentials* request_metadata_creds; - bool (*check_call_host)(grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error); - void (*cancel_check_call_host)(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error); - void (*add_handshakers)(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_channel_security_connector : public grpc_security_connector { + public: + grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds); + ~grpc_channel_security_connector() override; + + /// Checks that the host that will be set for a call is acceptable. + /// Returns true if completed synchronously, in which case \a error will + /// be set to indicate the result. Otherwise, \a on_call_host_checked + /// will be invoked when complete. + virtual bool check_call_host(const char* host, + grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) GRPC_ABSTRACT; + /// Cancels a pending asychronous call to + /// grpc_channel_security_connector_check_call_host() with + /// \a on_call_host_checked as its callback. + virtual void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) GRPC_ABSTRACT; + /// Registers handshakers with \a handshake_mgr. + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_channel_credentials* channel_creds() const { + return channel_creds_.get(); + } + grpc_channel_credentials* mutable_channel_creds() { + return channel_creds_.get(); + } + const grpc_call_credentials* request_metadata_creds() const { + return request_metadata_creds_.get(); + } + grpc_call_credentials* mutable_request_metadata_creds() { + return request_metadata_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int channel_security_connector_cmp( + const grpc_channel_security_connector* other) const; + + private: + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_; + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2); - -/// Checks that the host that will be set for a call is acceptable. -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_call_host_checked -/// will be invoked when complete. -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error); - -/// Cancels a pending asychronous call to -/// grpc_channel_security_connector_check_call_host() with -/// \a on_call_host_checked as its callback. -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error); - -/* Registers handshakers with \a handshake_mgr. */ -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - /* --- server_security_connector object. --- A server security connector object represents a way to configure the underlying transport security mechanism on the server side. */ -typedef struct grpc_server_security_connector grpc_server_security_connector; - -struct grpc_server_security_connector { - grpc_security_connector base; - grpc_server_credentials* server_creds; - void (*add_handshakers)(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_server_security_connector : public grpc_security_connector { + public: + grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds); + ~grpc_server_security_connector() override = default; + + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_server_credentials* server_creds() const { + return server_creds_.get(); + } + grpc_server_credentials* mutable_server_creds() { + return server_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int server_security_connector_cmp( + const grpc_server_security_connector* other) const; + + private: + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2); - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */ diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 20a9533dd1..7414ab1a37 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/ssl/ssl_credentials.h" @@ -39,172 +40,10 @@ #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* client_handshaker_factory; - char* target_name; - char* overridden_target_name; - const verify_peer_options* verify_options; -} grpc_ssl_channel_security_connector; - -typedef struct { - grpc_server_security_connector base; - tsi_ssl_server_handshaker_factory* server_handshaker_factory; -} grpc_ssl_server_security_connector; - -static bool server_connector_has_cert_config_fetcher( - grpc_ssl_server_security_connector* c) { - GPR_ASSERT(c != nullptr); - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(c->base.server_creds); - GPR_ASSERT(server_creds != nullptr); - return server_creds->certificate_config_fetcher.cb != nullptr; -} - -static void ssl_channel_destroy(grpc_security_connector* sc) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - grpc_channel_credentials_unref(c->base.channel_creds); - grpc_call_credentials_unref(c->base.request_metadata_creds); - tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory); - c->client_handshaker_factory = nullptr; - if (c->target_name != nullptr) gpr_free(c->target_name); - if (c->overridden_target_name != nullptr) gpr_free(c->overridden_target_name); - gpr_free(sc); -} - -static void ssl_server_destroy(grpc_security_connector* sc) { - grpc_ssl_server_security_connector* c = - reinterpret_cast<grpc_ssl_server_security_connector*>(sc); - grpc_server_credentials_unref(c->base.server_creds); - tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory); - c->server_handshaker_factory = nullptr; - gpr_free(sc); -} - -static void ssl_channel_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - // Instantiate TSI handshaker. - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->client_handshaker_factory, - c->overridden_target_name != nullptr ? c->overridden_target_name - : c->target_name, - &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -/* Attempts to replace the server_handshaker_factory with a new factory using - * the provided grpc_ssl_server_certificate_config. Should new factory creation - * fail, the existing factory will not be replaced. Returns true on success (new - * factory created). */ -static bool try_replace_server_handshaker_factory( - grpc_ssl_server_security_connector* sc, - const grpc_ssl_server_certificate_config* config) { - if (config == nullptr) { - gpr_log(GPR_ERROR, - "Server certificate config callback returned invalid (NULL) " - "config."); - return false; - } - gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); - - size_t num_alpn_protocols = 0; - const char** alpn_protocol_strings = - grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( - config->pem_key_cert_pairs, config->num_key_cert_pairs); - tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds); - tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( - cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, - grpc_get_tsi_client_certificate_request_type( - server_creds->config.client_certificate_request), - grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory); - gpr_free(cert_pairs); - gpr_free((void*)alpn_protocol_strings); - - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - return false; - } - tsi_ssl_server_handshaker_factory_unref(sc->server_handshaker_factory); - sc->server_handshaker_factory = new_handshaker_factory; - return true; -} - -/* Attempts to fetch the server certificate config if a callback is available. - * Current certificate config will continue to be used if the callback returns - * an error. Returns true if new credentials were sucessfully loaded. */ -static bool try_fetch_ssl_server_credentials( - grpc_ssl_server_security_connector* sc) { - grpc_ssl_server_certificate_config* certificate_config = nullptr; - bool status; - - GPR_ASSERT(sc != nullptr); - if (!server_connector_has_cert_config_fetcher(sc)) return false; - - grpc_ssl_server_credentials* server_creds = - reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds); - grpc_ssl_certificate_config_reload_status cb_result = - server_creds->certificate_config_fetcher.cb( - server_creds->certificate_config_fetcher.user_data, - &certificate_config); - if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { - gpr_log(GPR_DEBUG, "No change in SSL server credentials."); - status = false; - } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { - status = try_replace_server_handshaker_factory(sc, certificate_config); - } else { - // Log error, continue using previously-loaded credentials. - gpr_log(GPR_ERROR, - "Failed fetching new server credentials, continuing to " - "use previously-loaded credentials."); - status = false; - } - - if (certificate_config != nullptr) { - grpc_ssl_server_certificate_config_destroy(certificate_config); - } - return status; -} - -static void ssl_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_server_security_connector* c = - reinterpret_cast<grpc_ssl_server_security_connector*>(sc); - // Instantiate TSI handshaker. - try_fetch_ssl_server_credentials(c); - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( - c->server_handshaker_factory, &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -static grpc_error* ssl_check_peer(grpc_security_connector* sc, - const char* peer_name, const tsi_peer* peer, - grpc_auth_context** auth_context) { +namespace { +grpc_error* ssl_check_peer( + const char* peer_name, const tsi_peer* peer, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context) { #if TSI_OPENSSL_ALPN_SUPPORT /* Check the ALPN if ALPN is supported. */ const tsi_peer_property* p = @@ -230,245 +69,384 @@ static grpc_error* ssl_check_peer(grpc_security_connector* sc, return GRPC_ERROR_NONE; } -static void ssl_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - const char* target_name = c->overridden_target_name != nullptr - ? c->overridden_target_name - : c->target_name; - grpc_error* error = ssl_check_peer(sc, target_name, &peer, auth_context); - if (error == GRPC_ERROR_NONE && - c->verify_options->verify_peer_callback != nullptr) { - const tsi_peer_property* p = - tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); - if (p == nullptr) { - error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Cannot check peer: missing pem cert property."); - } else { - char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1)); - memcpy(peer_pem, p->value.data, p->value.length); - peer_pem[p->value.length] = '\0'; - int callback_status = c->verify_options->verify_peer_callback( - target_name, peer_pem, - c->verify_options->verify_peer_callback_userdata); - gpr_free(peer_pem); - if (callback_status) { - char* msg; - gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", - callback_status); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); - } - } +class grpc_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_ssl_channel_security_connector( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name) + : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + overridden_target_name_(overridden_target_name == nullptr + ? nullptr + : gpr_strdup(overridden_target_name)), + verify_options_(&config->verify_options) { + char* port; + gpr_split_host_port(target_name, &target_name_, &port); + gpr_free(port); } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static void ssl_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_error* error = ssl_check_peer(sc, nullptr, &peer, auth_context); - tsi_peer_destruct(&peer); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} + ~grpc_ssl_channel_security_connector() override { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + if (target_name_ != nullptr) gpr_free(target_name_); + if (overridden_target_name_ != nullptr) gpr_free(overridden_target_name_); + } -static int ssl_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_ssl_channel_security_connector* c1 = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc1); - grpc_ssl_channel_security_connector* c2 = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target_name, c2->target_name); - if (c != 0) return c; - return (c1->overridden_target_name == nullptr || - c2->overridden_target_name == nullptr) - ? GPR_ICMP(c1->overridden_target_name, c2->overridden_target_name) - : strcmp(c1->overridden_target_name, c2->overridden_target_name); -} + grpc_security_status InitializeHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache) { + bool has_key_cert_pair = + config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + GPR_DEBUG_ASSERT(pem_root_certs != nullptr); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory_); + gpr_free((void*)options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; + } -static int ssl_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast<grpc_server_security_connector*>(sc1), - reinterpret_cast<grpc_server_security_connector*>(sc2)); -} + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory_, + overridden_target_name_ != nullptr ? overridden_target_name_ + : target_name_, + &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); + } -static bool ssl_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast<grpc_ssl_channel_security_connector*>(sc); - grpc_security_status status = GRPC_SECURITY_ERROR; - tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); - if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; - /* If the target name was overridden, then the original target_name was - 'checked' transitively during the previous peer check at the end of the - handshake. */ - if (c->overridden_target_name != nullptr && - strcmp(host, c->target_name) == 0) { - status = GRPC_SECURITY_OK; + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + const char* target_name = overridden_target_name_ != nullptr + ? overridden_target_name_ + : target_name_; + grpc_error* error = ssl_check_peer(target_name, &peer, auth_context); + if (error == GRPC_ERROR_NONE && + verify_options_->verify_peer_callback != nullptr) { + const tsi_peer_property* p = + tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + if (p == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing pem cert property."); + } else { + char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1)); + memcpy(peer_pem, p->value.data, p->value.length); + peer_pem[p->value.length] = '\0'; + int callback_status = verify_options_->verify_peer_callback( + target_name, peer_pem, + verify_options_->verify_peer_callback_userdata); + gpr_free(peer_pem); + if (callback_status) { + char* msg; + gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", + callback_status); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + } + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); } - if (status != GRPC_SECURITY_OK) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "call host does not match SSL server name"); + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast<const grpc_ssl_channel_security_connector*>(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_name_, other->target_name_); + if (c != 0) return c; + return (overridden_target_name_ == nullptr || + other->overridden_target_name_ == nullptr) + ? GPR_ICMP(overridden_target_name_, + other->overridden_target_name_) + : strcmp(overridden_target_name_, + other->overridden_target_name_); } - grpc_shallow_peer_destruct(&peer); - return true; -} -static void ssl_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + grpc_security_status status = GRPC_SECURITY_ERROR; + tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); + if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; + /* If the target name was overridden, then the original target_name was + 'checked' transitively during the previous peer check at the end of the + handshake. */ + if (overridden_target_name_ != nullptr && strcmp(host, target_name_) == 0) { + status = GRPC_SECURITY_OK; + } + if (status != GRPC_SECURITY_OK) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "call host does not match SSL server name"); + } + grpc_shallow_peer_destruct(&peer); + return true; + } -static grpc_security_connector_vtable ssl_channel_vtable = { - ssl_channel_destroy, ssl_channel_check_peer, ssl_channel_cmp}; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } -static grpc_security_connector_vtable ssl_server_vtable = { - ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp}; + private: + tsi_ssl_client_handshaker_factory* client_handshaker_factory_; + char* target_name_; + char* overridden_target_name_; + const verify_peer_options* verify_options_; +}; + +class grpc_ssl_server_security_connector + : public grpc_server_security_connector { + public: + grpc_ssl_server_security_connector( + grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) + : grpc_server_security_connector(GRPC_SSL_URL_SCHEME, + std::move(server_creds)) {} + + ~grpc_ssl_server_security_connector() override { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_channel_security_connector* c; - char* port; - bool has_key_cert_pair; - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.alpn_protocols = - grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + bool has_cert_config_fetcher() const { + return static_cast<const grpc_ssl_server_credentials*>(server_creds()) + ->has_cert_config_fetcher(); + } - if (config == nullptr || target_name == nullptr) { - gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); - goto error; + const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const { + return server_handshaker_factory_; } - if (config->pem_root_certs == nullptr) { - // Use default root certificates. - options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); - options.root_store = grpc_core::DefaultSslRootStore::GetRootStore(); - if (options.pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, "Could not get default pem root certs."); - goto error; + + grpc_security_status InitializeHandshakerFactory() { + if (has_cert_config_fetcher()) { + // Load initial credentials from certificate_config_fetcher: + if (!try_fetch_ssl_server_credentials()) { + gpr_log(GPR_ERROR, + "Failed loading SSL server credentials from fetcher."); + return GRPC_SECURITY_ERROR; + } + } else { + auto* server_credentials = + static_cast<const grpc_ssl_server_credentials*>(server_creds()); + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + const tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + server_credentials->config().pem_key_cert_pairs, + server_credentials->config().num_key_cert_pairs, + server_credentials->config().pem_root_certs, + grpc_get_tsi_client_certificate_request_type( + server_credentials->config().client_certificate_request), + grpc_get_ssl_cipher_suites(), alpn_protocol_strings, + static_cast<uint16_t>(num_alpn_protocols), + &server_handshaker_factory_); + gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } } - } else { - options.pem_root_certs = config->pem_root_certs; - } - c = static_cast<grpc_ssl_channel_security_connector*>( - gpr_zalloc(sizeof(grpc_ssl_channel_security_connector))); - - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &ssl_channel_vtable; - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = ssl_channel_check_call_host; - c->base.cancel_check_call_host = ssl_channel_cancel_check_call_host; - c->base.add_handshakers = ssl_channel_add_handshakers; - gpr_split_host_port(target_name, &c->target_name, &port); - gpr_free(port); - if (overridden_target_name != nullptr) { - c->overridden_target_name = gpr_strdup(overridden_target_name); + return GRPC_SECURITY_OK; } - c->verify_options = &config->verify_options; - has_key_cert_pair = config->pem_key_cert_pair != nullptr && - config->pem_key_cert_pair->private_key != nullptr && - config->pem_key_cert_pair->cert_chain != nullptr; - if (has_key_cert_pair) { - options.pem_key_cert_pair = config->pem_key_cert_pair; + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + try_fetch_ssl_server_credentials(); + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory_, &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); } - options.cipher_suites = grpc_get_ssl_cipher_suites(); - options.session_cache = ssl_session_cache; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->client_handshaker_factory); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - ssl_channel_destroy(&c->base.base); - *sc = nullptr; - goto error; + + void check_peer(tsi_peer peer, grpc_endpoint* ep, + grpc_core::RefCountedPtr<grpc_auth_context>* auth_context, + grpc_closure* on_peer_checked) override { + grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context); + tsi_peer_destruct(&peer); + GRPC_CLOSURE_SCHED(on_peer_checked, error); } - *sc = &c->base; - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_OK; -error: - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_ERROR; -} + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast<const grpc_server_security_connector*>(other)); + } -static grpc_ssl_server_security_connector* -grpc_ssl_server_security_connector_initialize( - grpc_server_credentials* server_creds) { - grpc_ssl_server_security_connector* c = - static_cast<grpc_ssl_server_security_connector*>( - gpr_zalloc(sizeof(grpc_ssl_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.base.vtable = &ssl_server_vtable; - c->base.add_handshakers = ssl_server_add_handshakers; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - return c; -} + private: + /* Attempts to fetch the server certificate config if a callback is available. + * Current certificate config will continue to be used if the callback returns + * an error. Returns true if new credentials were sucessfully loaded. */ + bool try_fetch_ssl_server_credentials() { + grpc_ssl_server_certificate_config* certificate_config = nullptr; + bool status; + + if (!has_cert_config_fetcher()) return false; + + grpc_ssl_server_credentials* server_creds = + static_cast<grpc_ssl_server_credentials*>(this->mutable_server_creds()); + grpc_ssl_certificate_config_reload_status cb_result = + server_creds->FetchCertConfig(&certificate_config); + if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { + gpr_log(GPR_DEBUG, "No change in SSL server credentials."); + status = false; + } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { + status = try_replace_server_handshaker_factory(certificate_config); + } else { + // Log error, continue using previously-loaded credentials. + gpr_log(GPR_ERROR, + "Failed fetching new server credentials, continuing to " + "use previously-loaded credentials."); + status = false; + } -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* gsc, grpc_server_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_server_credentials* server_credentials = - reinterpret_cast<grpc_ssl_server_credentials*>(gsc); - grpc_security_status retval = GRPC_SECURITY_OK; + if (certificate_config != nullptr) { + grpc_ssl_server_certificate_config_destroy(certificate_config); + } + return status; + } - GPR_ASSERT(server_credentials != nullptr); - GPR_ASSERT(sc != nullptr); - - grpc_ssl_server_security_connector* c = - grpc_ssl_server_security_connector_initialize(gsc); - if (server_connector_has_cert_config_fetcher(c)) { - // Load initial credentials from certificate_config_fetcher: - if (!try_fetch_ssl_server_credentials(c)) { - gpr_log(GPR_ERROR, "Failed loading SSL server credentials from fetcher."); - retval = GRPC_SECURITY_ERROR; + /* Attempts to replace the server_handshaker_factory with a new factory using + * the provided grpc_ssl_server_certificate_config. Should new factory + * creation fail, the existing factory will not be replaced. Returns true on + * success (new factory created). */ + bool try_replace_server_handshaker_factory( + const grpc_ssl_server_certificate_config* config) { + if (config == nullptr) { + gpr_log(GPR_ERROR, + "Server certificate config callback returned invalid (NULL) " + "config."); + return false; } - } else { + gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); + size_t num_alpn_protocols = 0; const char** alpn_protocol_strings = grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - result = tsi_create_ssl_server_handshaker_factory_ex( - server_credentials->config.pem_key_cert_pairs, - server_credentials->config.num_key_cert_pairs, - server_credentials->config.pem_root_certs, + tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + config->pem_key_cert_pairs, config->num_key_cert_pairs); + tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; + const grpc_ssl_server_credentials* server_creds = + static_cast<const grpc_ssl_server_credentials*>(this->server_creds()); + GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr); + tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, grpc_get_tsi_client_certificate_request_type( - server_credentials->config.client_certificate_request), + server_creds->config().client_certificate_request), grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast<uint16_t>(num_alpn_protocols), - &c->server_handshaker_factory); + static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory); + gpr_free(cert_pairs); gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - retval = GRPC_SECURITY_ERROR; + return false; } + set_server_handshaker_factory(new_handshaker_factory); + return true; + } + + void set_server_handshaker_factory( + tsi_ssl_server_handshaker_factory* new_factory) { + if (server_handshaker_factory_) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + server_handshaker_factory_ = new_factory; + } + + tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr; +}; +} // namespace + +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) { + if (config == nullptr || target_name == nullptr) { + gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); + return nullptr; } - if (retval == GRPC_SECURITY_OK) { - *sc = &c->base; + const char* pem_root_certs; + const tsi_ssl_root_certs_store* root_store; + if (config->pem_root_certs == nullptr) { + // Use default root certificates. + pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return nullptr; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); } else { - if (c != nullptr) ssl_server_destroy(&c->base.base); - if (sc != nullptr) *sc = nullptr; + pem_root_certs = config->pem_root_certs; + root_store = nullptr; + } + + grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c = + grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>( + std::move(channel_creds), std::move(request_metadata_creds), config, + target_name, overridden_target_name); + const grpc_security_status result = c->InitializeHandshakerFactory( + config, pem_root_certs, root_store, ssl_session_cache); + if (result != GRPC_SECURITY_OK) { + return nullptr; } - return retval; + return c; +} + +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials) { + GPR_ASSERT(server_credentials != nullptr); + grpc_core::RefCountedPtr<grpc_ssl_server_security_connector> c = + grpc_core::MakeRefCounted<grpc_ssl_server_security_connector>( + std::move(server_credentials)); + const grpc_security_status retval = c->InitializeHandshakerFactory(); + if (retval != GRPC_SECURITY_OK) { + return nullptr; + } + return c; } diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h index 9b80590606..70e26e338a 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h @@ -25,6 +25,7 @@ #include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,20 +48,21 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, +grpc_core::RefCountedPtr<grpc_channel_security_connector> +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds, + grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc); + tsi_ssl_session_cache* ssl_session_cache); /* Config for ssl servers. */ typedef struct { - tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; - grpc_ssl_client_certificate_request_type client_certificate_request; + tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; + grpc_ssl_client_certificate_request_type client_certificate_request = + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE; } grpc_ssl_server_config; /* Creates an SSL server_security_connector. @@ -69,9 +71,9 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* server_credentials, - grpc_server_security_connector** sc); +grpc_core::RefCountedPtr<grpc_server_security_connector> +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SSL_SSL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/ssl_utils.cc b/src/core/lib/security/security_connector/ssl_utils.cc index fbf41cfbc7..29030f07ad 100644 --- a/src/core/lib/security/security_connector/ssl_utils.cc +++ b/src/core/lib/security/security_connector/ssl_utils.cc @@ -30,6 +30,7 @@ #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/load_file.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/security_connector/load_system_roots.h" @@ -141,16 +142,17 @@ int grpc_ssl_host_matches_name(const tsi_peer* peer, const char* peer_name) { return r; } -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { +grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context( + const tsi_peer* peer) { size_t i; - grpc_auth_context* ctx = nullptr; const char* peer_identity_property_name = nullptr; /* The caller has checked the certificate type property. */ GPR_ASSERT(peer->property_count >= 1); - ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr<grpc_auth_context> ctx = + grpc_core::MakeRefCounted<grpc_auth_context>(nullptr); grpc_auth_context_add_cstring_property( - ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_SSL_TRANSPORT_SECURITY_TYPE); for (i = 0; i < peer->property_count; i++) { const tsi_peer_property* prop = &peer->properties[i]; @@ -160,24 +162,26 @@ grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { if (peer_identity_property_name == nullptr) { peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME; } - grpc_auth_context_add_property(ctx, GRPC_X509_CN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME; - grpc_auth_context_add_property(ctx, GRPC_X509_SAN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), + GRPC_X509_PEM_CERT_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY, + grpc_auth_context_add_property(ctx.get(), + GRPC_SSL_SESSION_REUSED_PROPERTY, prop->value.data, prop->value.length); } } if (peer_identity_property_name != nullptr) { GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, peer_identity_property_name) == 1); + ctx.get(), peer_identity_property_name) == 1); } return ctx; } diff --git a/src/core/lib/security/security_connector/ssl_utils.h b/src/core/lib/security/security_connector/ssl_utils.h index 6f6d473311..c9cd1a1d9c 100644 --- a/src/core/lib/security/security_connector/ssl_utils.h +++ b/src/core/lib/security/security_connector/ssl_utils.h @@ -26,6 +26,7 @@ #include <grpc/grpc_security.h> #include <grpc/slice_buffer.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,7 +48,8 @@ grpc_get_tsi_client_certificate_request_type( const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols); /* Exposed for testing only. */ -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer); +grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context( + const tsi_peer* peer); tsi_peer grpc_shallow_peer_from_ssl_auth_context( const grpc_auth_context* auth_context); void grpc_shallow_peer_destruct(tsi_peer* peer); diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 6955e8698e..66f86b8bc5 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -55,7 +55,7 @@ struct call_data { // that the memory is not initialized. void destroy() { grpc_credentials_mdelem_array_destroy(&md_array); - grpc_call_credentials_unref(creds); + creds.reset(); grpc_slice_unref_internal(host); grpc_slice_unref_internal(method); grpc_auth_metadata_context_reset(&auth_md_context); @@ -64,7 +64,7 @@ struct call_data { gpr_arena* arena; grpc_call_stack* owning_call; grpc_call_combiner* call_combiner; - grpc_call_credentials* creds = nullptr; + grpc_core::RefCountedPtr<grpc_call_credentials> creds; grpc_slice host = grpc_empty_slice(); grpc_slice method = grpc_empty_slice(); /* pollset{_set} bound to this call; if we need to make external @@ -83,8 +83,18 @@ struct call_data { /* We can have a per-channel credentials. */ struct channel_data { - grpc_channel_security_connector* security_connector; - grpc_auth_context* auth_context; + channel_data(grpc_channel_security_connector* security_connector, + grpc_auth_context* auth_context) + : security_connector( + security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")), + auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {} + ~channel_data() { + security_connector.reset(DEBUG_LOCATION, "client_auth_filter"); + auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + } + + grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; }; } // namespace @@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset( gpr_free(const_cast<char*>(auth_md_context->method_name)); auth_md_context->method_name = nullptr; } - GRPC_AUTH_CONTEXT_UNREF( - (grpc_auth_context*)auth_md_context->channel_auth_context, - "grpc_auth_metadata_context"); - auth_md_context->channel_auth_context = nullptr; + if (auth_md_context->channel_auth_context != nullptr) { + const_cast<grpc_auth_context*>(auth_md_context->channel_auth_context) + ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context"); + auth_md_context->channel_auth_context = nullptr; + } } static void add_error(grpc_error** combined, grpc_error* error) { @@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build( auth_md_context->service_url = service_url; auth_md_context->method_name = method_name; auth_md_context->channel_auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context"); + auth_context == nullptr + ? nullptr + : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context") + .release(); gpr_free(service); gpr_free(host_and_port); } @@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast<grpc_call_element*>(arg); call_data* calld = static_cast<call_data*>(elem->call_data); if (error != GRPC_ERROR_NONE) { - grpc_call_credentials_cancel_get_request_metadata( - calld->creds, &calld->md_array, GRPC_ERROR_REF(error)); + calld->creds->cancel_get_request_metadata(&calld->md_array, + GRPC_ERROR_REF(error)); } } @@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem, static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); grpc_call_credentials* channel_call_creds = - chand->security_connector->request_metadata_creds; + chand->security_connector->mutable_request_metadata_creds(); int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); if (channel_call_creds == nullptr && !call_creds_has_md) { @@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem, } if (channel_call_creds != nullptr && call_creds_has_md) { - calld->creds = grpc_composite_call_credentials_create(channel_call_creds, - ctx->creds, nullptr); + calld->creds = grpc_core::RefCountedPtr<grpc_call_credentials>( + grpc_composite_call_credentials_create(channel_call_creds, + ctx->creds.get(), nullptr)); if (calld->creds == nullptr) { grpc_transport_stream_op_batch_finish_with_failure( batch, @@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem, return; } } else { - calld->creds = grpc_call_credentials_ref( - call_creds_has_md ? ctx->creds : channel_call_creds); + calld->creds = + call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref(); } grpc_auth_metadata_context_build( - chand->security_connector->base.url_scheme, calld->host, calld->method, - chand->auth_context, &calld->auth_md_context); + chand->security_connector->url_scheme(), calld->host, calld->method, + chand->auth_context.get(), &calld->auth_md_context); GPR_ASSERT(calld->pollent != nullptr); GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata"); GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata, batch, grpc_schedule_on_exec_ctx); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - calld->creds, calld->pollent, calld->auth_md_context, - &calld->md_array, &calld->async_result_closure, &error)) { + if (calld->creds->get_request_metadata( + calld->pollent, calld->auth_md_context, &calld->md_array, + &calld->async_result_closure, &error)) { // Synchronous return; invoke on_credentials_metadata() directly. on_credentials_metadata(batch, error); GRPC_ERROR_UNREF(error); @@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) { call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); if (error != GRPC_ERROR_NONE) { - grpc_channel_security_connector_cancel_check_call_host( - chand->security_connector, &calld->async_result_closure, - GRPC_ERROR_REF(error)); + chand->security_connector->cancel_check_call_host( + &calld->async_result_closure, GRPC_ERROR_REF(error)); } } @@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch( GPR_ASSERT(batch->payload->context != nullptr); if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) { batch->payload->context[GRPC_CONTEXT_SECURITY].value = - grpc_client_security_context_create(calld->arena); + grpc_client_security_context_create(calld->arena, /*creds=*/nullptr); batch->payload->context[GRPC_CONTEXT_SECURITY].destroy = grpc_client_security_context_destroy; } grpc_client_security_context* sec_ctx = static_cast<grpc_client_security_context*>( batch->payload->context[GRPC_CONTEXT_SECURITY].value); - GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter"); + sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); sec_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter"); } if (batch->send_initial_metadata) { @@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch( grpc_schedule_on_exec_ctx); char* call_host = grpc_slice_to_c_string(calld->host); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_channel_security_connector_check_call_host( - chand->security_connector, call_host, chand->auth_context, + if (chand->security_connector->check_call_host( + call_host, chand->auth_context.get(), &calld->async_result_closure, &error)) { // Synchronous return; invoke on_host_checked() directly. on_host_checked(batch, error); @@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem, /* Constructor for channel_data */ static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { + /* The first and the last filters tend to be implemented differently to + handle the case that there's no 'next' filter to call on the up or down + path */ + GPR_ASSERT(!args->is_last); grpc_security_connector* sc = grpc_security_connector_find_in_args(args->channel_args); if (sc == nullptr) { @@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Auth context missing from client auth filter args"); } - - /* grab pointers to our data from the channel element */ - channel_data* chand = static_cast<channel_data*>(elem->channel_data); - - /* The first and the last filters tend to be implemented differently to - handle the case that there's no 'next' filter to call on the up or down - path */ - GPR_ASSERT(!args->is_last); - - /* initialize members */ - chand->security_connector = - reinterpret_cast<grpc_channel_security_connector*>( - GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter")); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter"); + new (elem->channel_data) channel_data( + static_cast<grpc_channel_security_connector*>(sc), auth_context); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { - /* grab pointers to our data from the channel element */ channel_data* chand = static_cast<channel_data*>(elem->channel_data); - grpc_channel_security_connector* sc = chand->security_connector; - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter"); - } - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter"); + chand->~channel_data(); } const grpc_channel_filter grpc_client_auth_filter = { diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc index 34d8435907..14fb55884f 100644 --- a/src/core/lib/security/transport/secure_endpoint.cc +++ b/src/core/lib/security/transport/secure_endpoint.cc @@ -416,6 +416,11 @@ static grpc_resource_user* endpoint_get_resource_user( return grpc_endpoint_get_resource_user(ep->wrapped_ep); } +static bool endpoint_can_track_err(grpc_endpoint* secure_ep) { + secure_endpoint* ep = reinterpret_cast<secure_endpoint*>(secure_ep); + return grpc_endpoint_can_track_err(ep->wrapped_ep); +} + static const grpc_endpoint_vtable vtable = {endpoint_read, endpoint_write, endpoint_add_to_pollset, @@ -425,7 +430,8 @@ static const grpc_endpoint_vtable vtable = {endpoint_read, endpoint_destroy, endpoint_get_resource_user, endpoint_get_peer, - endpoint_get_fd}; + endpoint_get_fd, + endpoint_can_track_err}; grpc_endpoint* grpc_secure_endpoint_create( struct tsi_frame_protector* protector, diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 854a1c4af9..01831dab10 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/transport/secure_endpoint.h" #include "src/core/lib/security/transport/tsi_error.h" @@ -38,34 +39,62 @@ #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 -typedef struct { +namespace { +struct security_handshaker { + security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector); + ~security_handshaker() { + gpr_mu_destroy(&mu); + tsi_handshaker_destroy(handshaker); + tsi_handshaker_result_destroy(handshaker_result); + if (endpoint_to_destroy != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy); + } + if (read_buffer_to_destroy != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy); + gpr_free(read_buffer_to_destroy); + } + gpr_free(handshake_buffer); + grpc_slice_buffer_destroy_internal(&outgoing); + auth_context.reset(DEBUG_LOCATION, "handshake"); + connector.reset(DEBUG_LOCATION, "handshake"); + } + + void Ref() { refs.Ref(); } + void Unref() { + if (refs.Unref()) { + grpc_core::Delete(this); + } + } + grpc_handshaker base; // State set at creation time. tsi_handshaker* handshaker; - grpc_security_connector* connector; + grpc_core::RefCountedPtr<grpc_security_connector> connector; gpr_mu mu; - gpr_refcount refs; + grpc_core::RefCount refs; - bool shutdown; + bool shutdown = false; // Endpoint and read buffer to destroy after a shutdown. - grpc_endpoint* endpoint_to_destroy; - grpc_slice_buffer* read_buffer_to_destroy; + grpc_endpoint* endpoint_to_destroy = nullptr; + grpc_slice_buffer* read_buffer_to_destroy = nullptr; // State saved while performing the handshake. - grpc_handshaker_args* args; - grpc_closure* on_handshake_done; + grpc_handshaker_args* args = nullptr; + grpc_closure* on_handshake_done = nullptr; - unsigned char* handshake_buffer; size_t handshake_buffer_size; + unsigned char* handshake_buffer; grpc_slice_buffer outgoing; grpc_closure on_handshake_data_sent_to_peer; grpc_closure on_handshake_data_received_from_peer; grpc_closure on_peer_checked; - grpc_auth_context* auth_context; - tsi_handshaker_result* handshaker_result; -} security_handshaker; + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; + tsi_handshaker_result* handshaker_result = nullptr; +}; +} // namespace static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { size_t bytes_in_read_buffer = h->args->read_buffer->length; @@ -85,26 +114,6 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { return bytes_in_read_buffer; } -static void security_handshaker_unref(security_handshaker* h) { - if (gpr_unref(&h->refs)) { - gpr_mu_destroy(&h->mu); - tsi_handshaker_destroy(h->handshaker); - tsi_handshaker_result_destroy(h->handshaker_result); - if (h->endpoint_to_destroy != nullptr) { - grpc_endpoint_destroy(h->endpoint_to_destroy); - } - if (h->read_buffer_to_destroy != nullptr) { - grpc_slice_buffer_destroy_internal(h->read_buffer_to_destroy); - gpr_free(h->read_buffer_to_destroy); - } - gpr_free(h->handshake_buffer); - grpc_slice_buffer_destroy_internal(&h->outgoing); - GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake"); - GRPC_SECURITY_CONNECTOR_UNREF(h->connector, "handshake"); - gpr_free(h); - } -} - // Set args fields to NULL, saving the endpoint and read buffer for // later destruction. static void cleanup_args_for_failure_locked(security_handshaker* h) { @@ -194,7 +203,7 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { tsi_handshaker_result_destroy(h->handshaker_result); h->handshaker_result = nullptr; // Add auth context to channel args. - grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context); + grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get()); grpc_channel_args* tmp_args = h->args->args; h->args->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1); @@ -211,7 +220,7 @@ static void on_peer_checked(void* arg, grpc_error* error) { gpr_mu_lock(&h->mu); on_peer_checked_inner(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } static grpc_error* check_peer_locked(security_handshaker* h) { @@ -222,8 +231,8 @@ static grpc_error* check_peer_locked(security_handshaker* h) { return grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); } - grpc_security_connector_check_peer(h->connector, peer, &h->auth_context, - &h->on_peer_checked); + h->connector->check_peer(peer, h->args->endpoint, &h->auth_context, + &h->on_peer_checked); return GRPC_ERROR_NONE; } @@ -281,7 +290,7 @@ static void on_handshake_next_done_grpc_wrapper( if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -317,7 +326,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake read failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // Copy all slices received. @@ -329,7 +338,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -343,7 +352,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake write failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // We may be done. @@ -355,7 +364,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } } @@ -368,7 +377,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { static void security_handshaker_destroy(grpc_handshaker* handshaker) { security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker); - security_handshaker_unref(h); + h->Unref(); } static void security_handshaker_shutdown(grpc_handshaker* handshaker, @@ -393,14 +402,14 @@ static void security_handshaker_do_handshake(grpc_handshaker* handshaker, gpr_mu_lock(&h->mu); h->args = args; h->on_handshake_done = on_handshake_done; - gpr_ref(&h->refs); + h->Ref(); size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); grpc_error* error = do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } gpr_mu_unlock(&h->mu); @@ -410,27 +419,32 @@ static const grpc_handshaker_vtable security_handshaker_vtable = { security_handshaker_destroy, security_handshaker_shutdown, security_handshaker_do_handshake, "security"}; -static grpc_handshaker* security_handshaker_create( - tsi_handshaker* handshaker, grpc_security_connector* connector) { - security_handshaker* h = static_cast<security_handshaker*>( - gpr_zalloc(sizeof(security_handshaker))); - grpc_handshaker_init(&security_handshaker_vtable, &h->base); - h->handshaker = handshaker; - h->connector = GRPC_SECURITY_CONNECTOR_REF(connector, "handshake"); - gpr_mu_init(&h->mu); - gpr_ref_init(&h->refs, 1); - h->handshake_buffer_size = GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE; - h->handshake_buffer = - static_cast<uint8_t*>(gpr_malloc(h->handshake_buffer_size)); - GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer, - on_handshake_data_sent_to_peer, h, +namespace { +security_handshaker::security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector) + : handshaker(handshaker), + connector(connector->Ref(DEBUG_LOCATION, "handshake")), + handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), + handshake_buffer( + static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size))) { + grpc_handshaker_init(&security_handshaker_vtable, &base); + gpr_mu_init(&mu); + grpc_slice_buffer_init(&outgoing); + GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer, + ::on_handshake_data_sent_to_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer, - on_handshake_data_received_from_peer, h, + GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer, + ::on_handshake_data_received_from_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_peer_checked, on_peer_checked, h, + GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this, grpc_schedule_on_exec_ctx); - grpc_slice_buffer_init(&h->outgoing); +} +} // namespace + +static grpc_handshaker* security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector) { + security_handshaker* h = + grpc_core::New<security_handshaker>(handshaker, connector); return &h->base; } @@ -477,8 +491,9 @@ static void client_handshaker_factory_add_handshakers( grpc_channel_security_connector* security_connector = reinterpret_cast<grpc_channel_security_connector*>( grpc_security_connector_find_in_args(args)); - grpc_channel_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void server_handshaker_factory_add_handshakers( @@ -488,8 +503,9 @@ static void server_handshaker_factory_add_handshakers( grpc_server_security_connector* security_connector = reinterpret_cast<grpc_server_security_connector*>( grpc_security_connector_find_in_args(args)); - grpc_server_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void handshaker_factory_destroy( diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 362f49a584..f93eb4275e 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -39,8 +39,12 @@ enum async_state { }; struct channel_data { - grpc_auth_context* auth_context; - grpc_server_credentials* creds; + channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds) + : auth_context(auth_context->Ref()), creds(creds->Ref()) {} + ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); } + + grpc_core::RefCountedPtr<grpc_auth_context> auth_context; + grpc_core::RefCountedPtr<grpc_server_credentials> creds; }; struct call_data { @@ -58,7 +62,7 @@ struct call_data { grpc_server_security_context_create(args.arena); channel_data* chand = static_cast<channel_data*>(elem->channel_data); server_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter"); if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) { args.context[GRPC_CONTEXT_SECURITY].destroy( args.context[GRPC_CONTEXT_SECURITY].value); @@ -208,7 +212,8 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { call_data* calld = static_cast<call_data*>(elem->call_data); grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; if (error == GRPC_ERROR_NONE) { - if (chand->creds != nullptr && chand->creds->processor.process != nullptr) { + if (chand->creds != nullptr && + chand->creds->auth_metadata_processor().process != nullptr) { // We're calling out to the application, so we need to make sure // to drop the call combiner early if we get cancelled. GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem, @@ -218,9 +223,10 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata"); calld->md = metadata_batch_to_md_array( batch->payload->recv_initial_metadata.recv_initial_metadata); - chand->creds->processor.process( - chand->creds->processor.state, chand->auth_context, - calld->md.metadata, calld->md.count, on_md_processing_done, elem); + chand->creds->auth_metadata_processor().process( + chand->creds->auth_metadata_processor().state, + chand->auth_context.get(), calld->md.metadata, calld->md.count, + on_md_processing_done, elem); return; } } @@ -290,23 +296,19 @@ static void destroy_call_elem(grpc_call_element* elem, static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { GPR_ASSERT(!args->is_last); - channel_data* chand = static_cast<channel_data*>(elem->channel_data); grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args->channel_args); GPR_ASSERT(auth_context != nullptr); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter"); grpc_server_credentials* creds = grpc_find_server_credentials_in_args(args->channel_args); - chand->creds = grpc_server_credentials_ref(creds); + new (elem->channel_data) channel_data(auth_context, creds); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { channel_data* chand = static_cast<channel_data*>(elem->channel_data); - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter"); - grpc_server_credentials_unref(chand->creds); + chand->~channel_data(); } const grpc_channel_filter grpc_server_auth_filter = { diff --git a/src/core/lib/surface/call.cc b/src/core/lib/surface/call.cc index 735a78ad08..89b3f77822 100644 --- a/src/core/lib/surface/call.cc +++ b/src/core/lib/surface/call.cc @@ -694,6 +694,10 @@ static void cancel_with_error(grpc_call* c, grpc_error* error) { execute_batch(c, op, &state->start_batch); } +void grpc_call_cancel_internal(grpc_call* call) { + cancel_with_error(call, GRPC_ERROR_CANCELLED); +} + static grpc_error* error_from_status(grpc_status_code status, const char* description) { // copying 'description' is needed to ensure the grpc_call_cancel_with_status diff --git a/src/core/lib/surface/call.h b/src/core/lib/surface/call.h index b34260505a..bd7295fe11 100644 --- a/src/core/lib/surface/call.h +++ b/src/core/lib/surface/call.h @@ -81,6 +81,10 @@ grpc_call_error grpc_call_start_batch_and_execute(grpc_call* call, size_t nops, grpc_closure* closure); +/* gRPC core internal version of grpc_call_cancel that does not create + * exec_ctx. */ +void grpc_call_cancel_internal(grpc_call* call); + /* Given the top call_element, get the call object. */ grpc_call* grpc_call_from_top_element(grpc_call_element* surface_element); diff --git a/src/core/lib/surface/init.cc b/src/core/lib/surface/init.cc index c6198b8ae7..67cf5d89bf 100644 --- a/src/core/lib/surface/init.cc +++ b/src/core/lib/surface/init.cc @@ -161,6 +161,7 @@ void grpc_shutdown(void) { if (--g_initializations == 0) { { grpc_core::ExecCtx exec_ctx(0); + grpc_iomgr_shutdown_background_closure(); { grpc_timer_manager_set_threading( false); // shutdown timer_manager thread diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc index 93b1933809..67b38e6f0c 100644 --- a/src/core/lib/surface/server.cc +++ b/src/core/lib/surface/server.cc @@ -28,6 +28,8 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> +#include <utility> + #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/connected_channel.h" #include "src/core/lib/debug/stats.h" @@ -47,6 +49,10 @@ grpc_core::TraceFlag grpc_server_channel_trace(false, "server_channel"); +static void server_on_recv_initial_metadata(void* ptr, grpc_error* error); +static void server_recv_trailing_metadata_ready(void* user_data, + grpc_error* error); + namespace { struct listener { void* arg; @@ -105,7 +111,7 @@ struct channel_data { uint32_t registered_method_max_probes; grpc_closure finish_destroy_channel_closure; grpc_closure channel_connectivity_changed; - intptr_t socket_uuid; + grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> socket_node; }; typedef struct shutdown_tag { @@ -128,46 +134,73 @@ typedef enum { typedef struct request_matcher request_matcher; struct call_data { + call_data(grpc_call_element* elem, const grpc_call_element_args& args) + : call(grpc_call_from_top_element(elem)), + call_combiner(args.call_combiner) { + GRPC_CLOSURE_INIT(&server_on_recv_initial_metadata, + ::server_on_recv_initial_metadata, elem, + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready, + server_recv_trailing_metadata_ready, elem, + grpc_schedule_on_exec_ctx); + } + ~call_data() { + GPR_ASSERT(state != PENDING); + GRPC_ERROR_UNREF(recv_initial_metadata_error); + if (host_set) { + grpc_slice_unref_internal(host); + } + if (path_set) { + grpc_slice_unref_internal(path); + } + grpc_metadata_array_destroy(&initial_metadata); + grpc_byte_buffer_destroy(payload); + } + grpc_call* call; - gpr_atm state; + gpr_atm state = NOT_STARTED; - bool path_set; - bool host_set; + bool path_set = false; + bool host_set = false; grpc_slice path; grpc_slice host; - grpc_millis deadline; + grpc_millis deadline = GRPC_MILLIS_INF_FUTURE; - grpc_completion_queue* cq_new; + grpc_completion_queue* cq_new = nullptr; - grpc_metadata_batch* recv_initial_metadata; - uint32_t recv_initial_metadata_flags; - grpc_metadata_array initial_metadata; + grpc_metadata_batch* recv_initial_metadata = nullptr; + uint32_t recv_initial_metadata_flags = 0; + grpc_metadata_array initial_metadata = + grpc_metadata_array(); // Zero-initialize the C struct. - request_matcher* matcher; - grpc_byte_buffer* payload; + request_matcher* matcher = nullptr; + grpc_byte_buffer* payload = nullptr; grpc_closure got_initial_metadata; grpc_closure server_on_recv_initial_metadata; grpc_closure kill_zombie_closure; grpc_closure* on_done_recv_initial_metadata; grpc_closure recv_trailing_metadata_ready; - grpc_error* recv_initial_metadata_error; + grpc_error* recv_initial_metadata_error = GRPC_ERROR_NONE; grpc_closure* original_recv_trailing_metadata_ready; - grpc_error* recv_trailing_metadata_error; - bool seen_recv_trailing_metadata_ready; + grpc_error* recv_trailing_metadata_error = GRPC_ERROR_NONE; + bool seen_recv_trailing_metadata_ready = false; grpc_closure publish; - call_data* pending_next; + call_data* pending_next = nullptr; grpc_call_combiner* call_combiner; }; struct request_matcher { + request_matcher(grpc_server* server); + ~request_matcher(); + grpc_server* server; - call_data* pending_head; - call_data* pending_tail; - gpr_locked_mpscq* requests_per_cq; + std::atomic<call_data*> pending_head{nullptr}; + call_data* pending_tail = nullptr; + gpr_locked_mpscq* requests_per_cq = nullptr; }; struct registered_method { @@ -316,22 +349,30 @@ static void channel_broadcaster_shutdown(channel_broadcaster* cb, * request_matcher */ -static void request_matcher_init(request_matcher* rm, grpc_server* server) { - memset(rm, 0, sizeof(*rm)); - rm->server = server; - rm->requests_per_cq = static_cast<gpr_locked_mpscq*>( - gpr_malloc(sizeof(*rm->requests_per_cq) * server->cq_count)); +namespace { +request_matcher::request_matcher(grpc_server* server) : server(server) { + requests_per_cq = static_cast<gpr_locked_mpscq*>( + gpr_malloc(sizeof(*requests_per_cq) * server->cq_count)); for (size_t i = 0; i < server->cq_count; i++) { - gpr_locked_mpscq_init(&rm->requests_per_cq[i]); + gpr_locked_mpscq_init(&requests_per_cq[i]); } } -static void request_matcher_destroy(request_matcher* rm) { - for (size_t i = 0; i < rm->server->cq_count; i++) { - GPR_ASSERT(gpr_locked_mpscq_pop(&rm->requests_per_cq[i]) == nullptr); - gpr_locked_mpscq_destroy(&rm->requests_per_cq[i]); +request_matcher::~request_matcher() { + for (size_t i = 0; i < server->cq_count; i++) { + GPR_ASSERT(gpr_locked_mpscq_pop(&requests_per_cq[i]) == nullptr); + gpr_locked_mpscq_destroy(&requests_per_cq[i]); } - gpr_free(rm->requests_per_cq); + gpr_free(requests_per_cq); +} +} // namespace + +static void request_matcher_init(request_matcher* rm, grpc_server* server) { + new (rm) request_matcher(server); +} + +static void request_matcher_destroy(request_matcher* rm) { + rm->~request_matcher(); } static void kill_zombie(void* elem, grpc_error* error) { @@ -340,9 +381,10 @@ static void kill_zombie(void* elem, grpc_error* error) { } static void request_matcher_zombify_all_pending_calls(request_matcher* rm) { - while (rm->pending_head) { - call_data* calld = rm->pending_head; - rm->pending_head = calld->pending_next; + call_data* calld; + while ((calld = rm->pending_head.load(std::memory_order_relaxed)) != + nullptr) { + rm->pending_head.store(calld->pending_next, std::memory_order_relaxed); gpr_atm_no_barrier_store(&calld->state, ZOMBIED); GRPC_CLOSURE_INIT( &calld->kill_zombie_closure, kill_zombie, @@ -540,8 +582,9 @@ static void publish_new_rpc(void* arg, grpc_error* error) { } gpr_atm_no_barrier_store(&calld->state, PENDING); - if (rm->pending_head == nullptr) { - rm->pending_tail = rm->pending_head = calld; + if (rm->pending_head.load(std::memory_order_relaxed) == nullptr) { + rm->pending_head.store(calld, std::memory_order_relaxed); + rm->pending_tail = calld; } else { rm->pending_tail->pending_next = calld; rm->pending_tail = calld; @@ -875,40 +918,18 @@ static void channel_connectivity_changed(void* cd, grpc_error* error) { static grpc_error* init_call_elem(grpc_call_element* elem, const grpc_call_element_args* args) { - call_data* calld = static_cast<call_data*>(elem->call_data); channel_data* chand = static_cast<channel_data*>(elem->channel_data); - memset(calld, 0, sizeof(call_data)); - calld->deadline = GRPC_MILLIS_INF_FUTURE; - calld->call = grpc_call_from_top_element(elem); - calld->call_combiner = args->call_combiner; - - GRPC_CLOSURE_INIT(&calld->server_on_recv_initial_metadata, - server_on_recv_initial_metadata, elem, - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready, - server_recv_trailing_metadata_ready, elem, - grpc_schedule_on_exec_ctx); server_ref(chand->server); + new (elem->call_data) call_data(elem, *args); return GRPC_ERROR_NONE; } static void destroy_call_elem(grpc_call_element* elem, const grpc_call_final_info* final_info, grpc_closure* ignored) { - channel_data* chand = static_cast<channel_data*>(elem->channel_data); call_data* calld = static_cast<call_data*>(elem->call_data); - - GPR_ASSERT(calld->state != PENDING); - GRPC_ERROR_UNREF(calld->recv_initial_metadata_error); - if (calld->host_set) { - grpc_slice_unref_internal(calld->host); - } - if (calld->path_set) { - grpc_slice_unref_internal(calld->path); - } - grpc_metadata_array_destroy(&calld->initial_metadata); - grpc_byte_buffer_destroy(calld->payload); - + calld->~call_data(); + channel_data* chand = static_cast<channel_data*>(elem->channel_data); server_unref(chand->server); } @@ -931,6 +952,7 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem, static void destroy_channel_elem(grpc_channel_element* elem) { size_t i; channel_data* chand = static_cast<channel_data*>(elem->channel_data); + chand->socket_node.reset(); if (chand->registered_methods) { for (i = 0; i < chand->registered_method_slots; i++) { grpc_slice_unref_internal(chand->registered_methods[i].method); @@ -1136,11 +1158,11 @@ void grpc_server_get_pollsets(grpc_server* server, grpc_pollset*** pollsets, *pollsets = server->pollsets; } -void grpc_server_setup_transport(grpc_server* s, grpc_transport* transport, - grpc_pollset* accepting_pollset, - const grpc_channel_args* args, - intptr_t socket_uuid, - grpc_resource_user* resource_user) { +void grpc_server_setup_transport( + grpc_server* s, grpc_transport* transport, grpc_pollset* accepting_pollset, + const grpc_channel_args* args, + grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> socket_node, + grpc_resource_user* resource_user) { size_t num_registered_methods; size_t alloc; registered_method* rm; @@ -1161,7 +1183,7 @@ void grpc_server_setup_transport(grpc_server* s, grpc_transport* transport, chand->server = s; server_ref(s); chand->channel = channel; - chand->socket_uuid = socket_uuid; + chand->socket_node = std::move(socket_node); size_t cq_idx; for (cq_idx = 0; cq_idx < s->cq_count; cq_idx++) { @@ -1237,14 +1259,13 @@ void grpc_server_setup_transport(grpc_server* s, grpc_transport* transport, } void grpc_server_populate_server_sockets( - grpc_server* s, grpc_core::channelz::ChildRefsList* server_sockets, + grpc_server* s, grpc_core::channelz::ChildSocketsList* server_sockets, intptr_t start_idx) { gpr_mu_lock(&s->mu_global); channel_data* c = nullptr; for (c = s->root_channel_data.next; c != &s->root_channel_data; c = c->next) { - intptr_t socket_uuid = c->socket_uuid; - if (socket_uuid >= start_idx) { - server_sockets->push_back(socket_uuid); + if (c->socket_node != nullptr && c->socket_node->uuid() >= start_idx) { + server_sockets->push_back(c->socket_node.get()); } } gpr_mu_unlock(&s->mu_global); @@ -1427,30 +1448,39 @@ static grpc_call_error queue_call_request(grpc_server* server, size_t cq_idx, rm = &rc->data.registered.method->matcher; break; } - if (gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link)) { - /* this was the first queued request: we need to lock and start - matching calls */ - gpr_mu_lock(&server->mu_call); - while ((calld = rm->pending_head) != nullptr) { - rc = reinterpret_cast<requested_call*>( - gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx])); - if (rc == nullptr) break; - rm->pending_head = calld->pending_next; - gpr_mu_unlock(&server->mu_call); - if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) { - // Zombied Call - GRPC_CLOSURE_INIT( - &calld->kill_zombie_closure, kill_zombie, - grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0), - grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE); - } else { - publish_call(server, calld, cq_idx, rc); - } - gpr_mu_lock(&server->mu_call); - } + + // Fast path: if there is no pending request to be processed, immediately + // return. + if (!gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link) || + // Note: We are reading the pending_head without holding the server's call + // mutex. Even if we read a non-null value here due to reordering, + // we will check it below again after grabbing the lock. + rm->pending_head.load(std::memory_order_relaxed) == nullptr) { + return GRPC_CALL_OK; + } + // Slow path: This was the first queued request and there are pendings: + // We need to lock and start matching calls. + gpr_mu_lock(&server->mu_call); + while ((calld = rm->pending_head.load(std::memory_order_relaxed)) != + nullptr) { + rc = reinterpret_cast<requested_call*>( + gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx])); + if (rc == nullptr) break; + rm->pending_head.store(calld->pending_next, std::memory_order_relaxed); gpr_mu_unlock(&server->mu_call); + if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) { + // Zombied Call + GRPC_CLOSURE_INIT( + &calld->kill_zombie_closure, kill_zombie, + grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0), + grpc_schedule_on_exec_ctx); + GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE); + } else { + publish_call(server, calld, cq_idx, rc); + } + gpr_mu_lock(&server->mu_call); } + gpr_mu_unlock(&server->mu_call); return GRPC_CALL_OK; } diff --git a/src/core/lib/surface/server.h b/src/core/lib/surface/server.h index 27038fdb7a..393bb24214 100644 --- a/src/core/lib/surface/server.h +++ b/src/core/lib/surface/server.h @@ -44,15 +44,15 @@ void grpc_server_add_listener(grpc_server* server, void* listener, /* Setup a transport - creates a channel stack, binds the transport to the server */ -void grpc_server_setup_transport(grpc_server* server, grpc_transport* transport, - grpc_pollset* accepting_pollset, - const grpc_channel_args* args, - intptr_t socket_uuid, - grpc_resource_user* resource_user = nullptr); +void grpc_server_setup_transport( + grpc_server* server, grpc_transport* transport, + grpc_pollset* accepting_pollset, const grpc_channel_args* args, + grpc_core::RefCountedPtr<grpc_core::channelz::SocketNode> socket_node, + grpc_resource_user* resource_user = nullptr); /* fills in the uuids of all sockets used for connections on this server */ void grpc_server_populate_server_sockets( - grpc_server* server, grpc_core::channelz::ChildRefsList* server_sockets, + grpc_server* server, grpc_core::channelz::ChildSocketsList* server_sockets, intptr_t start_idx); /* fills in the uuids of all listen sockets on this server */ diff --git a/src/core/lib/surface/version.cc b/src/core/lib/surface/version.cc index 66890ce65a..4829cc80a5 100644 --- a/src/core/lib/surface/version.cc +++ b/src/core/lib/surface/version.cc @@ -25,4 +25,4 @@ const char* grpc_version_string(void) { return "7.0.0-dev"; } -const char* grpc_g_stands_for(void) { return "gizmo"; } +const char* grpc_g_stands_for(void) { return "goose"; } diff --git a/src/core/lib/transport/static_metadata.cc b/src/core/lib/transport/static_metadata.cc index 4ebe73f82a..3dfaaaad5c 100644 --- a/src/core/lib/transport/static_metadata.cc +++ b/src/core/lib/transport/static_metadata.cc @@ -65,51 +65,56 @@ static uint8_t g_bytes[] = { 97, 110, 99, 101, 114, 47, 66, 97, 108, 97, 110, 99, 101, 76, 111, 97, 100, 47, 103, 114, 112, 99, 46, 104, 101, 97, 108, 116, 104, 46, 118, 49, 46, 72, 101, 97, 108, 116, 104, 47, 87, 97, 116, 99, 104, - 100, 101, 102, 108, 97, 116, 101, 103, 122, 105, 112, 115, 116, 114, 101, - 97, 109, 47, 103, 122, 105, 112, 71, 69, 84, 80, 79, 83, 84, 47, - 47, 105, 110, 100, 101, 120, 46, 104, 116, 109, 108, 104, 116, 116, 112, - 104, 116, 116, 112, 115, 50, 48, 48, 50, 48, 52, 50, 48, 54, 51, - 48, 52, 52, 48, 48, 52, 48, 52, 53, 48, 48, 97, 99, 99, 101, - 112, 116, 45, 99, 104, 97, 114, 115, 101, 116, 103, 122, 105, 112, 44, - 32, 100, 101, 102, 108, 97, 116, 101, 97, 99, 99, 101, 112, 116, 45, - 108, 97, 110, 103, 117, 97, 103, 101, 97, 99, 99, 101, 112, 116, 45, - 114, 97, 110, 103, 101, 115, 97, 99, 99, 101, 112, 116, 97, 99, 99, - 101, 115, 115, 45, 99, 111, 110, 116, 114, 111, 108, 45, 97, 108, 108, - 111, 119, 45, 111, 114, 105, 103, 105, 110, 97, 103, 101, 97, 108, 108, - 111, 119, 97, 117, 116, 104, 111, 114, 105, 122, 97, 116, 105, 111, 110, - 99, 97, 99, 104, 101, 45, 99, 111, 110, 116, 114, 111, 108, 99, 111, - 110, 116, 101, 110, 116, 45, 100, 105, 115, 112, 111, 115, 105, 116, 105, - 111, 110, 99, 111, 110, 116, 101, 110, 116, 45, 108, 97, 110, 103, 117, - 97, 103, 101, 99, 111, 110, 116, 101, 110, 116, 45, 108, 101, 110, 103, - 116, 104, 99, 111, 110, 116, 101, 110, 116, 45, 108, 111, 99, 97, 116, - 105, 111, 110, 99, 111, 110, 116, 101, 110, 116, 45, 114, 97, 110, 103, - 101, 99, 111, 111, 107, 105, 101, 100, 97, 116, 101, 101, 116, 97, 103, - 101, 120, 112, 101, 99, 116, 101, 120, 112, 105, 114, 101, 115, 102, 114, - 111, 109, 105, 102, 45, 109, 97, 116, 99, 104, 105, 102, 45, 109, 111, - 100, 105, 102, 105, 101, 100, 45, 115, 105, 110, 99, 101, 105, 102, 45, - 110, 111, 110, 101, 45, 109, 97, 116, 99, 104, 105, 102, 45, 114, 97, - 110, 103, 101, 105, 102, 45, 117, 110, 109, 111, 100, 105, 102, 105, 101, - 100, 45, 115, 105, 110, 99, 101, 108, 97, 115, 116, 45, 109, 111, 100, - 105, 102, 105, 101, 100, 108, 105, 110, 107, 108, 111, 99, 97, 116, 105, - 111, 110, 109, 97, 120, 45, 102, 111, 114, 119, 97, 114, 100, 115, 112, - 114, 111, 120, 121, 45, 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, - 116, 101, 112, 114, 111, 120, 121, 45, 97, 117, 116, 104, 111, 114, 105, - 122, 97, 116, 105, 111, 110, 114, 97, 110, 103, 101, 114, 101, 102, 101, - 114, 101, 114, 114, 101, 102, 114, 101, 115, 104, 114, 101, 116, 114, 121, - 45, 97, 102, 116, 101, 114, 115, 101, 114, 118, 101, 114, 115, 101, 116, - 45, 99, 111, 111, 107, 105, 101, 115, 116, 114, 105, 99, 116, 45, 116, - 114, 97, 110, 115, 112, 111, 114, 116, 45, 115, 101, 99, 117, 114, 105, - 116, 121, 116, 114, 97, 110, 115, 102, 101, 114, 45, 101, 110, 99, 111, - 100, 105, 110, 103, 118, 97, 114, 121, 118, 105, 97, 119, 119, 119, 45, - 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 101, 48, 105, 100, - 101, 110, 116, 105, 116, 121, 116, 114, 97, 105, 108, 101, 114, 115, 97, - 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 103, 114, 112, 99, - 103, 114, 112, 99, 80, 85, 84, 108, 98, 45, 99, 111, 115, 116, 45, - 98, 105, 110, 105, 100, 101, 110, 116, 105, 116, 121, 44, 100, 101, 102, - 108, 97, 116, 101, 105, 100, 101, 110, 116, 105, 116, 121, 44, 103, 122, - 105, 112, 100, 101, 102, 108, 97, 116, 101, 44, 103, 122, 105, 112, 105, - 100, 101, 110, 116, 105, 116, 121, 44, 100, 101, 102, 108, 97, 116, 101, - 44, 103, 122, 105, 112}; + 47, 101, 110, 118, 111, 121, 46, 115, 101, 114, 118, 105, 99, 101, 46, + 100, 105, 115, 99, 111, 118, 101, 114, 121, 46, 118, 50, 46, 65, 103, + 103, 114, 101, 103, 97, 116, 101, 100, 68, 105, 115, 99, 111, 118, 101, + 114, 121, 83, 101, 114, 118, 105, 99, 101, 47, 83, 116, 114, 101, 97, + 109, 65, 103, 103, 114, 101, 103, 97, 116, 101, 100, 82, 101, 115, 111, + 117, 114, 99, 101, 115, 100, 101, 102, 108, 97, 116, 101, 103, 122, 105, + 112, 115, 116, 114, 101, 97, 109, 47, 103, 122, 105, 112, 71, 69, 84, + 80, 79, 83, 84, 47, 47, 105, 110, 100, 101, 120, 46, 104, 116, 109, + 108, 104, 116, 116, 112, 104, 116, 116, 112, 115, 50, 48, 48, 50, 48, + 52, 50, 48, 54, 51, 48, 52, 52, 48, 48, 52, 48, 52, 53, 48, + 48, 97, 99, 99, 101, 112, 116, 45, 99, 104, 97, 114, 115, 101, 116, + 103, 122, 105, 112, 44, 32, 100, 101, 102, 108, 97, 116, 101, 97, 99, + 99, 101, 112, 116, 45, 108, 97, 110, 103, 117, 97, 103, 101, 97, 99, + 99, 101, 112, 116, 45, 114, 97, 110, 103, 101, 115, 97, 99, 99, 101, + 112, 116, 97, 99, 99, 101, 115, 115, 45, 99, 111, 110, 116, 114, 111, + 108, 45, 97, 108, 108, 111, 119, 45, 111, 114, 105, 103, 105, 110, 97, + 103, 101, 97, 108, 108, 111, 119, 97, 117, 116, 104, 111, 114, 105, 122, + 97, 116, 105, 111, 110, 99, 97, 99, 104, 101, 45, 99, 111, 110, 116, + 114, 111, 108, 99, 111, 110, 116, 101, 110, 116, 45, 100, 105, 115, 112, + 111, 115, 105, 116, 105, 111, 110, 99, 111, 110, 116, 101, 110, 116, 45, + 108, 97, 110, 103, 117, 97, 103, 101, 99, 111, 110, 116, 101, 110, 116, + 45, 108, 101, 110, 103, 116, 104, 99, 111, 110, 116, 101, 110, 116, 45, + 108, 111, 99, 97, 116, 105, 111, 110, 99, 111, 110, 116, 101, 110, 116, + 45, 114, 97, 110, 103, 101, 99, 111, 111, 107, 105, 101, 100, 97, 116, + 101, 101, 116, 97, 103, 101, 120, 112, 101, 99, 116, 101, 120, 112, 105, + 114, 101, 115, 102, 114, 111, 109, 105, 102, 45, 109, 97, 116, 99, 104, + 105, 102, 45, 109, 111, 100, 105, 102, 105, 101, 100, 45, 115, 105, 110, + 99, 101, 105, 102, 45, 110, 111, 110, 101, 45, 109, 97, 116, 99, 104, + 105, 102, 45, 114, 97, 110, 103, 101, 105, 102, 45, 117, 110, 109, 111, + 100, 105, 102, 105, 101, 100, 45, 115, 105, 110, 99, 101, 108, 97, 115, + 116, 45, 109, 111, 100, 105, 102, 105, 101, 100, 108, 105, 110, 107, 108, + 111, 99, 97, 116, 105, 111, 110, 109, 97, 120, 45, 102, 111, 114, 119, + 97, 114, 100, 115, 112, 114, 111, 120, 121, 45, 97, 117, 116, 104, 101, + 110, 116, 105, 99, 97, 116, 101, 112, 114, 111, 120, 121, 45, 97, 117, + 116, 104, 111, 114, 105, 122, 97, 116, 105, 111, 110, 114, 97, 110, 103, + 101, 114, 101, 102, 101, 114, 101, 114, 114, 101, 102, 114, 101, 115, 104, + 114, 101, 116, 114, 121, 45, 97, 102, 116, 101, 114, 115, 101, 114, 118, + 101, 114, 115, 101, 116, 45, 99, 111, 111, 107, 105, 101, 115, 116, 114, + 105, 99, 116, 45, 116, 114, 97, 110, 115, 112, 111, 114, 116, 45, 115, + 101, 99, 117, 114, 105, 116, 121, 116, 114, 97, 110, 115, 102, 101, 114, + 45, 101, 110, 99, 111, 100, 105, 110, 103, 118, 97, 114, 121, 118, 105, + 97, 119, 119, 119, 45, 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, + 116, 101, 48, 105, 100, 101, 110, 116, 105, 116, 121, 116, 114, 97, 105, + 108, 101, 114, 115, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, + 47, 103, 114, 112, 99, 103, 114, 112, 99, 80, 85, 84, 108, 98, 45, + 99, 111, 115, 116, 45, 98, 105, 110, 105, 100, 101, 110, 116, 105, 116, + 121, 44, 100, 101, 102, 108, 97, 116, 101, 105, 100, 101, 110, 116, 105, + 116, 121, 44, 103, 122, 105, 112, 100, 101, 102, 108, 97, 116, 101, 44, + 103, 122, 105, 112, 105, 100, 101, 110, 116, 105, 116, 121, 44, 100, 101, + 102, 108, 97, 116, 101, 44, 103, 122, 105, 112}; static void static_ref(void* unused) {} static void static_unref(void* unused) {} @@ -227,6 +232,7 @@ grpc_slice_refcount grpc_static_metadata_refcounts[GRPC_STATIC_MDSTR_COUNT] = { {&grpc_static_metadata_vtable, &static_sub_refcnt}, {&grpc_static_metadata_vtable, &static_sub_refcnt}, {&grpc_static_metadata_vtable, &static_sub_refcnt}, + {&grpc_static_metadata_vtable, &static_sub_refcnt}, }; const grpc_slice grpc_static_slice_table[GRPC_STATIC_MDSTR_COUNT] = { @@ -266,76 +272,77 @@ const grpc_slice grpc_static_slice_table[GRPC_STATIC_MDSTR_COUNT] = { {&grpc_static_metadata_refcounts[33], {{g_bytes + 415, 31}}}, {&grpc_static_metadata_refcounts[34], {{g_bytes + 446, 36}}}, {&grpc_static_metadata_refcounts[35], {{g_bytes + 482, 28}}}, - {&grpc_static_metadata_refcounts[36], {{g_bytes + 510, 7}}}, - {&grpc_static_metadata_refcounts[37], {{g_bytes + 517, 4}}}, - {&grpc_static_metadata_refcounts[38], {{g_bytes + 521, 11}}}, - {&grpc_static_metadata_refcounts[39], {{g_bytes + 532, 3}}}, - {&grpc_static_metadata_refcounts[40], {{g_bytes + 535, 4}}}, - {&grpc_static_metadata_refcounts[41], {{g_bytes + 539, 1}}}, - {&grpc_static_metadata_refcounts[42], {{g_bytes + 540, 11}}}, - {&grpc_static_metadata_refcounts[43], {{g_bytes + 551, 4}}}, - {&grpc_static_metadata_refcounts[44], {{g_bytes + 555, 5}}}, - {&grpc_static_metadata_refcounts[45], {{g_bytes + 560, 3}}}, - {&grpc_static_metadata_refcounts[46], {{g_bytes + 563, 3}}}, - {&grpc_static_metadata_refcounts[47], {{g_bytes + 566, 3}}}, - {&grpc_static_metadata_refcounts[48], {{g_bytes + 569, 3}}}, - {&grpc_static_metadata_refcounts[49], {{g_bytes + 572, 3}}}, - {&grpc_static_metadata_refcounts[50], {{g_bytes + 575, 3}}}, - {&grpc_static_metadata_refcounts[51], {{g_bytes + 578, 3}}}, - {&grpc_static_metadata_refcounts[52], {{g_bytes + 581, 14}}}, - {&grpc_static_metadata_refcounts[53], {{g_bytes + 595, 13}}}, - {&grpc_static_metadata_refcounts[54], {{g_bytes + 608, 15}}}, - {&grpc_static_metadata_refcounts[55], {{g_bytes + 623, 13}}}, - {&grpc_static_metadata_refcounts[56], {{g_bytes + 636, 6}}}, - {&grpc_static_metadata_refcounts[57], {{g_bytes + 642, 27}}}, - {&grpc_static_metadata_refcounts[58], {{g_bytes + 669, 3}}}, - {&grpc_static_metadata_refcounts[59], {{g_bytes + 672, 5}}}, - {&grpc_static_metadata_refcounts[60], {{g_bytes + 677, 13}}}, - {&grpc_static_metadata_refcounts[61], {{g_bytes + 690, 13}}}, - {&grpc_static_metadata_refcounts[62], {{g_bytes + 703, 19}}}, - {&grpc_static_metadata_refcounts[63], {{g_bytes + 722, 16}}}, - {&grpc_static_metadata_refcounts[64], {{g_bytes + 738, 14}}}, - {&grpc_static_metadata_refcounts[65], {{g_bytes + 752, 16}}}, - {&grpc_static_metadata_refcounts[66], {{g_bytes + 768, 13}}}, - {&grpc_static_metadata_refcounts[67], {{g_bytes + 781, 6}}}, - {&grpc_static_metadata_refcounts[68], {{g_bytes + 787, 4}}}, - {&grpc_static_metadata_refcounts[69], {{g_bytes + 791, 4}}}, - {&grpc_static_metadata_refcounts[70], {{g_bytes + 795, 6}}}, - {&grpc_static_metadata_refcounts[71], {{g_bytes + 801, 7}}}, - {&grpc_static_metadata_refcounts[72], {{g_bytes + 808, 4}}}, - {&grpc_static_metadata_refcounts[73], {{g_bytes + 812, 8}}}, - {&grpc_static_metadata_refcounts[74], {{g_bytes + 820, 17}}}, - {&grpc_static_metadata_refcounts[75], {{g_bytes + 837, 13}}}, - {&grpc_static_metadata_refcounts[76], {{g_bytes + 850, 8}}}, - {&grpc_static_metadata_refcounts[77], {{g_bytes + 858, 19}}}, - {&grpc_static_metadata_refcounts[78], {{g_bytes + 877, 13}}}, - {&grpc_static_metadata_refcounts[79], {{g_bytes + 890, 4}}}, - {&grpc_static_metadata_refcounts[80], {{g_bytes + 894, 8}}}, - {&grpc_static_metadata_refcounts[81], {{g_bytes + 902, 12}}}, - {&grpc_static_metadata_refcounts[82], {{g_bytes + 914, 18}}}, - {&grpc_static_metadata_refcounts[83], {{g_bytes + 932, 19}}}, - {&grpc_static_metadata_refcounts[84], {{g_bytes + 951, 5}}}, - {&grpc_static_metadata_refcounts[85], {{g_bytes + 956, 7}}}, - {&grpc_static_metadata_refcounts[86], {{g_bytes + 963, 7}}}, - {&grpc_static_metadata_refcounts[87], {{g_bytes + 970, 11}}}, - {&grpc_static_metadata_refcounts[88], {{g_bytes + 981, 6}}}, - {&grpc_static_metadata_refcounts[89], {{g_bytes + 987, 10}}}, - {&grpc_static_metadata_refcounts[90], {{g_bytes + 997, 25}}}, - {&grpc_static_metadata_refcounts[91], {{g_bytes + 1022, 17}}}, - {&grpc_static_metadata_refcounts[92], {{g_bytes + 1039, 4}}}, - {&grpc_static_metadata_refcounts[93], {{g_bytes + 1043, 3}}}, - {&grpc_static_metadata_refcounts[94], {{g_bytes + 1046, 16}}}, - {&grpc_static_metadata_refcounts[95], {{g_bytes + 1062, 1}}}, - {&grpc_static_metadata_refcounts[96], {{g_bytes + 1063, 8}}}, - {&grpc_static_metadata_refcounts[97], {{g_bytes + 1071, 8}}}, - {&grpc_static_metadata_refcounts[98], {{g_bytes + 1079, 16}}}, - {&grpc_static_metadata_refcounts[99], {{g_bytes + 1095, 4}}}, - {&grpc_static_metadata_refcounts[100], {{g_bytes + 1099, 3}}}, - {&grpc_static_metadata_refcounts[101], {{g_bytes + 1102, 11}}}, - {&grpc_static_metadata_refcounts[102], {{g_bytes + 1113, 16}}}, - {&grpc_static_metadata_refcounts[103], {{g_bytes + 1129, 13}}}, - {&grpc_static_metadata_refcounts[104], {{g_bytes + 1142, 12}}}, - {&grpc_static_metadata_refcounts[105], {{g_bytes + 1154, 21}}}, + {&grpc_static_metadata_refcounts[36], {{g_bytes + 510, 80}}}, + {&grpc_static_metadata_refcounts[37], {{g_bytes + 590, 7}}}, + {&grpc_static_metadata_refcounts[38], {{g_bytes + 597, 4}}}, + {&grpc_static_metadata_refcounts[39], {{g_bytes + 601, 11}}}, + {&grpc_static_metadata_refcounts[40], {{g_bytes + 612, 3}}}, + {&grpc_static_metadata_refcounts[41], {{g_bytes + 615, 4}}}, + {&grpc_static_metadata_refcounts[42], {{g_bytes + 619, 1}}}, + {&grpc_static_metadata_refcounts[43], {{g_bytes + 620, 11}}}, + {&grpc_static_metadata_refcounts[44], {{g_bytes + 631, 4}}}, + {&grpc_static_metadata_refcounts[45], {{g_bytes + 635, 5}}}, + {&grpc_static_metadata_refcounts[46], {{g_bytes + 640, 3}}}, + {&grpc_static_metadata_refcounts[47], {{g_bytes + 643, 3}}}, + {&grpc_static_metadata_refcounts[48], {{g_bytes + 646, 3}}}, + {&grpc_static_metadata_refcounts[49], {{g_bytes + 649, 3}}}, + {&grpc_static_metadata_refcounts[50], {{g_bytes + 652, 3}}}, + {&grpc_static_metadata_refcounts[51], {{g_bytes + 655, 3}}}, + {&grpc_static_metadata_refcounts[52], {{g_bytes + 658, 3}}}, + {&grpc_static_metadata_refcounts[53], {{g_bytes + 661, 14}}}, + {&grpc_static_metadata_refcounts[54], {{g_bytes + 675, 13}}}, + {&grpc_static_metadata_refcounts[55], {{g_bytes + 688, 15}}}, + {&grpc_static_metadata_refcounts[56], {{g_bytes + 703, 13}}}, + {&grpc_static_metadata_refcounts[57], {{g_bytes + 716, 6}}}, + {&grpc_static_metadata_refcounts[58], {{g_bytes + 722, 27}}}, + {&grpc_static_metadata_refcounts[59], {{g_bytes + 749, 3}}}, + {&grpc_static_metadata_refcounts[60], {{g_bytes + 752, 5}}}, + {&grpc_static_metadata_refcounts[61], {{g_bytes + 757, 13}}}, + {&grpc_static_metadata_refcounts[62], {{g_bytes + 770, 13}}}, + {&grpc_static_metadata_refcounts[63], {{g_bytes + 783, 19}}}, + {&grpc_static_metadata_refcounts[64], {{g_bytes + 802, 16}}}, + {&grpc_static_metadata_refcounts[65], {{g_bytes + 818, 14}}}, + {&grpc_static_metadata_refcounts[66], {{g_bytes + 832, 16}}}, + {&grpc_static_metadata_refcounts[67], {{g_bytes + 848, 13}}}, + {&grpc_static_metadata_refcounts[68], {{g_bytes + 861, 6}}}, + {&grpc_static_metadata_refcounts[69], {{g_bytes + 867, 4}}}, + {&grpc_static_metadata_refcounts[70], {{g_bytes + 871, 4}}}, + {&grpc_static_metadata_refcounts[71], {{g_bytes + 875, 6}}}, + {&grpc_static_metadata_refcounts[72], {{g_bytes + 881, 7}}}, + {&grpc_static_metadata_refcounts[73], {{g_bytes + 888, 4}}}, + {&grpc_static_metadata_refcounts[74], {{g_bytes + 892, 8}}}, + {&grpc_static_metadata_refcounts[75], {{g_bytes + 900, 17}}}, + {&grpc_static_metadata_refcounts[76], {{g_bytes + 917, 13}}}, + {&grpc_static_metadata_refcounts[77], {{g_bytes + 930, 8}}}, + {&grpc_static_metadata_refcounts[78], {{g_bytes + 938, 19}}}, + {&grpc_static_metadata_refcounts[79], {{g_bytes + 957, 13}}}, + {&grpc_static_metadata_refcounts[80], {{g_bytes + 970, 4}}}, + {&grpc_static_metadata_refcounts[81], {{g_bytes + 974, 8}}}, + {&grpc_static_metadata_refcounts[82], {{g_bytes + 982, 12}}}, + {&grpc_static_metadata_refcounts[83], {{g_bytes + 994, 18}}}, + {&grpc_static_metadata_refcounts[84], {{g_bytes + 1012, 19}}}, + {&grpc_static_metadata_refcounts[85], {{g_bytes + 1031, 5}}}, + {&grpc_static_metadata_refcounts[86], {{g_bytes + 1036, 7}}}, + {&grpc_static_metadata_refcounts[87], {{g_bytes + 1043, 7}}}, + {&grpc_static_metadata_refcounts[88], {{g_bytes + 1050, 11}}}, + {&grpc_static_metadata_refcounts[89], {{g_bytes + 1061, 6}}}, + {&grpc_static_metadata_refcounts[90], {{g_bytes + 1067, 10}}}, + {&grpc_static_metadata_refcounts[91], {{g_bytes + 1077, 25}}}, + {&grpc_static_metadata_refcounts[92], {{g_bytes + 1102, 17}}}, + {&grpc_static_metadata_refcounts[93], {{g_bytes + 1119, 4}}}, + {&grpc_static_metadata_refcounts[94], {{g_bytes + 1123, 3}}}, + {&grpc_static_metadata_refcounts[95], {{g_bytes + 1126, 16}}}, + {&grpc_static_metadata_refcounts[96], {{g_bytes + 1142, 1}}}, + {&grpc_static_metadata_refcounts[97], {{g_bytes + 1143, 8}}}, + {&grpc_static_metadata_refcounts[98], {{g_bytes + 1151, 8}}}, + {&grpc_static_metadata_refcounts[99], {{g_bytes + 1159, 16}}}, + {&grpc_static_metadata_refcounts[100], {{g_bytes + 1175, 4}}}, + {&grpc_static_metadata_refcounts[101], {{g_bytes + 1179, 3}}}, + {&grpc_static_metadata_refcounts[102], {{g_bytes + 1182, 11}}}, + {&grpc_static_metadata_refcounts[103], {{g_bytes + 1193, 16}}}, + {&grpc_static_metadata_refcounts[104], {{g_bytes + 1209, 13}}}, + {&grpc_static_metadata_refcounts[105], {{g_bytes + 1222, 12}}}, + {&grpc_static_metadata_refcounts[106], {{g_bytes + 1234, 21}}}, }; uintptr_t grpc_static_mdelem_user_data[GRPC_STATIC_MDELEM_COUNT] = { @@ -345,17 +352,17 @@ uintptr_t grpc_static_mdelem_user_data[GRPC_STATIC_MDELEM_COUNT] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 4, 6, 6, 8, 8, 2, 4, 4}; static const int8_t elems_r[] = { - 16, 11, -8, 0, 3, -42, -81, -43, 0, 6, -8, 0, 0, 0, -7, - -3, -10, 0, 0, 0, -1, -2, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, -63, 0, -47, -68, -69, -70, 0, 33, - 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 20, - 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, - 4, 4, 4, 3, 10, 9, 0, 0, 0, 0, 0, 0, -3, 0}; + 15, 10, -8, 0, 2, -42, -81, -43, 0, 6, -8, 0, 0, 0, 2, + -3, -10, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, -64, 0, -67, -68, -69, -70, 0, + 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, + 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 5, 4, 5, 4, 4, 8, 8, 0, 0, 0, 0, 0, 0, -5, 0}; static uint32_t elems_phash(uint32_t i) { - i -= 41; - uint32_t x = i % 104; - uint32_t y = i / 104; + i -= 42; + uint32_t x = i % 105; + uint32_t y = i / 105; uint32_t h = x; if (y < GPR_ARRAY_SIZE(elems_r)) { uint32_t delta = (uint32_t)elems_r[y]; @@ -365,29 +372,29 @@ static uint32_t elems_phash(uint32_t i) { } static const uint16_t elem_keys[] = { - 257, 258, 259, 260, 261, 262, 263, 1096, 1097, 1513, 1725, 145, - 146, 467, 468, 1619, 41, 42, 1733, 990, 991, 767, 768, 1627, - 627, 837, 2043, 2149, 2255, 5541, 5859, 5965, 6071, 6177, 1749, 6283, - 6389, 6495, 6601, 6707, 6813, 6919, 7025, 7131, 7237, 7343, 7449, 7555, - 7661, 5753, 7767, 7873, 7979, 8085, 8191, 8297, 8403, 8509, 8615, 8721, - 8827, 8933, 9039, 9145, 9251, 9357, 9463, 1156, 9569, 523, 9675, 9781, - 206, 1162, 1163, 1164, 1165, 1792, 1582, 1050, 9887, 9993, 1686, 10735, - 1799, 0, 0, 0, 0, 0, 347, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0}; + 260, 261, 262, 263, 264, 265, 266, 1107, 1108, 1741, 147, 148, + 472, 473, 1634, 42, 43, 1527, 1750, 1000, 1001, 774, 775, 1643, + 633, 845, 2062, 2169, 2276, 5700, 5914, 6021, 6128, 6235, 1766, 6342, + 6449, 6556, 6663, 6770, 6877, 6984, 7091, 7198, 7305, 7412, 7519, 7626, + 7733, 7840, 7947, 8054, 8161, 8268, 8375, 8482, 8589, 8696, 8803, 8910, + 9017, 9124, 9231, 9338, 9445, 9552, 9659, 1167, 528, 9766, 9873, 208, + 9980, 1173, 1174, 1175, 1176, 1809, 10087, 1060, 10194, 10943, 1702, 0, + 1816, 0, 0, 1597, 0, 0, 350, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0}; static const uint8_t elem_idxs[] = { - 7, 8, 9, 10, 11, 12, 13, 77, 79, 30, 71, 1, 2, 5, 6, 25, - 3, 4, 84, 66, 65, 62, 63, 73, 67, 61, 57, 37, 74, 14, 17, 18, - 19, 20, 15, 21, 22, 23, 24, 26, 27, 28, 29, 31, 32, 33, 34, 35, - 36, 16, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, - 52, 53, 54, 76, 55, 69, 56, 58, 70, 78, 80, 81, 82, 83, 68, 64, - 59, 60, 72, 75, 85, 255, 255, 255, 255, 255, 0}; + 7, 8, 9, 10, 11, 12, 13, 77, 79, 71, 1, 2, 5, 6, 25, 3, + 4, 30, 84, 66, 65, 62, 63, 73, 67, 61, 57, 37, 74, 14, 16, 17, + 18, 19, 15, 20, 21, 22, 23, 24, 26, 27, 28, 29, 31, 32, 33, 34, + 35, 36, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, + 52, 53, 54, 76, 69, 55, 56, 70, 58, 78, 80, 81, 82, 83, 59, 64, + 60, 75, 72, 255, 85, 255, 255, 68, 255, 255, 0}; grpc_mdelem grpc_static_mdelem_for_static_strings(int a, int b) { if (a == -1 || b == -1) return GRPC_MDNULL; - uint32_t k = (uint32_t)(a * 106 + b); + uint32_t k = (uint32_t)(a * 107 + b); uint32_t h = elems_phash(k); return h < GPR_ARRAY_SIZE(elem_keys) && elem_keys[h] == k && elem_idxs[h] != 255 @@ -400,175 +407,175 @@ grpc_mdelem_data grpc_static_mdelem_table[GRPC_STATIC_MDELEM_COUNT] = { {{&grpc_static_metadata_refcounts[3], {{g_bytes + 19, 10}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[1], {{g_bytes + 5, 7}}}, - {&grpc_static_metadata_refcounts[39], {{g_bytes + 532, 3}}}}, + {&grpc_static_metadata_refcounts[40], {{g_bytes + 612, 3}}}}, {{&grpc_static_metadata_refcounts[1], {{g_bytes + 5, 7}}}, - {&grpc_static_metadata_refcounts[40], {{g_bytes + 535, 4}}}}, + {&grpc_static_metadata_refcounts[41], {{g_bytes + 615, 4}}}}, {{&grpc_static_metadata_refcounts[0], {{g_bytes + 0, 5}}}, - {&grpc_static_metadata_refcounts[41], {{g_bytes + 539, 1}}}}, + {&grpc_static_metadata_refcounts[42], {{g_bytes + 619, 1}}}}, {{&grpc_static_metadata_refcounts[0], {{g_bytes + 0, 5}}}, - {&grpc_static_metadata_refcounts[42], {{g_bytes + 540, 11}}}}, + {&grpc_static_metadata_refcounts[43], {{g_bytes + 620, 11}}}}, {{&grpc_static_metadata_refcounts[4], {{g_bytes + 29, 7}}}, - {&grpc_static_metadata_refcounts[43], {{g_bytes + 551, 4}}}}, + {&grpc_static_metadata_refcounts[44], {{g_bytes + 631, 4}}}}, {{&grpc_static_metadata_refcounts[4], {{g_bytes + 29, 7}}}, - {&grpc_static_metadata_refcounts[44], {{g_bytes + 555, 5}}}}, + {&grpc_static_metadata_refcounts[45], {{g_bytes + 635, 5}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[45], {{g_bytes + 560, 3}}}}, + {&grpc_static_metadata_refcounts[46], {{g_bytes + 640, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[46], {{g_bytes + 563, 3}}}}, + {&grpc_static_metadata_refcounts[47], {{g_bytes + 643, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[47], {{g_bytes + 566, 3}}}}, + {&grpc_static_metadata_refcounts[48], {{g_bytes + 646, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[48], {{g_bytes + 569, 3}}}}, + {&grpc_static_metadata_refcounts[49], {{g_bytes + 649, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[49], {{g_bytes + 572, 3}}}}, + {&grpc_static_metadata_refcounts[50], {{g_bytes + 652, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[50], {{g_bytes + 575, 3}}}}, + {&grpc_static_metadata_refcounts[51], {{g_bytes + 655, 3}}}}, {{&grpc_static_metadata_refcounts[2], {{g_bytes + 12, 7}}}, - {&grpc_static_metadata_refcounts[51], {{g_bytes + 578, 3}}}}, - {{&grpc_static_metadata_refcounts[52], {{g_bytes + 581, 14}}}, + {&grpc_static_metadata_refcounts[52], {{g_bytes + 658, 3}}}}, + {{&grpc_static_metadata_refcounts[53], {{g_bytes + 661, 14}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[16], {{g_bytes + 186, 15}}}, - {&grpc_static_metadata_refcounts[53], {{g_bytes + 595, 13}}}}, - {{&grpc_static_metadata_refcounts[54], {{g_bytes + 608, 15}}}, + {&grpc_static_metadata_refcounts[54], {{g_bytes + 675, 13}}}}, + {{&grpc_static_metadata_refcounts[55], {{g_bytes + 688, 15}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[55], {{g_bytes + 623, 13}}}, + {{&grpc_static_metadata_refcounts[56], {{g_bytes + 703, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[56], {{g_bytes + 636, 6}}}, + {{&grpc_static_metadata_refcounts[57], {{g_bytes + 716, 6}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[57], {{g_bytes + 642, 27}}}, + {{&grpc_static_metadata_refcounts[58], {{g_bytes + 722, 27}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[58], {{g_bytes + 669, 3}}}, + {{&grpc_static_metadata_refcounts[59], {{g_bytes + 749, 3}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[59], {{g_bytes + 672, 5}}}, + {{&grpc_static_metadata_refcounts[60], {{g_bytes + 752, 5}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[60], {{g_bytes + 677, 13}}}, + {{&grpc_static_metadata_refcounts[61], {{g_bytes + 757, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[61], {{g_bytes + 690, 13}}}, + {{&grpc_static_metadata_refcounts[62], {{g_bytes + 770, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[62], {{g_bytes + 703, 19}}}, + {{&grpc_static_metadata_refcounts[63], {{g_bytes + 783, 19}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[15], {{g_bytes + 170, 16}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[63], {{g_bytes + 722, 16}}}, + {{&grpc_static_metadata_refcounts[64], {{g_bytes + 802, 16}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[64], {{g_bytes + 738, 14}}}, + {{&grpc_static_metadata_refcounts[65], {{g_bytes + 818, 14}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[65], {{g_bytes + 752, 16}}}, + {{&grpc_static_metadata_refcounts[66], {{g_bytes + 832, 16}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[66], {{g_bytes + 768, 13}}}, + {{&grpc_static_metadata_refcounts[67], {{g_bytes + 848, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[14], {{g_bytes + 158, 12}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[67], {{g_bytes + 781, 6}}}, + {{&grpc_static_metadata_refcounts[68], {{g_bytes + 861, 6}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[68], {{g_bytes + 787, 4}}}, + {{&grpc_static_metadata_refcounts[69], {{g_bytes + 867, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[69], {{g_bytes + 791, 4}}}, + {{&grpc_static_metadata_refcounts[70], {{g_bytes + 871, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[70], {{g_bytes + 795, 6}}}, + {{&grpc_static_metadata_refcounts[71], {{g_bytes + 875, 6}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[71], {{g_bytes + 801, 7}}}, + {{&grpc_static_metadata_refcounts[72], {{g_bytes + 881, 7}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[72], {{g_bytes + 808, 4}}}, + {{&grpc_static_metadata_refcounts[73], {{g_bytes + 888, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[20], {{g_bytes + 278, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[73], {{g_bytes + 812, 8}}}, + {{&grpc_static_metadata_refcounts[74], {{g_bytes + 892, 8}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[74], {{g_bytes + 820, 17}}}, + {{&grpc_static_metadata_refcounts[75], {{g_bytes + 900, 17}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[75], {{g_bytes + 837, 13}}}, + {{&grpc_static_metadata_refcounts[76], {{g_bytes + 917, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[76], {{g_bytes + 850, 8}}}, + {{&grpc_static_metadata_refcounts[77], {{g_bytes + 930, 8}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[77], {{g_bytes + 858, 19}}}, + {{&grpc_static_metadata_refcounts[78], {{g_bytes + 938, 19}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[78], {{g_bytes + 877, 13}}}, + {{&grpc_static_metadata_refcounts[79], {{g_bytes + 957, 13}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[79], {{g_bytes + 890, 4}}}, + {{&grpc_static_metadata_refcounts[80], {{g_bytes + 970, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[80], {{g_bytes + 894, 8}}}, + {{&grpc_static_metadata_refcounts[81], {{g_bytes + 974, 8}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[81], {{g_bytes + 902, 12}}}, + {{&grpc_static_metadata_refcounts[82], {{g_bytes + 982, 12}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[82], {{g_bytes + 914, 18}}}, + {{&grpc_static_metadata_refcounts[83], {{g_bytes + 994, 18}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[83], {{g_bytes + 932, 19}}}, + {{&grpc_static_metadata_refcounts[84], {{g_bytes + 1012, 19}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[84], {{g_bytes + 951, 5}}}, + {{&grpc_static_metadata_refcounts[85], {{g_bytes + 1031, 5}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[85], {{g_bytes + 956, 7}}}, + {{&grpc_static_metadata_refcounts[86], {{g_bytes + 1036, 7}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[86], {{g_bytes + 963, 7}}}, + {{&grpc_static_metadata_refcounts[87], {{g_bytes + 1043, 7}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[87], {{g_bytes + 970, 11}}}, + {{&grpc_static_metadata_refcounts[88], {{g_bytes + 1050, 11}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[88], {{g_bytes + 981, 6}}}, + {{&grpc_static_metadata_refcounts[89], {{g_bytes + 1061, 6}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[89], {{g_bytes + 987, 10}}}, + {{&grpc_static_metadata_refcounts[90], {{g_bytes + 1067, 10}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[90], {{g_bytes + 997, 25}}}, + {{&grpc_static_metadata_refcounts[91], {{g_bytes + 1077, 25}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[91], {{g_bytes + 1022, 17}}}, + {{&grpc_static_metadata_refcounts[92], {{g_bytes + 1102, 17}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[19], {{g_bytes + 268, 10}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[92], {{g_bytes + 1039, 4}}}, + {{&grpc_static_metadata_refcounts[93], {{g_bytes + 1119, 4}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[93], {{g_bytes + 1043, 3}}}, + {{&grpc_static_metadata_refcounts[94], {{g_bytes + 1123, 3}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[94], {{g_bytes + 1046, 16}}}, + {{&grpc_static_metadata_refcounts[95], {{g_bytes + 1126, 16}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[7], {{g_bytes + 50, 11}}}, - {&grpc_static_metadata_refcounts[95], {{g_bytes + 1062, 1}}}}, + {&grpc_static_metadata_refcounts[96], {{g_bytes + 1142, 1}}}}, {{&grpc_static_metadata_refcounts[7], {{g_bytes + 50, 11}}}, {&grpc_static_metadata_refcounts[25], {{g_bytes + 350, 1}}}}, {{&grpc_static_metadata_refcounts[7], {{g_bytes + 50, 11}}}, {&grpc_static_metadata_refcounts[26], {{g_bytes + 351, 1}}}}, {{&grpc_static_metadata_refcounts[9], {{g_bytes + 77, 13}}}, - {&grpc_static_metadata_refcounts[96], {{g_bytes + 1063, 8}}}}, + {&grpc_static_metadata_refcounts[97], {{g_bytes + 1143, 8}}}}, {{&grpc_static_metadata_refcounts[9], {{g_bytes + 77, 13}}}, - {&grpc_static_metadata_refcounts[37], {{g_bytes + 517, 4}}}}, + {&grpc_static_metadata_refcounts[38], {{g_bytes + 597, 4}}}}, {{&grpc_static_metadata_refcounts[9], {{g_bytes + 77, 13}}}, - {&grpc_static_metadata_refcounts[36], {{g_bytes + 510, 7}}}}, + {&grpc_static_metadata_refcounts[37], {{g_bytes + 590, 7}}}}, {{&grpc_static_metadata_refcounts[5], {{g_bytes + 36, 2}}}, - {&grpc_static_metadata_refcounts[97], {{g_bytes + 1071, 8}}}}, + {&grpc_static_metadata_refcounts[98], {{g_bytes + 1151, 8}}}}, {{&grpc_static_metadata_refcounts[14], {{g_bytes + 158, 12}}}, - {&grpc_static_metadata_refcounts[98], {{g_bytes + 1079, 16}}}}, + {&grpc_static_metadata_refcounts[99], {{g_bytes + 1159, 16}}}}, {{&grpc_static_metadata_refcounts[4], {{g_bytes + 29, 7}}}, - {&grpc_static_metadata_refcounts[99], {{g_bytes + 1095, 4}}}}, + {&grpc_static_metadata_refcounts[100], {{g_bytes + 1175, 4}}}}, {{&grpc_static_metadata_refcounts[1], {{g_bytes + 5, 7}}}, - {&grpc_static_metadata_refcounts[100], {{g_bytes + 1099, 3}}}}, + {&grpc_static_metadata_refcounts[101], {{g_bytes + 1179, 3}}}}, {{&grpc_static_metadata_refcounts[16], {{g_bytes + 186, 15}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[15], {{g_bytes + 170, 16}}}, - {&grpc_static_metadata_refcounts[96], {{g_bytes + 1063, 8}}}}, + {&grpc_static_metadata_refcounts[97], {{g_bytes + 1143, 8}}}}, {{&grpc_static_metadata_refcounts[15], {{g_bytes + 170, 16}}}, - {&grpc_static_metadata_refcounts[37], {{g_bytes + 517, 4}}}}, + {&grpc_static_metadata_refcounts[38], {{g_bytes + 597, 4}}}}, {{&grpc_static_metadata_refcounts[21], {{g_bytes + 282, 8}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, - {{&grpc_static_metadata_refcounts[101], {{g_bytes + 1102, 11}}}, + {{&grpc_static_metadata_refcounts[102], {{g_bytes + 1182, 11}}}, {&grpc_static_metadata_refcounts[29], {{g_bytes + 354, 0}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[96], {{g_bytes + 1063, 8}}}}, + {&grpc_static_metadata_refcounts[97], {{g_bytes + 1143, 8}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[36], {{g_bytes + 510, 7}}}}, + {&grpc_static_metadata_refcounts[37], {{g_bytes + 590, 7}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[102], {{g_bytes + 1113, 16}}}}, + {&grpc_static_metadata_refcounts[103], {{g_bytes + 1193, 16}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[37], {{g_bytes + 517, 4}}}}, + {&grpc_static_metadata_refcounts[38], {{g_bytes + 597, 4}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[103], {{g_bytes + 1129, 13}}}}, + {&grpc_static_metadata_refcounts[104], {{g_bytes + 1209, 13}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[104], {{g_bytes + 1142, 12}}}}, + {&grpc_static_metadata_refcounts[105], {{g_bytes + 1222, 12}}}}, {{&grpc_static_metadata_refcounts[10], {{g_bytes + 90, 20}}}, - {&grpc_static_metadata_refcounts[105], {{g_bytes + 1154, 21}}}}, + {&grpc_static_metadata_refcounts[106], {{g_bytes + 1234, 21}}}}, {{&grpc_static_metadata_refcounts[16], {{g_bytes + 186, 15}}}, - {&grpc_static_metadata_refcounts[96], {{g_bytes + 1063, 8}}}}, + {&grpc_static_metadata_refcounts[97], {{g_bytes + 1143, 8}}}}, {{&grpc_static_metadata_refcounts[16], {{g_bytes + 186, 15}}}, - {&grpc_static_metadata_refcounts[37], {{g_bytes + 517, 4}}}}, + {&grpc_static_metadata_refcounts[38], {{g_bytes + 597, 4}}}}, {{&grpc_static_metadata_refcounts[16], {{g_bytes + 186, 15}}}, - {&grpc_static_metadata_refcounts[103], {{g_bytes + 1129, 13}}}}, + {&grpc_static_metadata_refcounts[104], {{g_bytes + 1209, 13}}}}, }; const uint8_t grpc_static_accept_encoding_metadata[8] = {0, 76, 77, 78, 79, 80, 81, 82}; diff --git a/src/core/lib/transport/static_metadata.h b/src/core/lib/transport/static_metadata.h index 2bb9f72838..4f9670232c 100644 --- a/src/core/lib/transport/static_metadata.h +++ b/src/core/lib/transport/static_metadata.h @@ -31,7 +31,7 @@ #include "src/core/lib/transport/metadata.h" -#define GRPC_STATIC_MDSTR_COUNT 106 +#define GRPC_STATIC_MDSTR_COUNT 107 extern const grpc_slice grpc_static_slice_table[GRPC_STATIC_MDSTR_COUNT]; /* ":path" */ #define GRPC_MDSTR_PATH (grpc_static_slice_table[0]) @@ -110,147 +110,151 @@ extern const grpc_slice grpc_static_slice_table[GRPC_STATIC_MDSTR_COUNT]; /* "/grpc.health.v1.Health/Watch" */ #define GRPC_MDSTR_SLASH_GRPC_DOT_HEALTH_DOT_V1_DOT_HEALTH_SLASH_WATCH \ (grpc_static_slice_table[35]) +/* "/envoy.service.discovery.v2.AggregatedDiscoveryService/StreamAggregatedResources" + */ +#define GRPC_MDSTR_SLASH_ENVOY_DOT_SERVICE_DOT_DISCOVERY_DOT_V2_DOT_AGGREGATEDDISCOVERYSERVICE_SLASH_STREAMAGGREGATEDRESOURCES \ + (grpc_static_slice_table[36]) /* "deflate" */ -#define GRPC_MDSTR_DEFLATE (grpc_static_slice_table[36]) +#define GRPC_MDSTR_DEFLATE (grpc_static_slice_table[37]) /* "gzip" */ -#define GRPC_MDSTR_GZIP (grpc_static_slice_table[37]) +#define GRPC_MDSTR_GZIP (grpc_static_slice_table[38]) /* "stream/gzip" */ -#define GRPC_MDSTR_STREAM_SLASH_GZIP (grpc_static_slice_table[38]) +#define GRPC_MDSTR_STREAM_SLASH_GZIP (grpc_static_slice_table[39]) /* "GET" */ -#define GRPC_MDSTR_GET (grpc_static_slice_table[39]) +#define GRPC_MDSTR_GET (grpc_static_slice_table[40]) /* "POST" */ -#define GRPC_MDSTR_POST (grpc_static_slice_table[40]) +#define GRPC_MDSTR_POST (grpc_static_slice_table[41]) /* "/" */ -#define GRPC_MDSTR_SLASH (grpc_static_slice_table[41]) +#define GRPC_MDSTR_SLASH (grpc_static_slice_table[42]) /* "/index.html" */ -#define GRPC_MDSTR_SLASH_INDEX_DOT_HTML (grpc_static_slice_table[42]) +#define GRPC_MDSTR_SLASH_INDEX_DOT_HTML (grpc_static_slice_table[43]) /* "http" */ -#define GRPC_MDSTR_HTTP (grpc_static_slice_table[43]) +#define GRPC_MDSTR_HTTP (grpc_static_slice_table[44]) /* "https" */ -#define GRPC_MDSTR_HTTPS (grpc_static_slice_table[44]) +#define GRPC_MDSTR_HTTPS (grpc_static_slice_table[45]) /* "200" */ -#define GRPC_MDSTR_200 (grpc_static_slice_table[45]) +#define GRPC_MDSTR_200 (grpc_static_slice_table[46]) /* "204" */ -#define GRPC_MDSTR_204 (grpc_static_slice_table[46]) +#define GRPC_MDSTR_204 (grpc_static_slice_table[47]) /* "206" */ -#define GRPC_MDSTR_206 (grpc_static_slice_table[47]) +#define GRPC_MDSTR_206 (grpc_static_slice_table[48]) /* "304" */ -#define GRPC_MDSTR_304 (grpc_static_slice_table[48]) +#define GRPC_MDSTR_304 (grpc_static_slice_table[49]) /* "400" */ -#define GRPC_MDSTR_400 (grpc_static_slice_table[49]) +#define GRPC_MDSTR_400 (grpc_static_slice_table[50]) /* "404" */ -#define GRPC_MDSTR_404 (grpc_static_slice_table[50]) +#define GRPC_MDSTR_404 (grpc_static_slice_table[51]) /* "500" */ -#define GRPC_MDSTR_500 (grpc_static_slice_table[51]) +#define GRPC_MDSTR_500 (grpc_static_slice_table[52]) /* "accept-charset" */ -#define GRPC_MDSTR_ACCEPT_CHARSET (grpc_static_slice_table[52]) +#define GRPC_MDSTR_ACCEPT_CHARSET (grpc_static_slice_table[53]) /* "gzip, deflate" */ -#define GRPC_MDSTR_GZIP_COMMA_DEFLATE (grpc_static_slice_table[53]) +#define GRPC_MDSTR_GZIP_COMMA_DEFLATE (grpc_static_slice_table[54]) /* "accept-language" */ -#define GRPC_MDSTR_ACCEPT_LANGUAGE (grpc_static_slice_table[54]) +#define GRPC_MDSTR_ACCEPT_LANGUAGE (grpc_static_slice_table[55]) /* "accept-ranges" */ -#define GRPC_MDSTR_ACCEPT_RANGES (grpc_static_slice_table[55]) +#define GRPC_MDSTR_ACCEPT_RANGES (grpc_static_slice_table[56]) /* "accept" */ -#define GRPC_MDSTR_ACCEPT (grpc_static_slice_table[56]) +#define GRPC_MDSTR_ACCEPT (grpc_static_slice_table[57]) /* "access-control-allow-origin" */ -#define GRPC_MDSTR_ACCESS_CONTROL_ALLOW_ORIGIN (grpc_static_slice_table[57]) +#define GRPC_MDSTR_ACCESS_CONTROL_ALLOW_ORIGIN (grpc_static_slice_table[58]) /* "age" */ -#define GRPC_MDSTR_AGE (grpc_static_slice_table[58]) +#define GRPC_MDSTR_AGE (grpc_static_slice_table[59]) /* "allow" */ -#define GRPC_MDSTR_ALLOW (grpc_static_slice_table[59]) +#define GRPC_MDSTR_ALLOW (grpc_static_slice_table[60]) /* "authorization" */ -#define GRPC_MDSTR_AUTHORIZATION (grpc_static_slice_table[60]) +#define GRPC_MDSTR_AUTHORIZATION (grpc_static_slice_table[61]) /* "cache-control" */ -#define GRPC_MDSTR_CACHE_CONTROL (grpc_static_slice_table[61]) +#define GRPC_MDSTR_CACHE_CONTROL (grpc_static_slice_table[62]) /* "content-disposition" */ -#define GRPC_MDSTR_CONTENT_DISPOSITION (grpc_static_slice_table[62]) +#define GRPC_MDSTR_CONTENT_DISPOSITION (grpc_static_slice_table[63]) /* "content-language" */ -#define GRPC_MDSTR_CONTENT_LANGUAGE (grpc_static_slice_table[63]) +#define GRPC_MDSTR_CONTENT_LANGUAGE (grpc_static_slice_table[64]) /* "content-length" */ -#define GRPC_MDSTR_CONTENT_LENGTH (grpc_static_slice_table[64]) +#define GRPC_MDSTR_CONTENT_LENGTH (grpc_static_slice_table[65]) /* "content-location" */ -#define GRPC_MDSTR_CONTENT_LOCATION (grpc_static_slice_table[65]) +#define GRPC_MDSTR_CONTENT_LOCATION (grpc_static_slice_table[66]) /* "content-range" */ -#define GRPC_MDSTR_CONTENT_RANGE (grpc_static_slice_table[66]) +#define GRPC_MDSTR_CONTENT_RANGE (grpc_static_slice_table[67]) /* "cookie" */ -#define GRPC_MDSTR_COOKIE (grpc_static_slice_table[67]) +#define GRPC_MDSTR_COOKIE (grpc_static_slice_table[68]) /* "date" */ -#define GRPC_MDSTR_DATE (grpc_static_slice_table[68]) +#define GRPC_MDSTR_DATE (grpc_static_slice_table[69]) /* "etag" */ -#define GRPC_MDSTR_ETAG (grpc_static_slice_table[69]) +#define GRPC_MDSTR_ETAG (grpc_static_slice_table[70]) /* "expect" */ -#define GRPC_MDSTR_EXPECT (grpc_static_slice_table[70]) +#define GRPC_MDSTR_EXPECT (grpc_static_slice_table[71]) /* "expires" */ -#define GRPC_MDSTR_EXPIRES (grpc_static_slice_table[71]) +#define GRPC_MDSTR_EXPIRES (grpc_static_slice_table[72]) /* "from" */ -#define GRPC_MDSTR_FROM (grpc_static_slice_table[72]) +#define GRPC_MDSTR_FROM (grpc_static_slice_table[73]) /* "if-match" */ -#define GRPC_MDSTR_IF_MATCH (grpc_static_slice_table[73]) +#define GRPC_MDSTR_IF_MATCH (grpc_static_slice_table[74]) /* "if-modified-since" */ -#define GRPC_MDSTR_IF_MODIFIED_SINCE (grpc_static_slice_table[74]) +#define GRPC_MDSTR_IF_MODIFIED_SINCE (grpc_static_slice_table[75]) /* "if-none-match" */ -#define GRPC_MDSTR_IF_NONE_MATCH (grpc_static_slice_table[75]) +#define GRPC_MDSTR_IF_NONE_MATCH (grpc_static_slice_table[76]) /* "if-range" */ -#define GRPC_MDSTR_IF_RANGE (grpc_static_slice_table[76]) +#define GRPC_MDSTR_IF_RANGE (grpc_static_slice_table[77]) /* "if-unmodified-since" */ -#define GRPC_MDSTR_IF_UNMODIFIED_SINCE (grpc_static_slice_table[77]) +#define GRPC_MDSTR_IF_UNMODIFIED_SINCE (grpc_static_slice_table[78]) /* "last-modified" */ -#define GRPC_MDSTR_LAST_MODIFIED (grpc_static_slice_table[78]) +#define GRPC_MDSTR_LAST_MODIFIED (grpc_static_slice_table[79]) /* "link" */ -#define GRPC_MDSTR_LINK (grpc_static_slice_table[79]) +#define GRPC_MDSTR_LINK (grpc_static_slice_table[80]) /* "location" */ -#define GRPC_MDSTR_LOCATION (grpc_static_slice_table[80]) +#define GRPC_MDSTR_LOCATION (grpc_static_slice_table[81]) /* "max-forwards" */ -#define GRPC_MDSTR_MAX_FORWARDS (grpc_static_slice_table[81]) +#define GRPC_MDSTR_MAX_FORWARDS (grpc_static_slice_table[82]) /* "proxy-authenticate" */ -#define GRPC_MDSTR_PROXY_AUTHENTICATE (grpc_static_slice_table[82]) +#define GRPC_MDSTR_PROXY_AUTHENTICATE (grpc_static_slice_table[83]) /* "proxy-authorization" */ -#define GRPC_MDSTR_PROXY_AUTHORIZATION (grpc_static_slice_table[83]) +#define GRPC_MDSTR_PROXY_AUTHORIZATION (grpc_static_slice_table[84]) /* "range" */ -#define GRPC_MDSTR_RANGE (grpc_static_slice_table[84]) +#define GRPC_MDSTR_RANGE (grpc_static_slice_table[85]) /* "referer" */ -#define GRPC_MDSTR_REFERER (grpc_static_slice_table[85]) +#define GRPC_MDSTR_REFERER (grpc_static_slice_table[86]) /* "refresh" */ -#define GRPC_MDSTR_REFRESH (grpc_static_slice_table[86]) +#define GRPC_MDSTR_REFRESH (grpc_static_slice_table[87]) /* "retry-after" */ -#define GRPC_MDSTR_RETRY_AFTER (grpc_static_slice_table[87]) +#define GRPC_MDSTR_RETRY_AFTER (grpc_static_slice_table[88]) /* "server" */ -#define GRPC_MDSTR_SERVER (grpc_static_slice_table[88]) +#define GRPC_MDSTR_SERVER (grpc_static_slice_table[89]) /* "set-cookie" */ -#define GRPC_MDSTR_SET_COOKIE (grpc_static_slice_table[89]) +#define GRPC_MDSTR_SET_COOKIE (grpc_static_slice_table[90]) /* "strict-transport-security" */ -#define GRPC_MDSTR_STRICT_TRANSPORT_SECURITY (grpc_static_slice_table[90]) +#define GRPC_MDSTR_STRICT_TRANSPORT_SECURITY (grpc_static_slice_table[91]) /* "transfer-encoding" */ -#define GRPC_MDSTR_TRANSFER_ENCODING (grpc_static_slice_table[91]) +#define GRPC_MDSTR_TRANSFER_ENCODING (grpc_static_slice_table[92]) /* "vary" */ -#define GRPC_MDSTR_VARY (grpc_static_slice_table[92]) +#define GRPC_MDSTR_VARY (grpc_static_slice_table[93]) /* "via" */ -#define GRPC_MDSTR_VIA (grpc_static_slice_table[93]) +#define GRPC_MDSTR_VIA (grpc_static_slice_table[94]) /* "www-authenticate" */ -#define GRPC_MDSTR_WWW_AUTHENTICATE (grpc_static_slice_table[94]) +#define GRPC_MDSTR_WWW_AUTHENTICATE (grpc_static_slice_table[95]) /* "0" */ -#define GRPC_MDSTR_0 (grpc_static_slice_table[95]) +#define GRPC_MDSTR_0 (grpc_static_slice_table[96]) /* "identity" */ -#define GRPC_MDSTR_IDENTITY (grpc_static_slice_table[96]) +#define GRPC_MDSTR_IDENTITY (grpc_static_slice_table[97]) /* "trailers" */ -#define GRPC_MDSTR_TRAILERS (grpc_static_slice_table[97]) +#define GRPC_MDSTR_TRAILERS (grpc_static_slice_table[98]) /* "application/grpc" */ -#define GRPC_MDSTR_APPLICATION_SLASH_GRPC (grpc_static_slice_table[98]) +#define GRPC_MDSTR_APPLICATION_SLASH_GRPC (grpc_static_slice_table[99]) /* "grpc" */ -#define GRPC_MDSTR_GRPC (grpc_static_slice_table[99]) +#define GRPC_MDSTR_GRPC (grpc_static_slice_table[100]) /* "PUT" */ -#define GRPC_MDSTR_PUT (grpc_static_slice_table[100]) +#define GRPC_MDSTR_PUT (grpc_static_slice_table[101]) /* "lb-cost-bin" */ -#define GRPC_MDSTR_LB_COST_BIN (grpc_static_slice_table[101]) +#define GRPC_MDSTR_LB_COST_BIN (grpc_static_slice_table[102]) /* "identity,deflate" */ -#define GRPC_MDSTR_IDENTITY_COMMA_DEFLATE (grpc_static_slice_table[102]) +#define GRPC_MDSTR_IDENTITY_COMMA_DEFLATE (grpc_static_slice_table[103]) /* "identity,gzip" */ -#define GRPC_MDSTR_IDENTITY_COMMA_GZIP (grpc_static_slice_table[103]) +#define GRPC_MDSTR_IDENTITY_COMMA_GZIP (grpc_static_slice_table[104]) /* "deflate,gzip" */ -#define GRPC_MDSTR_DEFLATE_COMMA_GZIP (grpc_static_slice_table[104]) +#define GRPC_MDSTR_DEFLATE_COMMA_GZIP (grpc_static_slice_table[105]) /* "identity,deflate,gzip" */ #define GRPC_MDSTR_IDENTITY_COMMA_DEFLATE_COMMA_GZIP \ - (grpc_static_slice_table[105]) + (grpc_static_slice_table[106]) extern const grpc_slice_refcount_vtable grpc_static_metadata_vtable; extern grpc_slice_refcount diff --git a/src/core/lib/transport/transport.cc b/src/core/lib/transport/transport.cc index cbdb77c844..b32f9c6ec1 100644 --- a/src/core/lib/transport/transport.cc +++ b/src/core/lib/transport/transport.cc @@ -27,6 +27,7 @@ #include <grpc/support/log.h> #include <grpc/support/sync.h> +#include "src/core/lib/gpr/alloc.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/iomgr/executor.h" #include "src/core/lib/slice/slice_internal.h" @@ -149,7 +150,7 @@ void grpc_transport_move_stats(grpc_transport_stream_stats* from, } size_t grpc_transport_stream_size(grpc_transport* transport) { - return transport->vtable->sizeof_stream; + return GPR_ROUND_UP_TO_ALIGNMENT_SIZE(transport->vtable->sizeof_stream); } void grpc_transport_destroy(grpc_transport* transport) { diff --git a/src/core/lib/transport/transport.h b/src/core/lib/transport/transport.h index edfa7030d1..5ce568834e 100644 --- a/src/core/lib/transport/transport.h +++ b/src/core/lib/transport/transport.h @@ -129,7 +129,8 @@ struct grpc_transport_stream_op_batch { recv_initial_metadata(false), recv_message(false), recv_trailing_metadata(false), - cancel_stream(false) {} + cancel_stream(false), + is_traced(false) {} /** Should be scheduled when all of the non-recv operations in the batch are complete. @@ -167,6 +168,9 @@ struct grpc_transport_stream_op_batch { /** Cancel this stream with the provided error */ bool cancel_stream : 1; + /** Is this stream traced */ + bool is_traced : 1; + /*************************************************************************** * remaining fields are initialized and used at the discretion of the * current handler of the op */ diff --git a/src/core/plugin_registry/grpc_cronet_plugin_registry.cc b/src/core/plugin_registry/grpc_cronet_plugin_registry.cc index c0c17b0a4b..92085d31d7 100644 --- a/src/core/plugin_registry/grpc_cronet_plugin_registry.cc +++ b/src/core/plugin_registry/grpc_cronet_plugin_registry.cc @@ -28,8 +28,6 @@ void grpc_deadline_filter_init(void); void grpc_deadline_filter_shutdown(void); void grpc_client_channel_init(void); void grpc_client_channel_shutdown(void); -void grpc_tsi_alts_init(void); -void grpc_tsi_alts_shutdown(void); void grpc_register_built_in_plugins(void) { grpc_register_plugin(grpc_http_filters_init, @@ -40,6 +38,4 @@ void grpc_register_built_in_plugins(void) { grpc_deadline_filter_shutdown); grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); - grpc_register_plugin(grpc_tsi_alts_init, - grpc_tsi_alts_shutdown); } diff --git a/src/core/plugin_registry/grpc_plugin_registry.cc b/src/core/plugin_registry/grpc_plugin_registry.cc index 94c2493d5e..cde40ef65c 100644 --- a/src/core/plugin_registry/grpc_plugin_registry.cc +++ b/src/core/plugin_registry/grpc_plugin_registry.cc @@ -28,8 +28,6 @@ void grpc_deadline_filter_init(void); void grpc_deadline_filter_shutdown(void); void grpc_client_channel_init(void); void grpc_client_channel_shutdown(void); -void grpc_tsi_alts_init(void); -void grpc_tsi_alts_shutdown(void); void grpc_inproc_plugin_init(void); void grpc_inproc_plugin_shutdown(void); void grpc_resolver_fake_init(void); @@ -66,8 +64,6 @@ void grpc_register_built_in_plugins(void) { grpc_deadline_filter_shutdown); grpc_register_plugin(grpc_client_channel_init, grpc_client_channel_shutdown); - grpc_register_plugin(grpc_tsi_alts_init, - grpc_tsi_alts_shutdown); grpc_register_plugin(grpc_inproc_plugin_init, grpc_inproc_plugin_shutdown); grpc_register_plugin(grpc_resolver_fake_init, diff --git a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc index 941ca13114..43d0979f4b 100644 --- a/src/core/tsi/alts/handshaker/alts_handshaker_client.cc +++ b/src/core/tsi/alts/handshaker/alts_handshaker_client.cc @@ -116,12 +116,13 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c, "cb is nullptr in alts_tsi_handshaker_handle_response()"); return; } - if (handshaker == nullptr || recv_buffer == nullptr) { + if (handshaker == nullptr) { gpr_log(GPR_ERROR, - "Invalid arguments to alts_tsi_handshaker_handle_response()"); + "handshaker is nullptr in alts_tsi_handshaker_handle_response()"); cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); return; } + /* TSI handshake has been shutdown. */ if (alts_tsi_handshaker_has_shutdown(handshaker)) { gpr_log(GPR_ERROR, "TSI handshake shutdown"); cb(TSI_HANDSHAKE_SHUTDOWN, user_data, nullptr, 0, nullptr); @@ -133,6 +134,12 @@ void alts_handshaker_client_handle_response(alts_handshaker_client* c, cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); return; } + if (recv_buffer == nullptr) { + gpr_log(GPR_ERROR, + "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()"); + cb(TSI_INTERNAL_ERROR, user_data, nullptr, 0, nullptr); + return; + } grpc_gcp_handshaker_resp* resp = alts_tsi_utils_deserialize_response(recv_buffer); grpc_byte_buffer_destroy(client->recv_buffer); @@ -376,7 +383,7 @@ static void handshaker_client_shutdown(alts_handshaker_client* c) { alts_grpc_handshaker_client* client = reinterpret_cast<alts_grpc_handshaker_client*>(c); if (client->call != nullptr) { - GPR_ASSERT(grpc_call_cancel(client->call, nullptr) == GRPC_CALL_OK); + grpc_call_cancel_internal(client->call); } } diff --git a/src/core/tsi/alts/handshaker/alts_shared_resource.cc b/src/core/tsi/alts/handshaker/alts_shared_resource.cc index ffb5e1c655..3501257f05 100644 --- a/src/core/tsi/alts/handshaker/alts_shared_resource.cc +++ b/src/core/tsi/alts/handshaker/alts_shared_resource.cc @@ -25,7 +25,6 @@ #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" static alts_shared_resource_dedicated g_alts_resource_dedicated; -static alts_shared_resource* g_shared_resources = alts_get_shared_resource(); alts_shared_resource_dedicated* grpc_alts_get_shared_resource_dedicated(void) { return &g_alts_resource_dedicated; @@ -49,16 +48,25 @@ static void thread_worker(void* arg) { void grpc_alts_shared_resource_dedicated_init() { g_alts_resource_dedicated.cq = nullptr; + gpr_mu_init(&g_alts_resource_dedicated.mu); } -void grpc_alts_shared_resource_dedicated_start() { - g_alts_resource_dedicated.cq = grpc_completion_queue_create_for_next(nullptr); - g_alts_resource_dedicated.thread = - grpc_core::Thread("alts_tsi_handshaker", &thread_worker, nullptr); - g_alts_resource_dedicated.interested_parties = grpc_pollset_set_create(); - grpc_pollset_set_add_pollset(g_alts_resource_dedicated.interested_parties, - grpc_cq_pollset(g_alts_resource_dedicated.cq)); - g_alts_resource_dedicated.thread.Start(); +void grpc_alts_shared_resource_dedicated_start( + const char* handshaker_service_url) { + gpr_mu_lock(&g_alts_resource_dedicated.mu); + if (g_alts_resource_dedicated.cq == nullptr) { + g_alts_resource_dedicated.channel = + grpc_insecure_channel_create(handshaker_service_url, nullptr, nullptr); + g_alts_resource_dedicated.cq = + grpc_completion_queue_create_for_next(nullptr); + g_alts_resource_dedicated.thread = + grpc_core::Thread("alts_tsi_handshaker", &thread_worker, nullptr); + g_alts_resource_dedicated.interested_parties = grpc_pollset_set_create(); + grpc_pollset_set_add_pollset(g_alts_resource_dedicated.interested_parties, + grpc_cq_pollset(g_alts_resource_dedicated.cq)); + g_alts_resource_dedicated.thread.Start(); + } + gpr_mu_unlock(&g_alts_resource_dedicated.mu); } void grpc_alts_shared_resource_dedicated_shutdown() { @@ -69,5 +77,7 @@ void grpc_alts_shared_resource_dedicated_shutdown() { g_alts_resource_dedicated.thread.Join(); grpc_pollset_set_destroy(g_alts_resource_dedicated.interested_parties); grpc_completion_queue_destroy(g_alts_resource_dedicated.cq); + grpc_channel_destroy(g_alts_resource_dedicated.channel); } + gpr_mu_destroy(&g_alts_resource_dedicated.mu); } diff --git a/src/core/tsi/alts/handshaker/alts_shared_resource.h b/src/core/tsi/alts/handshaker/alts_shared_resource.h index a07da305a8..8ae0089a11 100644 --- a/src/core/tsi/alts/handshaker/alts_shared_resource.h +++ b/src/core/tsi/alts/handshaker/alts_shared_resource.h @@ -37,6 +37,8 @@ typedef struct alts_shared_resource_dedicated { grpc_completion_queue* cq; grpc_pollset_set* interested_parties; grpc_cq_completion storage; + gpr_mu mu; + grpc_channel* channel; } alts_shared_resource_dedicated; /* This method returns the address of alts_shared_resource_dedicated @@ -64,7 +66,8 @@ void grpc_alts_shared_resource_dedicated_init(); * The API will be invoked by the caller in a lazy manner. That is, * it will get invoked when ALTS TSI handshake occurs for the first time. */ -void grpc_alts_shared_resource_dedicated_start(); +void grpc_alts_shared_resource_dedicated_start( + const char* handshaker_service_url); #endif /* GRPC_CORE_TSI_ALTS_HANDSHAKER_ALTS_SHARED_RESOURCE_H \ */ diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc index 1639252883..1b7e58d3ce 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.cc @@ -39,9 +39,6 @@ #include "src/core/tsi/alts/handshaker/alts_shared_resource.h" #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" -#include "src/core/tsi/alts_transport_security.h" - -static alts_shared_resource* g_shared_resources = alts_get_shared_resource(); /* Main struct for ALTS TSI handshaker. */ struct alts_tsi_handshaker { @@ -51,11 +48,11 @@ struct alts_tsi_handshaker { bool is_client; bool has_sent_start_message; bool has_created_handshaker_client; - bool use_dedicated_cq; char* handshaker_service_url; grpc_pollset_set* interested_parties; grpc_alts_credentials_options* options; alts_handshaker_client_vtable* client_vtable_for_testing; + grpc_channel* channel; }; /* Main struct for ALTS TSI handshaker result. */ @@ -237,20 +234,6 @@ tsi_result alts_tsi_handshaker_result_create(grpc_gcp_handshaker_resp* resp, return TSI_OK; } -static void init_shared_resources(const char* handshaker_service_url, - bool use_dedicated_cq) { - GPR_ASSERT(handshaker_service_url != nullptr); - gpr_mu_lock(&g_shared_resources->mu); - if (g_shared_resources->channel == nullptr) { - g_shared_resources->channel = - grpc_insecure_channel_create(handshaker_service_url, nullptr, nullptr); - if (use_dedicated_cq) { - grpc_alts_shared_resource_dedicated_start(); - } - } - gpr_mu_unlock(&g_shared_resources->mu); -} - /* gRPC provided callback used when gRPC thread model is applied. */ static void on_handshaker_service_resp_recv(void* arg, grpc_error* error) { alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg); @@ -289,20 +272,24 @@ static tsi_result handshaker_next( reinterpret_cast<alts_tsi_handshaker*>(self); tsi_result ok = TSI_OK; if (!handshaker->has_created_handshaker_client) { - init_shared_resources(handshaker->handshaker_service_url, - handshaker->use_dedicated_cq); - if (handshaker->use_dedicated_cq) { + if (handshaker->channel == nullptr) { + grpc_alts_shared_resource_dedicated_start( + handshaker->handshaker_service_url); handshaker->interested_parties = grpc_alts_get_shared_resource_dedicated()->interested_parties; GPR_ASSERT(handshaker->interested_parties != nullptr); } - grpc_iomgr_cb_func grpc_cb = handshaker->use_dedicated_cq + grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr ? on_handshaker_service_resp_recv_dedicated : on_handshaker_service_resp_recv; + grpc_channel* channel = + handshaker->channel == nullptr + ? grpc_alts_get_shared_resource_dedicated()->channel + : handshaker->channel; handshaker->client = alts_grpc_handshaker_client_create( - handshaker, g_shared_resources->channel, - handshaker->handshaker_service_url, handshaker->interested_parties, - handshaker->options, handshaker->target_name, grpc_cb, cb, user_data, + handshaker, channel, handshaker->handshaker_service_url, + handshaker->interested_parties, handshaker->options, + handshaker->target_name, grpc_cb, cb, user_data, handshaker->client_vtable_for_testing, handshaker->is_client); if (handshaker->client == nullptr) { gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client"); @@ -310,7 +297,7 @@ static tsi_result handshaker_next( } handshaker->has_created_handshaker_client = true; } - if (handshaker->use_dedicated_cq && + if (handshaker->channel == nullptr && handshaker->client_vtable_for_testing == nullptr) { GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq, handshaker->client)); @@ -371,6 +358,9 @@ static void handshaker_destroy(tsi_handshaker* self) { alts_handshaker_client_destroy(handshaker->client); grpc_slice_unref_internal(handshaker->target_name); grpc_alts_credentials_options_destroy(handshaker->options); + if (handshaker->channel != nullptr) { + grpc_channel_destroy(handshaker->channel); + } gpr_free(handshaker->handshaker_service_url); gpr_free(handshaker); } @@ -407,7 +397,7 @@ tsi_result alts_tsi_handshaker_create( } alts_tsi_handshaker* handshaker = static_cast<alts_tsi_handshaker*>(gpr_zalloc(sizeof(*handshaker))); - handshaker->use_dedicated_cq = interested_parties == nullptr; + bool use_dedicated_cq = interested_parties == nullptr; handshaker->client = nullptr; handshaker->is_client = is_client; handshaker->has_sent_start_message = false; @@ -418,9 +408,13 @@ tsi_result alts_tsi_handshaker_create( handshaker->has_created_handshaker_client = false; handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url); handshaker->options = grpc_alts_credentials_options_copy(options); - handshaker->base.vtable = handshaker->use_dedicated_cq - ? &handshaker_vtable_dedicated - : &handshaker_vtable; + handshaker->base.vtable = + use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable; + handshaker->channel = + use_dedicated_cq + ? nullptr + : grpc_insecure_channel_create(handshaker->handshaker_service_url, + nullptr, nullptr); *self = &handshaker->base; return TSI_OK; } diff --git a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h index c961eae498..32f94bc9d3 100644 --- a/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h +++ b/src/core/tsi/alts/handshaker/alts_tsi_handshaker.h @@ -27,7 +27,6 @@ #include "src/core/lib/security/credentials/alts/grpc_alts_credentials_options.h" #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" #include "src/core/tsi/alts/handshaker/alts_handshaker_service_api_util.h" -#include "src/core/tsi/alts_transport_security.h" #include "src/core/tsi/transport_security.h" #include "src/core/tsi/transport_security_interface.h" diff --git a/src/core/tsi/alts_transport_security.cc b/src/core/tsi/alts_transport_security.cc deleted file mode 100644 index 5a1494ae5c..0000000000 --- a/src/core/tsi/alts_transport_security.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* - * - * Copyright 2017 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include <grpc/support/port_platform.h> - -#include <string.h> - -#include "src/core/tsi/alts_transport_security.h" - -static alts_shared_resource g_alts_resource; - -alts_shared_resource* alts_get_shared_resource(void) { - return &g_alts_resource; -} - -void grpc_tsi_alts_init() { - g_alts_resource.channel = nullptr; - gpr_mu_init(&g_alts_resource.mu); -} - -void grpc_tsi_alts_shutdown() { - if (g_alts_resource.channel != nullptr) { - grpc_channel_destroy(g_alts_resource.channel); - } - gpr_mu_destroy(&g_alts_resource.mu); -} diff --git a/src/core/tsi/alts_transport_security.h b/src/core/tsi/alts_transport_security.h deleted file mode 100644 index f4319d10d2..0000000000 --- a/src/core/tsi/alts_transport_security.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * - * Copyright 2017 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#ifndef GRPC_CORE_TSI_ALTS_TRANSPORT_SECURITY_H -#define GRPC_CORE_TSI_ALTS_TRANSPORT_SECURITY_H - -#include <grpc/support/port_platform.h> - -#include <grpc/grpc.h> -#include <grpc/support/sync.h> - -#include "src/core/lib/gprpp/thd.h" - -typedef struct alts_shared_resource { - grpc_channel* channel; - gpr_mu mu; -} alts_shared_resource; - -/* This method returns the address of alts_shared_resource object shared by all - * TSI handshakes. */ -alts_shared_resource* alts_get_shared_resource(void); - -#endif /* GRPC_CORE_TSI_ALTS_TRANSPORT_SECURITY_H */ diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc index d6a72ada0d..fb6ea19210 100644 --- a/src/core/tsi/ssl_transport_security.cc +++ b/src/core/tsi/ssl_transport_security.cc @@ -156,9 +156,13 @@ static unsigned long openssl_thread_id_cb(void) { #endif static void init_openssl(void) { +#if OPENSSL_API_COMPAT >= 0x10100000L + OPENSSL_init_ssl(0, NULL); +#else SSL_library_init(); SSL_load_error_strings(); OpenSSL_add_all_algorithms(); +#endif #if OPENSSL_VERSION_NUMBER < 0x10100000 if (!CRYPTO_get_locking_callback()) { int num_locks = CRYPTO_num_locks(); @@ -1649,7 +1653,11 @@ tsi_result tsi_create_ssl_client_handshaker_factory_with_options( return TSI_INVALID_ARGUMENT; } +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + ssl_context = SSL_CTX_new(TLS_method()); +#else ssl_context = SSL_CTX_new(TLSv1_2_method()); +#endif if (ssl_context == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); return TSI_INVALID_ARGUMENT; @@ -1806,7 +1814,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options( for (i = 0; i < options->num_key_cert_pairs; i++) { do { +#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L + impl->ssl_contexts[i] = SSL_CTX_new(TLS_method()); +#else impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method()); +#endif if (impl->ssl_contexts[i] == nullptr) { gpr_log(GPR_ERROR, "Could not create ssl context."); result = TSI_OUT_OF_RESOURCES; @@ -1850,31 +1862,30 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options( break; } SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names); - switch (options->client_certificate_request) { - case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); - break; - case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, - NullVerifyCallback); - break; - case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: - SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); - break; - case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: - SSL_CTX_set_verify( - impl->ssl_contexts[i], - SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, - NullVerifyCallback); - break; - case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: - SSL_CTX_set_verify( - impl->ssl_contexts[i], - SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - break; - } - /* TODO(jboeuf): Add revocation verification. */ } + switch (options->client_certificate_request) { + case TSI_DONT_REQUEST_CLIENT_CERTIFICATE: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, + NullVerifyCallback); + break; + case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + NullVerifyCallback); + break; + case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY: + SSL_CTX_set_verify(impl->ssl_contexts[i], + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + break; + } + /* TODO(jboeuf): Add revocation verification. */ result = extract_x509_subject_names_from_pem_cert( options->pem_key_cert_pairs[i].cert_chain, diff --git a/src/core/tsi/transport_security.cc b/src/core/tsi/transport_security.cc index ca861b52de..078a917bba 100644 --- a/src/core/tsi/transport_security.cc +++ b/src/core/tsi/transport_security.cc @@ -213,10 +213,10 @@ tsi_result tsi_handshaker_next( void tsi_handshaker_shutdown(tsi_handshaker* self) { if (self == nullptr || self->vtable == nullptr) return; - self->handshake_shutdown = true; if (self->vtable->shutdown != nullptr) { self->vtable->shutdown(self); } + self->handshake_shutdown = true; } void tsi_handshaker_destroy(tsi_handshaker* self) { diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 8e1cea0269..a31d0b30b1 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -54,13 +54,11 @@ namespace grpc { static internal::GrpcLibraryInitializer g_gli_initializer; Channel::Channel( const grpc::string& host, grpc_channel* channel, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) : host_(host), c_channel_(channel) { - if (interceptor_creators != nullptr) { - interceptor_creators_ = std::move(*interceptor_creators); - } + interceptor_creators_ = std::move(interceptor_creators); g_gli_initializer.summon(); } @@ -151,8 +149,9 @@ internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, // ClientRpcInfo should be set before call because set_call also checks // whether the call has been cancelled, and if the call was cancelled, we // should notify the interceptors too/ - auto* info = context->set_client_rpc_info( - method.name(), this, interceptor_creators_, interceptor_pos); + auto* info = + context->set_client_rpc_info(method.name(), method.method_type(), this, + interceptor_creators_, interceptor_pos); context->set_call(c_call, shared_from_this()); return internal::Call(c_call, this, cq, info); diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index 50da75f09c..c9ea3e5f83 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -41,9 +41,10 @@ class DefaultGlobalClientCallbacks final }; static internal::GrpcLibraryInitializer g_gli_initializer; -static DefaultGlobalClientCallbacks g_default_client_callbacks; +static DefaultGlobalClientCallbacks* g_default_client_callbacks = + new DefaultGlobalClientCallbacks(); static ClientContext::GlobalCallbacks* g_client_callbacks = - &g_default_client_callbacks; + g_default_client_callbacks; ClientContext::ClientContext() : initial_metadata_received_(false), @@ -139,9 +140,9 @@ grpc::string ClientContext::peer() const { } void ClientContext::SetGlobalCallbacks(GlobalCallbacks* client_callbacks) { - GPR_ASSERT(g_client_callbacks == &g_default_client_callbacks); + GPR_ASSERT(g_client_callbacks == g_default_client_callbacks); GPR_ASSERT(client_callbacks != nullptr); - GPR_ASSERT(client_callbacks != &g_default_client_callbacks); + GPR_ASSERT(client_callbacks != g_default_client_callbacks); g_client_callbacks = client_callbacks; } diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc index efdff6c265..457daa674c 100644 --- a/src/cpp/client/create_channel.cc +++ b/src/cpp/client/create_channel.cc @@ -39,13 +39,14 @@ std::shared_ptr<Channel> CreateCustomChannel( const std::shared_ptr<ChannelCredentials>& creds, const ChannelArguments& args) { GrpcLibraryCodegen init_lib; // We need to call init in case of a bad creds. - return creds - ? creds->CreateChannel(target, args) - : CreateChannelInternal("", - grpc_lame_client_channel_create( - nullptr, GRPC_STATUS_INVALID_ARGUMENT, - "Invalid credentials."), - nullptr); + return creds ? creds->CreateChannel(target, args) + : CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>()); } namespace experimental { @@ -64,17 +65,18 @@ std::shared_ptr<Channel> CreateCustomChannelWithInterceptors( const grpc::string& target, const std::shared_ptr<ChannelCredentials>& creds, const ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) { - return creds - ? creds->CreateChannelWithInterceptors( - target, args, std::move(interceptor_creators)) - : CreateChannelInternal("", - grpc_lame_client_channel_create( - nullptr, GRPC_STATUS_INVALID_ARGUMENT, - "Invalid credentials."), - nullptr); + return creds ? creds->CreateChannelWithInterceptors( + target, args, std::move(interceptor_creators)) + : CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>()); } } // namespace experimental diff --git a/src/cpp/client/create_channel_internal.cc b/src/cpp/client/create_channel_internal.cc index 313d682aae..a0efb97f7e 100644 --- a/src/cpp/client/create_channel_internal.cc +++ b/src/cpp/client/create_channel_internal.cc @@ -26,8 +26,8 @@ namespace grpc { std::shared_ptr<Channel> CreateChannelInternal( const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) { return std::shared_ptr<Channel>( new Channel(host, c_channel, std::move(interceptor_creators))); diff --git a/src/cpp/client/create_channel_internal.h b/src/cpp/client/create_channel_internal.h index 512fc22866..a90c92c518 100644 --- a/src/cpp/client/create_channel_internal.h +++ b/src/cpp/client/create_channel_internal.h @@ -31,8 +31,8 @@ class Channel; std::shared_ptr<Channel> CreateChannelInternal( const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators); } // namespace grpc diff --git a/src/cpp/client/create_channel_posix.cc b/src/cpp/client/create_channel_posix.cc index 8d775e7a87..3affc1ef39 100644 --- a/src/cpp/client/create_channel_posix.cc +++ b/src/cpp/client/create_channel_posix.cc @@ -34,7 +34,8 @@ std::shared_ptr<Channel> CreateInsecureChannelFromFd(const grpc::string& target, init_lib.init(); return CreateChannelInternal( "", grpc_insecure_channel_create_from_fd(target.c_str(), fd, nullptr), - nullptr); + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); } std::shared_ptr<Channel> CreateCustomInsecureChannelFromFd( @@ -46,15 +47,16 @@ std::shared_ptr<Channel> CreateCustomInsecureChannelFromFd( return CreateChannelInternal( "", grpc_insecure_channel_create_from_fd(target.c_str(), fd, &channel_args), - nullptr); + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); } namespace experimental { std::shared_ptr<Channel> CreateCustomInsecureChannelWithInterceptorsFromFd( const grpc::string& target, int fd, const ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) { internal::GrpcLibrary init_lib; init_lib.init(); diff --git a/src/cpp/client/cronet_credentials.cc b/src/cpp/client/cronet_credentials.cc index 09a76b428c..b2801764f2 100644 --- a/src/cpp/client/cronet_credentials.cc +++ b/src/cpp/client/cronet_credentials.cc @@ -31,7 +31,10 @@ class CronetChannelCredentialsImpl final : public ChannelCredentials { std::shared_ptr<grpc::Channel> CreateChannel( const string& target, const grpc::ChannelArguments& args) override { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>()); } SecureChannelCredentials* AsSecureCredentials() override { return nullptr; } @@ -39,8 +42,8 @@ class CronetChannelCredentialsImpl final : public ChannelCredentials { private: std::shared_ptr<grpc::Channel> CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) override { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); diff --git a/src/cpp/client/generic_stub.cc b/src/cpp/client/generic_stub.cc index 87902b26f0..f61c1b5317 100644 --- a/src/cpp/client/generic_stub.cc +++ b/src/cpp/client/generic_stub.cc @@ -72,4 +72,13 @@ void GenericStub::experimental_type::UnaryCall( context, request, response, std::move(on_completion)); } +void GenericStub::experimental_type::PrepareBidiStreamingCall( + ClientContext* context, const grpc::string& method, + experimental::ClientBidiReactor<ByteBuffer, ByteBuffer>* reactor) { + internal::ClientCallbackReaderWriterFactory<ByteBuffer, ByteBuffer>::Create( + stub_->channel_.get(), + internal::RpcMethod(method.c_str(), internal::RpcMethod::BIDI_STREAMING), + context, reactor); +} + } // namespace grpc diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc index b816e0c59a..241ce91803 100644 --- a/src/cpp/client/insecure_credentials.cc +++ b/src/cpp/client/insecure_credentials.cc @@ -32,13 +32,16 @@ class InsecureChannelCredentialsImpl final : public ChannelCredentials { public: std::shared_ptr<grpc::Channel> CreateChannel( const string& target, const grpc::ChannelArguments& args) override { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector<std::unique_ptr< + experimental::ClientInterceptorFactoryInterface>>()); } std::shared_ptr<grpc::Channel> CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) override { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index 7faaa20e78..4d0ed355ab 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -36,14 +36,17 @@ SecureChannelCredentials::SecureChannelCredentials( std::shared_ptr<grpc::Channel> SecureChannelCredentials::CreateChannel( const string& target, const grpc::ChannelArguments& args) { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); } std::shared_ptr<grpc::Channel> SecureChannelCredentials::CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); @@ -258,10 +261,10 @@ void MetadataCredentialsPluginWrapper::InvokePlugin( grpc_status_code* status_code, const char** error_details) { std::multimap<grpc::string, grpc::string> metadata; - // const_cast is safe since the SecureAuthContext does not take owndership and - // the object is passed as a const ref to plugin_->GetMetadata. + // const_cast is safe since the SecureAuthContext only inc/dec the refcount + // and the object is passed as a const ref to plugin_->GetMetadata. SecureAuthContext cpp_channel_auth_context( - const_cast<grpc_auth_context*>(context.channel_auth_context), false); + const_cast<grpc_auth_context*>(context.channel_auth_context)); Status status = plugin_->GetMetadata(context.service_url, context.method_name, cpp_channel_auth_context, &metadata); diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index bfb6e17ee9..4918bd5a4d 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -24,6 +24,7 @@ #include <grpcpp/security/credentials.h> #include <grpcpp/support/config.h> +#include "src/core/lib/security/credentials/credentials.h" #include "src/cpp/server/thread_pool_interface.h" namespace grpc { @@ -31,7 +32,9 @@ namespace grpc { class SecureChannelCredentials final : public ChannelCredentials { public: explicit SecureChannelCredentials(grpc_channel_credentials* c_creds); - ~SecureChannelCredentials() { grpc_channel_credentials_release(c_creds_); } + ~SecureChannelCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_channel_credentials* GetRawCreds() { return c_creds_; } std::shared_ptr<grpc::Channel> CreateChannel( @@ -42,8 +45,8 @@ class SecureChannelCredentials final : public ChannelCredentials { private: std::shared_ptr<grpc::Channel> CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) override; grpc_channel_credentials* const c_creds_; }; @@ -51,7 +54,9 @@ class SecureChannelCredentials final : public ChannelCredentials { class SecureCallCredentials final : public CallCredentials { public: explicit SecureCallCredentials(grpc_call_credentials* c_creds); - ~SecureCallCredentials() { grpc_call_credentials_release(c_creds_); } + ~SecureCallCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_call_credentials* GetRawCreds() { return c_creds_; } bool ApplyToCall(grpc_call* call) override; diff --git a/src/cpp/common/alarm.cc b/src/cpp/common/alarm.cc index 5819a4210b..148f0b9bc9 100644 --- a/src/cpp/common/alarm.cc +++ b/src/cpp/common/alarm.cc @@ -31,10 +31,10 @@ #include <grpc/support/log.h> #include "src/core/lib/debug/trace.h" -namespace grpc { +namespace grpc_impl { namespace internal { -class AlarmImpl : public CompletionQueueTag { +class AlarmImpl : public ::grpc::internal::CompletionQueueTag { public: AlarmImpl() : cq_(nullptr), tag_(nullptr) { gpr_ref_init(&refs_, 1); @@ -51,7 +51,7 @@ class AlarmImpl : public CompletionQueueTag { Unref(); return true; } - void Set(CompletionQueue* cq, gpr_timespec deadline, void* tag) { + void Set(::grpc::CompletionQueue* cq, gpr_timespec deadline, void* tag) { grpc_core::ExecCtx exec_ctx; GRPC_CQ_INTERNAL_REF(cq->cq(), "alarm"); cq_ = cq->cq(); @@ -114,13 +114,14 @@ class AlarmImpl : public CompletionQueueTag { }; } // namespace internal -static internal::GrpcLibraryInitializer g_gli_initializer; +static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer; Alarm::Alarm() : alarm_(new internal::AlarmImpl()) { g_gli_initializer.summon(); } -void Alarm::SetInternal(CompletionQueue* cq, gpr_timespec deadline, void* tag) { +void Alarm::SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline, + void* tag) { // Note that we know that alarm_ is actually an internal::AlarmImpl // but we declared it as the base pointer to avoid a forward declaration // or exposing core data structures in the C++ public headers. @@ -145,4 +146,4 @@ Alarm::~Alarm() { } void Alarm::Cancel() { static_cast<internal::AlarmImpl*>(alarm_)->Cancel(); } -} // namespace grpc +} // namespace grpc_impl diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc index 50ee9d871f..214d72f853 100644 --- a/src/cpp/common/channel_arguments.cc +++ b/src/cpp/common/channel_arguments.cc @@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) { } if (!replaced) { + strings_.push_back(grpc::string(mutator_arg.key)); args_.push_back(mutator_arg); + args_.back().key = const_cast<char*>(strings_.back().c_str()); } } diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc index 1d66dd3d1f..7a2b5afed6 100644 --- a/src/cpp/common/secure_auth_context.cc +++ b/src/cpp/common/secure_auth_context.cc @@ -22,19 +22,12 @@ namespace grpc { -SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx, - bool take_ownership) - : ctx_(ctx), take_ownership_(take_ownership) {} - -SecureAuthContext::~SecureAuthContext() { - if (take_ownership_) grpc_auth_context_release(ctx_); -} - std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector<grpc::string_ref>(); } - grpc_auth_property_iterator iter = grpc_auth_context_peer_identity(ctx_); + grpc_auth_property_iterator iter = + grpc_auth_context_peer_identity(ctx_.get()); std::vector<grpc::string_ref> identity; const grpc_auth_property* property = nullptr; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -45,20 +38,20 @@ std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { } grpc::string SecureAuthContext::GetPeerIdentityPropertyName() const { - if (!ctx_) { + if (ctx_ == nullptr) { return ""; } - const char* name = grpc_auth_context_peer_identity_property_name(ctx_); + const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get()); return name == nullptr ? "" : name; } std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues( const grpc::string& name) const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector<grpc::string_ref>(); } grpc_auth_property_iterator iter = - grpc_auth_context_find_properties_by_name(ctx_, name.c_str()); + grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str()); const grpc_auth_property* property = nullptr; std::vector<grpc::string_ref> values; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -68,9 +61,9 @@ std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues( } AuthPropertyIterator SecureAuthContext::begin() const { - if (ctx_) { + if (ctx_ != nullptr) { grpc_auth_property_iterator iter = - grpc_auth_context_property_iterator(ctx_); + grpc_auth_context_property_iterator(ctx_.get()); const grpc_auth_property* property = grpc_auth_property_iterator_next(&iter); return AuthPropertyIterator(property, &iter); @@ -85,19 +78,20 @@ AuthPropertyIterator SecureAuthContext::end() const { void SecureAuthContext::AddProperty(const grpc::string& key, const grpc::string_ref& value) { - if (!ctx_) return; - grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size()); + if (ctx_ == nullptr) return; + grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(), + value.size()); } bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) { - if (!ctx_) return false; - return grpc_auth_context_set_peer_identity_property_name(ctx_, + if (ctx_ == nullptr) return false; + return grpc_auth_context_set_peer_identity_property_name(ctx_.get(), name.c_str()) != 0; } bool SecureAuthContext::IsPeerAuthenticated() const { - if (!ctx_) return false; - return grpc_auth_context_peer_is_authenticated(ctx_) != 0; + if (ctx_ == nullptr) return false; + return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0; } } // namespace grpc diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h index 142617959c..2e8f793721 100644 --- a/src/cpp/common/secure_auth_context.h +++ b/src/cpp/common/secure_auth_context.h @@ -21,15 +21,17 @@ #include <grpcpp/security/auth_context.h> -struct grpc_auth_context; +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" namespace grpc { class SecureAuthContext final : public AuthContext { public: - SecureAuthContext(grpc_auth_context* ctx, bool take_ownership); + explicit SecureAuthContext(grpc_auth_context* ctx) + : ctx_(ctx != nullptr ? ctx->Ref() : nullptr) {} - ~SecureAuthContext() override; + ~SecureAuthContext() override = default; bool IsPeerAuthenticated() const override; @@ -50,8 +52,7 @@ class SecureAuthContext final : public AuthContext { virtual bool SetPeerIdentityPropertyName(const grpc::string& name) override; private: - grpc_auth_context* ctx_; - bool take_ownership_; + grpc_core::RefCountedPtr<grpc_auth_context> ctx_; }; } // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc index bc1387c8d7..908c46629e 100644 --- a/src/cpp/common/secure_create_auth_context.cc +++ b/src/cpp/common/secure_create_auth_context.cc @@ -20,6 +20,7 @@ #include <grpc/grpc.h> #include <grpc/grpc_security.h> #include <grpcpp/security/auth_context.h> +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/cpp/common/secure_auth_context.h" namespace grpc { @@ -28,8 +29,8 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) { if (call == nullptr) { return std::shared_ptr<const AuthContext>(); } - return std::shared_ptr<const AuthContext>( - new SecureAuthContext(grpc_call_auth_context(call), true)); + grpc_core::RefCountedPtr<grpc_auth_context> ctx(grpc_call_auth_context(call)); + return std::make_shared<SecureAuthContext>(ctx.get()); } } // namespace grpc diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc index 8abd45efb7..55da89e6c8 100644 --- a/src/cpp/common/version_cc.cc +++ b/src/cpp/common/version_cc.cc @@ -22,5 +22,5 @@ #include <grpcpp/grpcpp.h> namespace grpc { -grpc::string Version() { return "1.17.0-dev"; } +grpc::string Version() { return "1.18.0-dev"; } } // namespace grpc diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc index 78fc69a805..160590353a 100644 --- a/src/cpp/ext/filters/census/context.cc +++ b/src/cpp/ext/filters/census/context.cc @@ -28,6 +28,9 @@ using ::opencensus::trace::SpanContext; void GenerateServerContext(absl::string_view tracing, absl::string_view stats, absl::string_view primary_role, absl::string_view method, CensusContext* context) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + context->~CensusContext(); GrpcTraceContext trace_ctxt; if (TraceContextEncoding::Decode(tracing, &trace_ctxt) != TraceContextEncoding::kEncodeDecodeFailure) { @@ -42,6 +45,9 @@ void GenerateServerContext(absl::string_view tracing, absl::string_view stats, void GenerateClientContext(absl::string_view method, CensusContext* ctxt, CensusContext* parent_ctxt) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + ctxt->~CensusContext(); if (parent_ctxt != nullptr) { SpanContext span_ctxt = parent_ctxt->Context(); Span span = parent_ctxt->Span(); diff --git a/src/cpp/server/channelz/channelz_service.cc b/src/cpp/server/channelz/channelz_service.cc index 9ecb9de7e4..c44a9ac0de 100644 --- a/src/cpp/server/channelz/channelz_service.cc +++ b/src/cpp/server/channelz/channelz_service.cc @@ -79,8 +79,8 @@ Status ChannelzService::GetServer(ServerContext* unused, Status ChannelzService::GetServerSockets( ServerContext* unused, const channelz::v1::GetServerSocketsRequest* request, channelz::v1::GetServerSocketsResponse* response) { - char* json_str = grpc_channelz_get_server_sockets(request->server_id(), - request->start_socket_id()); + char* json_str = grpc_channelz_get_server_sockets( + request->server_id(), request->start_socket_id(), request->max_results()); if (json_str == nullptr) { return Status(StatusCode::INTERNAL, "grpc_channelz_get_server_sockets returned null"); diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc index c951c69d51..44aebd2f9d 100644 --- a/src/cpp/server/health/default_health_check_service.cc +++ b/src/cpp/server/health/default_health_check_service.cc @@ -42,18 +42,37 @@ DefaultHealthCheckService::DefaultHealthCheckService() { void DefaultHealthCheckService::SetServingStatus( const grpc::string& service_name, bool serving) { std::unique_lock<std::mutex> lock(mu_); + if (shutdown_) { + // Set to NOT_SERVING in case service_name is not in the map. + serving = false; + } services_map_[service_name].SetServingStatus(serving ? SERVING : NOT_SERVING); } void DefaultHealthCheckService::SetServingStatus(bool serving) { const ServingStatus status = serving ? SERVING : NOT_SERVING; std::unique_lock<std::mutex> lock(mu_); + if (shutdown_) { + return; + } for (auto& p : services_map_) { ServiceData& service_data = p.second; service_data.SetServingStatus(status); } } +void DefaultHealthCheckService::Shutdown() { + std::unique_lock<std::mutex> lock(mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto& p : services_map_) { + ServiceData& service_data = p.second; + service_data.SetServingStatus(NOT_SERVING); + } +} + DefaultHealthCheckService::ServingStatus DefaultHealthCheckService::GetServingStatus( const grpc::string& service_name) const { diff --git a/src/cpp/server/health/default_health_check_service.h b/src/cpp/server/health/default_health_check_service.h index 450bd543f5..9551cd2e2c 100644 --- a/src/cpp/server/health/default_health_check_service.h +++ b/src/cpp/server/health/default_health_check_service.h @@ -237,6 +237,8 @@ class DefaultHealthCheckService final : public HealthCheckServiceInterface { bool serving) override; void SetServingStatus(bool serving) override; + void Shutdown() override; + ServingStatus GetServingStatus(const grpc::string& service_name) const; HealthCheckServiceImpl* GetHealthCheckService( @@ -272,6 +274,7 @@ class DefaultHealthCheckService final : public HealthCheckServiceInterface { const std::shared_ptr<HealthCheckServiceImpl::CallHandler>& handler); mutable std::mutex mu_; + bool shutdown_ = false; // Guarded by mu_. std::map<grpc::string, ServiceData> services_map_; // Guarded by mu_. std::unique_ptr<HealthCheckServiceImpl> impl_; }; diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index ebb17def32..453e76eb25 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -61,7 +61,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor( metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key), StringRefFromSlice(&md[i].value))); } - SecureAuthContext context(ctx, false); + SecureAuthContext context(ctx); AuthMetadataProcessor::OutputMetadata consumed_metadata; AuthMetadataProcessor::OutputMetadata response_metadata; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 8b1658dd27..13741ce7aa 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -84,10 +84,7 @@ class ShutdownTag : public internal::CompletionQueueTag { class DummyTag : public internal::CompletionQueueTag { public: - bool FinalizeResult(void** tag, bool* status) { - *status = true; - return true; - } + bool FinalizeResult(void** tag, bool* status) { return true; } }; class UnimplementedAsyncRequestContext { @@ -239,9 +236,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { : nullptr), request_(nullptr), method_(mrd->method_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size(), - ctx_.set_server_rpc_info(method_->name(), - server->interceptor_creators_)), + call_( + mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), method_->method_type(), + server->interceptor_creators_)), server_(server), global_callbacks_(nullptr), resources_(false) { @@ -294,7 +292,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void ContinueRunAfterInterception() { { - ctx_.BeginCompletionOp(&call_, false); + ctx_.BeginCompletionOp(&call_, nullptr, nullptr); global_callbacks_->PreSynchronousRequest(&ctx_); auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get(); @@ -430,7 +428,8 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { req_->call_, req_->server_, req_->cq_, req_->server_->max_receive_message_size(), req_->ctx_.set_server_rpc_info( - req_->method_->name(), req_->server_->interceptor_creators_)); + req_->method_->name(), req_->method_->method_type(), + req_->server_->interceptor_creators_)); req_->interceptor_methods_.SetCall(call_); req_->interceptor_methods_.SetReverse(); @@ -459,7 +458,6 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { } } void ContinueRunAfterInterception() { - req_->ctx_.BeginCompletionOp(call_, true); req_->method_->handler()->RunHandler( internal::MethodHandler::HandlerParameter( call_, &req_->ctx_, req_->request_, req_->request_status_, @@ -732,14 +730,15 @@ std::shared_ptr<Channel> Server::InProcessChannel( grpc_channel_args channel_args = args.c_channel_args(); return CreateChannelInternal( "inproc", grpc_inproc_channel_create(server_, &channel_args, nullptr), - nullptr); + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); } std::shared_ptr<Channel> Server::experimental_type::InProcessChannelWithInterceptors( const ChannelArguments& args, - std::unique_ptr<std::vector< - std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> + std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> interceptor_creators) { grpc_channel_args channel_args = args.c_channel_args(); return CreateChannelInternal( @@ -1020,7 +1019,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, } } if (*status && call_) { - context_->BeginCompletionOp(&call_wrapper_, false); + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); } *tag = tag_; if (delete_on_finalize_) { @@ -1031,7 +1030,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, void ServerInterface::BaseAsyncRequest:: ContinueFinalizeResultAfterInterception() { - context_->BeginCompletionOp(&call_wrapper_, false); + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); // Queue a tag which will be returned immediately grpc_core::ExecCtx exec_ctx; grpc_cq_begin_op(notification_cq_->cq(), this); @@ -1044,10 +1043,12 @@ void ServerInterface::BaseAsyncRequest:: ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag, const char* name) + ServerCompletionQueue* notification_cq, void* tag, const char* name, + internal::RpcMethod::RpcType type) : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, true), - name_(name) {} + name_(name), + type_(type) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -1094,6 +1095,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, call_, server_, call_cq_, server_->max_receive_message_size(), context_->set_server_rpc_info( static_cast<GenericServerContext*>(context_)->method_.c_str(), + internal::RpcMethod::BIDI_STREAMING, *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 396996e5bc..1b524bc3e8 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -17,6 +17,7 @@ */ #include <grpcpp/server_context.h> +#include <grpcpp/support/server_callback.h> #include <algorithm> #include <mutex> @@ -41,8 +42,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { public: // initial refs: one in the server context, one in the cq // must ref the call before calling constructor and after deleting this - CompletionOp(internal::Call* call) + CompletionOp(internal::Call* call, internal::ServerReactor* reactor) : call_(*call), + reactor_(reactor), has_tag_(false), tag_(nullptr), core_cq_tag_(this), @@ -124,9 +126,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { return; } /* Start a dummy op so that we can return the tag */ - GPR_CODEGEN_ASSERT(GRPC_CALL_OK == - g_core_codegen_interface->grpc_call_start_batch( - call_.call(), nullptr, 0, this, nullptr)); + GPR_CODEGEN_ASSERT( + GRPC_CALL_OK == + grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag_, nullptr)); } private: @@ -136,13 +138,14 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { } internal::Call call_; + internal::ServerReactor* reactor_; bool has_tag_; void* tag_; void* core_cq_tag_; std::mutex mu_; int refs_; bool finalized_; - int cancelled_; + int cancelled_; // This is an int (not bool) because it is passed to core bool done_intercepting_; internal::InterceptorBatchMethodsImpl interceptor_methods_; }; @@ -190,7 +193,16 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { } finalized_ = true; - if (!*status) cancelled_ = 1; + // If for some reason the incoming status is false, mark that as a + // cancellation. + // TODO(vjpai): does this ever happen? + if (!*status) { + cancelled_ = 1; + } + + if (cancelled_ && (reactor_ != nullptr)) { + reactor_->OnCancel(); + } /* Release the lock since we are going to be running through interceptors now */ lock.unlock(); @@ -247,21 +259,29 @@ void ServerContext::BindDeadlineAndMetadata(gpr_timespec deadline, ServerContext::~ServerContext() { Clear(); } void ServerContext::Clear() { - if (call_) { - grpc_call_unref(call_); - } + auth_context_.reset(); + initial_metadata_.clear(); + trailing_metadata_.clear(); + client_metadata_.Reset(); if (completion_op_) { completion_op_->Unref(); + completion_op_ = nullptr; completion_tag_.Clear(); } if (rpc_info_) { rpc_info_->Unref(); + rpc_info_ = nullptr; + } + if (call_) { + auto* call = call_; + call_ = nullptr; + grpc_call_unref(call); } - // Don't need to clear out call_, completion_op_, or rpc_info_ because this is - // either called from destructor or just before Setup } -void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) { +void ServerContext::BeginCompletionOp(internal::Call* call, + std::function<void(bool)> callback, + internal::ServerReactor* reactor) { GPR_ASSERT(!completion_op_); if (rpc_info_) { rpc_info_->Ref(); @@ -269,10 +289,11 @@ void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) { grpc_call_ref(call->call()); completion_op_ = new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) - CompletionOp(call); - if (callback) { - completion_tag_.Set(call->call(), nullptr, completion_op_); + CompletionOp(call, reactor); + if (callback != nullptr) { + completion_tag_.Set(call->call(), std::move(callback), completion_op_); completion_op_->set_core_cq_tag(&completion_tag_); + completion_op_->set_tag(completion_op_); } else if (has_notify_when_done_tag_) { completion_op_->set_tag(async_notify_when_done_tag_); } diff --git a/src/csharp/Grpc.Core.Tests/AppDomainUnloadTest.cs b/src/csharp/Grpc.Core.Tests/AppDomainUnloadTest.cs index 3a161763fd..b10d7a1045 100644 --- a/src/csharp/Grpc.Core.Tests/AppDomainUnloadTest.cs +++ b/src/csharp/Grpc.Core.Tests/AppDomainUnloadTest.cs @@ -25,7 +25,7 @@ namespace Grpc.Core.Tests { public class AppDomainUnloadTest { -#if NETCOREAPP1_0 +#if NETCOREAPP1_1 || NETCOREAPP2_1 [Test] [Ignore("Not supported for CoreCLR")] public void AppDomainUnloadHookCanCleanupAbandonedCall() diff --git a/src/csharp/Grpc.Core.Tests/Grpc.Core.Tests.csproj b/src/csharp/Grpc.Core.Tests/Grpc.Core.Tests.csproj index d58f046824..178931a3d7 100755 --- a/src/csharp/Grpc.Core.Tests/Grpc.Core.Tests.csproj +++ b/src/csharp/Grpc.Core.Tests/Grpc.Core.Tests.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Core.Tests</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Core.Tests</PackageId> diff --git a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs index e7d8939978..5c7d48f786 100644 --- a/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs +++ b/src/csharp/Grpc.Core.Tests/Internal/AsyncCallServerTest.cs @@ -49,7 +49,7 @@ namespace Grpc.Core.Internal.Tests fakeCall = new FakeNativeCall(); asyncCallServer = new AsyncCallServer<string, string>( - Marshallers.StringMarshaller.Serializer, Marshallers.StringMarshaller.Deserializer, + Marshallers.StringMarshaller.ContextualSerializer, Marshallers.StringMarshaller.ContextualDeserializer, server); asyncCallServer.InitializeForTesting(fakeCall); } diff --git a/src/csharp/Grpc.Core.Tests/MarshallerTest.cs b/src/csharp/Grpc.Core.Tests/MarshallerTest.cs index 97f64a0575..ad3e81d61f 100644 --- a/src/csharp/Grpc.Core.Tests/MarshallerTest.cs +++ b/src/csharp/Grpc.Core.Tests/MarshallerTest.cs @@ -69,11 +69,8 @@ namespace Grpc.Core.Tests Assert.AreSame(contextualSerializer, marshaller.ContextualSerializer); Assert.AreSame(contextualDeserializer, marshaller.ContextualDeserializer); - - // test that emulated serializer and deserializer work - var origMsg = "abc"; - var serialized = marshaller.Serializer(origMsg); - Assert.AreEqual(origMsg, marshaller.Deserializer(serialized)); + Assert.Throws(typeof(NotImplementedException), () => marshaller.Serializer("abc")); + Assert.Throws(typeof(NotImplementedException), () => marshaller.Deserializer(new byte[] {1, 2, 3})); } class FakeSerializationContext : SerializationContext diff --git a/src/csharp/Grpc.Core.Tests/NUnitMain.cs b/src/csharp/Grpc.Core.Tests/NUnitMain.cs index 49cb8cd3b9..3b206603f1 100644 --- a/src/csharp/Grpc.Core.Tests/NUnitMain.cs +++ b/src/csharp/Grpc.Core.Tests/NUnitMain.cs @@ -34,11 +34,7 @@ namespace Grpc.Core.Tests { // Make logger immune to NUnit capturing stdout and stderr to workaround https://github.com/nunit/nunit/issues/1406. GrpcEnvironment.SetLogger(new ConsoleLogger()); -#if NETCOREAPP1_0 - return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args, new ExtendedTextWrapper(Console.Out), Console.In); -#else - return new AutoRun().Execute(args); -#endif + return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); } } } diff --git a/src/csharp/Grpc.Core.Tests/SanityTest.cs b/src/csharp/Grpc.Core.Tests/SanityTest.cs index 0904453b6e..f785f70f4c 100644 --- a/src/csharp/Grpc.Core.Tests/SanityTest.cs +++ b/src/csharp/Grpc.Core.Tests/SanityTest.cs @@ -31,7 +31,7 @@ namespace Grpc.Core.Tests public class SanityTest { // TODO: make sanity test work for CoreCLR as well -#if !NETCOREAPP1_0 +#if !NETCOREAPP1_1 && !NETCOREAPP2_1 /// <summary> /// Because we depend on a native library, sometimes when things go wrong, the /// entire NUnit test process crashes. To be able to track down problems better, @@ -44,7 +44,7 @@ namespace Grpc.Core.Tests public void TestsJsonUpToDate() { var discoveredTests = DiscoverAllTestClasses(); - var testsFromFile + var testsFromFile = JsonConvert.DeserializeObject<Dictionary<string, List<string>>>(ReadTestsJson()); Assert.AreEqual(discoveredTests, testsFromFile); diff --git a/src/csharp/Grpc.Core/DeserializationContext.cs b/src/csharp/Grpc.Core/DeserializationContext.cs index 5b6372ef85..d69e0db5bd 100644 --- a/src/csharp/Grpc.Core/DeserializationContext.cs +++ b/src/csharp/Grpc.Core/DeserializationContext.cs @@ -16,6 +16,8 @@ #endregion +using System; + namespace Grpc.Core { /// <summary> @@ -41,6 +43,9 @@ namespace Grpc.Core /// (as there is no practical reason for doing so) and <c>DeserializationContext</c> implementations are free to assume so. /// </summary> /// <returns>byte array containing the entire payload.</returns> - public abstract byte[] PayloadAsNewBuffer(); + public virtual byte[] PayloadAsNewBuffer() + { + throw new NotImplementedException(); + } } } diff --git a/src/csharp/Grpc.Core/Internal/AsyncCall.cs b/src/csharp/Grpc.Core/Internal/AsyncCall.cs index 4cdf0ee6a7..b6d687f71e 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCall.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCall.cs @@ -54,7 +54,7 @@ namespace Grpc.Core.Internal ClientSideStatus? finishedStatus; public AsyncCall(CallInvocationDetails<TRequest, TResponse> callDetails) - : base(callDetails.RequestMarshaller.Serializer, callDetails.ResponseMarshaller.Deserializer) + : base(callDetails.RequestMarshaller.ContextualSerializer, callDetails.ResponseMarshaller.ContextualDeserializer) { this.details = callDetails.WithOptions(callDetails.Options.Normalize()); this.initialMetadataSent = true; // we always send metadata at the very beginning of the call. diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs index a93dc34620..39c9f7c616 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallBase.cs @@ -40,8 +40,8 @@ namespace Grpc.Core.Internal static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<AsyncCallBase<TWrite, TRead>>(); protected static readonly Status DeserializeResponseFailureStatus = new Status(StatusCode.Internal, "Failed to deserialize response message."); - readonly Func<TWrite, byte[]> serializer; - readonly Func<byte[], TRead> deserializer; + readonly Action<TWrite, SerializationContext> serializer; + readonly Func<DeserializationContext, TRead> deserializer; protected readonly object myLock = new object(); @@ -63,7 +63,7 @@ namespace Grpc.Core.Internal protected bool initialMetadataSent; protected long streamingWritesCounter; // Number of streaming send operations started so far. - public AsyncCallBase(Func<TWrite, byte[]> serializer, Func<byte[], TRead> deserializer) + public AsyncCallBase(Action<TWrite, SerializationContext> serializer, Func<DeserializationContext, TRead> deserializer) { this.serializer = GrpcPreconditions.CheckNotNull(serializer); this.deserializer = GrpcPreconditions.CheckNotNull(deserializer); @@ -215,14 +215,26 @@ namespace Grpc.Core.Internal protected byte[] UnsafeSerialize(TWrite msg) { - return serializer(msg); + DefaultSerializationContext context = null; + try + { + context = DefaultSerializationContext.GetInitializedThreadLocal(); + serializer(msg, context); + return context.GetPayload(); + } + finally + { + context?.Reset(); + } } protected Exception TryDeserialize(byte[] payload, out TRead msg) { + DefaultDeserializationContext context = null; try { - msg = deserializer(payload); + context = DefaultDeserializationContext.GetInitializedThreadLocal(payload); + msg = deserializer(context); return null; } catch (Exception e) @@ -230,6 +242,11 @@ namespace Grpc.Core.Internal msg = default(TRead); return e; } + finally + { + context?.Reset(); + + } } /// <summary> diff --git a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs index 0ceca4abb8..0bf1fb3b7d 100644 --- a/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs +++ b/src/csharp/Grpc.Core/Internal/AsyncCallServer.cs @@ -37,7 +37,7 @@ namespace Grpc.Core.Internal readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); readonly Server server; - public AsyncCallServer(Func<TResponse, byte[]> serializer, Func<byte[], TRequest> deserializer, Server server) : base(serializer, deserializer) + public AsyncCallServer(Action<TResponse, SerializationContext> serializer, Func<DeserializationContext, TRequest> deserializer, Server server) : base(serializer, deserializer) { this.server = GrpcPreconditions.CheckNotNull(server); } diff --git a/src/csharp/Grpc.Core/Internal/DefaultDeserializationContext.cs b/src/csharp/Grpc.Core/Internal/DefaultDeserializationContext.cs new file mode 100644 index 0000000000..7ace80e8d5 --- /dev/null +++ b/src/csharp/Grpc.Core/Internal/DefaultDeserializationContext.cs @@ -0,0 +1,66 @@ +#region Copyright notice and license + +// Copyright 2018 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core.Utils; +using System; +using System.Threading; + +namespace Grpc.Core.Internal +{ + internal class DefaultDeserializationContext : DeserializationContext + { + static readonly ThreadLocal<DefaultDeserializationContext> threadLocalInstance = + new ThreadLocal<DefaultDeserializationContext>(() => new DefaultDeserializationContext(), false); + + byte[] payload; + bool alreadyCalledPayloadAsNewBuffer; + + public DefaultDeserializationContext() + { + Reset(); + } + + public override int PayloadLength => payload.Length; + + public override byte[] PayloadAsNewBuffer() + { + GrpcPreconditions.CheckState(!alreadyCalledPayloadAsNewBuffer); + alreadyCalledPayloadAsNewBuffer = true; + return payload; + } + + public void Initialize(byte[] payload) + { + this.payload = GrpcPreconditions.CheckNotNull(payload); + this.alreadyCalledPayloadAsNewBuffer = false; + } + + public void Reset() + { + this.payload = null; + this.alreadyCalledPayloadAsNewBuffer = true; // mark payload as read + } + + public static DefaultDeserializationContext GetInitializedThreadLocal(byte[] payload) + { + var instance = threadLocalInstance.Value; + instance.Initialize(payload); + return instance; + } + } +} diff --git a/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs b/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs new file mode 100644 index 0000000000..cceb194879 --- /dev/null +++ b/src/csharp/Grpc.Core/Internal/DefaultSerializationContext.cs @@ -0,0 +1,62 @@ +#region Copyright notice and license + +// Copyright 2018 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core.Utils; +using System.Threading; + +namespace Grpc.Core.Internal +{ + internal class DefaultSerializationContext : SerializationContext + { + static readonly ThreadLocal<DefaultSerializationContext> threadLocalInstance = + new ThreadLocal<DefaultSerializationContext>(() => new DefaultSerializationContext(), false); + + bool isComplete; + byte[] payload; + + public DefaultSerializationContext() + { + Reset(); + } + + public override void Complete(byte[] payload) + { + GrpcPreconditions.CheckState(!isComplete); + this.isComplete = true; + this.payload = payload; + } + + internal byte[] GetPayload() + { + return this.payload; + } + + public void Reset() + { + this.isComplete = false; + this.payload = null; + } + + public static DefaultSerializationContext GetInitializedThreadLocal() + { + var instance = threadLocalInstance.Value; + instance.Reset(); + return instance; + } + } +} diff --git a/src/csharp/Grpc.Core/Internal/NativeExtension.cs b/src/csharp/Grpc.Core/Internal/NativeExtension.cs index f526b913af..5177b69fd9 100644 --- a/src/csharp/Grpc.Core/Internal/NativeExtension.cs +++ b/src/csharp/Grpc.Core/Internal/NativeExtension.cs @@ -83,13 +83,13 @@ namespace Grpc.Core.Internal // See https://github.com/grpc/grpc/pull/7303 for one option. var assemblyDirectory = Path.GetDirectoryName(GetAssemblyPath()); - // With old-style VS projects, the native libraries get copied using a .targets rule to the build output folder + // With "classic" VS projects, the native libraries get copied using a .targets rule to the build output folder // alongside the compiled assembly. // With dotnet cli projects targeting net45 framework, the native libraries (just the required ones) // are similarly copied to the built output folder, through the magic of Microsoft.NETCore.Platforms. var classicPath = Path.Combine(assemblyDirectory, GetNativeLibraryFilename()); - // With dotnet cli project targeting netcoreapp1.0, projects will use Grpc.Core assembly directly in the location where it got restored + // With dotnet cli project targeting netcoreappX.Y, projects will use Grpc.Core assembly directly in the location where it got restored // by nuget. We locate the native libraries based on known structure of Grpc.Core nuget package. // When "dotnet publish" is used, the runtimes directory is copied next to the published assemblies. string runtimesDirectory = string.Format("runtimes/{0}/native", GetPlatformString()); diff --git a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs index 81522cf8fe..ec732e8c7f 100644 --- a/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs +++ b/src/csharp/Grpc.Core/Internal/ServerCallHandler.cs @@ -52,8 +52,8 @@ namespace Grpc.Core.Internal public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq) { var asyncCall = new AsyncCallServer<TRequest, TResponse>( - method.ResponseMarshaller.Serializer, - method.RequestMarshaller.Deserializer, + method.ResponseMarshaller.ContextualSerializer, + method.RequestMarshaller.ContextualDeserializer, newRpc.Server); asyncCall.Initialize(newRpc.Call, cq); @@ -116,8 +116,8 @@ namespace Grpc.Core.Internal public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq) { var asyncCall = new AsyncCallServer<TRequest, TResponse>( - method.ResponseMarshaller.Serializer, - method.RequestMarshaller.Deserializer, + method.ResponseMarshaller.ContextualSerializer, + method.RequestMarshaller.ContextualDeserializer, newRpc.Server); asyncCall.Initialize(newRpc.Call, cq); @@ -179,8 +179,8 @@ namespace Grpc.Core.Internal public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq) { var asyncCall = new AsyncCallServer<TRequest, TResponse>( - method.ResponseMarshaller.Serializer, - method.RequestMarshaller.Deserializer, + method.ResponseMarshaller.ContextualSerializer, + method.RequestMarshaller.ContextualDeserializer, newRpc.Server); asyncCall.Initialize(newRpc.Call, cq); @@ -242,8 +242,8 @@ namespace Grpc.Core.Internal public async Task HandleCall(ServerRpcNew newRpc, CompletionQueueSafeHandle cq) { var asyncCall = new AsyncCallServer<TRequest, TResponse>( - method.ResponseMarshaller.Serializer, - method.RequestMarshaller.Deserializer, + method.ResponseMarshaller.ContextualSerializer, + method.RequestMarshaller.ContextualDeserializer, newRpc.Server); asyncCall.Initialize(newRpc.Call, cq); diff --git a/src/csharp/Grpc.Core/Marshaller.cs b/src/csharp/Grpc.Core/Marshaller.cs index 0af9aa586b..34a1849cd7 100644 --- a/src/csharp/Grpc.Core/Marshaller.cs +++ b/src/csharp/Grpc.Core/Marshaller.cs @@ -41,6 +41,8 @@ namespace Grpc.Core { this.serializer = GrpcPreconditions.CheckNotNull(serializer, nameof(serializer)); this.deserializer = GrpcPreconditions.CheckNotNull(deserializer, nameof(deserializer)); + // contextual serialization/deserialization is emulated to make the marshaller + // usable with the grpc library (required for backward compatibility). this.contextualSerializer = EmulateContextualSerializer; this.contextualDeserializer = EmulateContextualDeserializer; } @@ -57,10 +59,10 @@ namespace Grpc.Core { this.contextualSerializer = GrpcPreconditions.CheckNotNull(serializer, nameof(serializer)); this.contextualDeserializer = GrpcPreconditions.CheckNotNull(deserializer, nameof(deserializer)); - // TODO(jtattermusch): once gRPC C# library switches to using contextual (de)serializer, - // emulating the simple (de)serializer will become unnecessary. - this.serializer = EmulateSimpleSerializer; - this.deserializer = EmulateSimpleDeserializer; + // gRPC only uses contextual serializer/deserializer internally, so emulating the legacy + // (de)serializer is not necessary. + this.serializer = (msg) => { throw new NotImplementedException(); }; + this.deserializer = (payload) => { throw new NotImplementedException(); }; } /// <summary> @@ -85,25 +87,6 @@ namespace Grpc.Core /// </summary> public Func<DeserializationContext, T> ContextualDeserializer => this.contextualDeserializer; - // for backward compatibility, emulate the simple serializer using the contextual one - private byte[] EmulateSimpleSerializer(T msg) - { - // TODO(jtattermusch): avoid the allocation by passing a thread-local instance - // This code will become unnecessary once gRPC C# library switches to using contextual (de)serializer. - var context = new EmulatedSerializationContext(); - this.contextualSerializer(msg, context); - return context.GetPayload(); - } - - // for backward compatibility, emulate the simple deserializer using the contextual one - private T EmulateSimpleDeserializer(byte[] payload) - { - // TODO(jtattermusch): avoid the allocation by passing a thread-local instance - // This code will become unnecessary once gRPC C# library switches to using contextual (de)serializer. - var context = new EmulatedDeserializationContext(payload); - return this.contextualDeserializer(context); - } - // for backward compatibility, emulate the contextual serializer using the simple one private void EmulateContextualSerializer(T message, SerializationContext context) { @@ -116,44 +99,6 @@ namespace Grpc.Core { return this.deserializer(context.PayloadAsNewBuffer()); } - - internal class EmulatedSerializationContext : SerializationContext - { - bool isComplete; - byte[] payload; - - public override void Complete(byte[] payload) - { - GrpcPreconditions.CheckState(!isComplete); - this.isComplete = true; - this.payload = payload; - } - - internal byte[] GetPayload() - { - return this.payload; - } - } - - internal class EmulatedDeserializationContext : DeserializationContext - { - readonly byte[] payload; - bool alreadyCalledPayloadAsNewBuffer; - - public EmulatedDeserializationContext(byte[] payload) - { - this.payload = GrpcPreconditions.CheckNotNull(payload); - } - - public override int PayloadLength => payload.Length; - - public override byte[] PayloadAsNewBuffer() - { - GrpcPreconditions.CheckState(!alreadyCalledPayloadAsNewBuffer); - alreadyCalledPayloadAsNewBuffer = true; - return payload; - } - } } /// <summary> diff --git a/src/csharp/Grpc.Core/SerializationContext.cs b/src/csharp/Grpc.Core/SerializationContext.cs index cf4d1595da..9aef2adbcd 100644 --- a/src/csharp/Grpc.Core/SerializationContext.cs +++ b/src/csharp/Grpc.Core/SerializationContext.cs @@ -16,6 +16,8 @@ #endregion +using System; + namespace Grpc.Core { /// <summary> @@ -29,6 +31,9 @@ namespace Grpc.Core /// payload which must not be accessed afterwards. /// </summary> /// <param name="payload">the serialized form of current message</param> - public abstract void Complete(byte[] payload); + public virtual void Complete(byte[] payload) + { + throw new NotImplementedException(); + } } } diff --git a/src/csharp/Grpc.Core/ServerServiceDefinition.cs b/src/csharp/Grpc.Core/ServerServiceDefinition.cs index 07c6aa1796..b040ab379c 100644 --- a/src/csharp/Grpc.Core/ServerServiceDefinition.cs +++ b/src/csharp/Grpc.Core/ServerServiceDefinition.cs @@ -72,7 +72,7 @@ namespace Grpc.Core } /// <summary> - /// Adds a definitions for a single request - single response method. + /// Adds a definition for a single request - single response method. /// </summary> /// <typeparam name="TRequest">The request message class.</typeparam> /// <typeparam name="TResponse">The response message class.</typeparam> @@ -90,7 +90,7 @@ namespace Grpc.Core } /// <summary> - /// Adds a definitions for a client streaming method. + /// Adds a definition for a client streaming method. /// </summary> /// <typeparam name="TRequest">The request message class.</typeparam> /// <typeparam name="TResponse">The response message class.</typeparam> @@ -108,7 +108,7 @@ namespace Grpc.Core } /// <summary> - /// Adds a definitions for a server streaming method. + /// Adds a definition for a server streaming method. /// </summary> /// <typeparam name="TRequest">The request message class.</typeparam> /// <typeparam name="TResponse">The response message class.</typeparam> @@ -126,7 +126,7 @@ namespace Grpc.Core } /// <summary> - /// Adds a definitions for a bidirectional streaming method. + /// Adds a definition for a bidirectional streaming method. /// </summary> /// <typeparam name="TRequest">The request message class.</typeparam> /// <typeparam name="TResponse">The response message class.</typeparam> diff --git a/src/csharp/Grpc.Core/ServiceBinderBase.cs b/src/csharp/Grpc.Core/ServiceBinderBase.cs new file mode 100644 index 0000000000..d4909f4a26 --- /dev/null +++ b/src/csharp/Grpc.Core/ServiceBinderBase.cs @@ -0,0 +1,101 @@ +#region Copyright notice and license + +// Copyright 2018 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using Grpc.Core.Interceptors; +using Grpc.Core.Internal; +using Grpc.Core.Utils; + +namespace Grpc.Core +{ + /// <summary> + /// Allows binding server-side method implementations in alternative serving stacks. + /// Instances of this class are usually populated by the <c>BindService</c> method + /// that is part of the autogenerated code for a protocol buffers service definition. + /// <seealso cref="ServerServiceDefinition"/> + /// </summary> + public class ServiceBinderBase + { + /// <summary> + /// Adds a definition for a single request - single response method. + /// </summary> + /// <typeparam name="TRequest">The request message class.</typeparam> + /// <typeparam name="TResponse">The response message class.</typeparam> + /// <param name="method">The method.</param> + /// <param name="handler">The method handler.</param> + public virtual void AddMethod<TRequest, TResponse>( + Method<TRequest, TResponse> method, + UnaryServerMethod<TRequest, TResponse> handler) + where TRequest : class + where TResponse : class + { + throw new NotImplementedException(); + } + + /// <summary> + /// Adds a definition for a client streaming method. + /// </summary> + /// <typeparam name="TRequest">The request message class.</typeparam> + /// <typeparam name="TResponse">The response message class.</typeparam> + /// <param name="method">The method.</param> + /// <param name="handler">The method handler.</param> + public virtual void AddMethod<TRequest, TResponse>( + Method<TRequest, TResponse> method, + ClientStreamingServerMethod<TRequest, TResponse> handler) + where TRequest : class + where TResponse : class + { + throw new NotImplementedException(); + } + + /// <summary> + /// Adds a definition for a server streaming method. + /// </summary> + /// <typeparam name="TRequest">The request message class.</typeparam> + /// <typeparam name="TResponse">The response message class.</typeparam> + /// <param name="method">The method.</param> + /// <param name="handler">The method handler.</param> + public virtual void AddMethod<TRequest, TResponse>( + Method<TRequest, TResponse> method, + ServerStreamingServerMethod<TRequest, TResponse> handler) + where TRequest : class + where TResponse : class + { + throw new NotImplementedException(); + } + + /// <summary> + /// Adds a definition for a bidirectional streaming method. + /// </summary> + /// <typeparam name="TRequest">The request message class.</typeparam> + /// <typeparam name="TResponse">The response message class.</typeparam> + /// <param name="method">The method.</param> + /// <param name="handler">The method handler.</param> + public virtual void AddMethod<TRequest, TResponse>( + Method<TRequest, TResponse> method, + DuplexStreamingServerMethod<TRequest, TResponse> handler) + where TRequest : class + where TResponse : class + { + throw new NotImplementedException(); + } + } +} diff --git a/src/csharp/Grpc.Core/Version.csproj.include b/src/csharp/Grpc.Core/Version.csproj.include index ed0d884365..4fffe4f644 100755 --- a/src/csharp/Grpc.Core/Version.csproj.include +++ b/src/csharp/Grpc.Core/Version.csproj.include @@ -1,7 +1,7 @@ <!-- This file is generated --> <Project> <PropertyGroup> - <GrpcCsharpVersion>1.17.0-dev</GrpcCsharpVersion> + <GrpcCsharpVersion>1.18.0-dev</GrpcCsharpVersion> <GoogleProtobufVersion>3.6.1</GoogleProtobufVersion> </PropertyGroup> </Project> diff --git a/src/csharp/Grpc.Core/VersionInfo.cs b/src/csharp/Grpc.Core/VersionInfo.cs index 14714c8c4a..633880189c 100644 --- a/src/csharp/Grpc.Core/VersionInfo.cs +++ b/src/csharp/Grpc.Core/VersionInfo.cs @@ -33,11 +33,11 @@ namespace Grpc.Core /// <summary> /// Current <c>AssemblyFileVersion</c> of gRPC C# assemblies /// </summary> - public const string CurrentAssemblyFileVersion = "1.17.0.0"; + public const string CurrentAssemblyFileVersion = "1.18.0.0"; /// <summary> /// Current version of gRPC C# /// </summary> - public const string CurrentVersion = "1.17.0-dev"; + public const string CurrentVersion = "1.18.0-dev"; } } diff --git a/src/csharp/Grpc.Examples.MathClient/Grpc.Examples.MathClient.csproj b/src/csharp/Grpc.Examples.MathClient/Grpc.Examples.MathClient.csproj index db4e3ef4e3..1afcd9fba0 100755 --- a/src/csharp/Grpc.Examples.MathClient/Grpc.Examples.MathClient.csproj +++ b/src/csharp/Grpc.Examples.MathClient/Grpc.Examples.MathClient.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Examples.MathClient</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Examples.MathClient</PackageId> diff --git a/src/csharp/Grpc.Examples.MathServer/Grpc.Examples.MathServer.csproj b/src/csharp/Grpc.Examples.MathServer/Grpc.Examples.MathServer.csproj index b12b418d01..75ef6d1008 100755 --- a/src/csharp/Grpc.Examples.MathServer/Grpc.Examples.MathServer.csproj +++ b/src/csharp/Grpc.Examples.MathServer/Grpc.Examples.MathServer.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Examples.MathServer</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Examples.MathServer</PackageId> diff --git a/src/csharp/Grpc.Examples.Tests/Grpc.Examples.Tests.csproj b/src/csharp/Grpc.Examples.Tests/Grpc.Examples.Tests.csproj index 7493eb8051..93d112a0c5 100755 --- a/src/csharp/Grpc.Examples.Tests/Grpc.Examples.Tests.csproj +++ b/src/csharp/Grpc.Examples.Tests/Grpc.Examples.Tests.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Examples.Tests</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Examples.Tests</PackageId> diff --git a/src/csharp/Grpc.Examples.Tests/NUnitMain.cs b/src/csharp/Grpc.Examples.Tests/NUnitMain.cs index bcb8b46b64..107df64809 100644 --- a/src/csharp/Grpc.Examples.Tests/NUnitMain.cs +++ b/src/csharp/Grpc.Examples.Tests/NUnitMain.cs @@ -34,11 +34,7 @@ namespace Grpc.Examples.Tests { // Make logger immune to NUnit capturing stdout and stderr to workaround https://github.com/nunit/nunit/issues/1406. GrpcEnvironment.SetLogger(new ConsoleLogger()); -#if NETCOREAPP1_0 - return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args, new ExtendedTextWrapper(Console.Out), Console.In); -#else - return new AutoRun().Execute(args); -#endif + return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); } } } diff --git a/src/csharp/Grpc.Examples/Grpc.Examples.csproj b/src/csharp/Grpc.Examples/Grpc.Examples.csproj index baa3b4ce6c..9ce2b59d03 100755 --- a/src/csharp/Grpc.Examples/Grpc.Examples.csproj +++ b/src/csharp/Grpc.Examples/Grpc.Examples.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Examples</AssemblyName> <PackageId>Grpc.Examples</PackageId> <TreatWarningsAsErrors>true</TreatWarningsAsErrors> diff --git a/src/csharp/Grpc.Examples/MathGrpc.cs b/src/csharp/Grpc.Examples/MathGrpc.cs index 9578bb4d81..e5be387e67 100644 --- a/src/csharp/Grpc.Examples/MathGrpc.cs +++ b/src/csharp/Grpc.Examples/MathGrpc.cs @@ -287,6 +287,18 @@ namespace Math { .AddMethod(__Method_Sum, serviceImpl.Sum).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, MathBase serviceImpl) + { + serviceBinder.AddMethod(__Method_Div, serviceImpl.Div); + serviceBinder.AddMethod(__Method_DivMany, serviceImpl.DivMany); + serviceBinder.AddMethod(__Method_Fib, serviceImpl.Fib); + serviceBinder.AddMethod(__Method_Sum, serviceImpl.Sum); + } + } } #endregion diff --git a/src/csharp/Grpc.HealthCheck.Tests/Grpc.HealthCheck.Tests.csproj b/src/csharp/Grpc.HealthCheck.Tests/Grpc.HealthCheck.Tests.csproj index 616e56df10..2a037a72e5 100755 --- a/src/csharp/Grpc.HealthCheck.Tests/Grpc.HealthCheck.Tests.csproj +++ b/src/csharp/Grpc.HealthCheck.Tests/Grpc.HealthCheck.Tests.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.HealthCheck.Tests</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.HealthCheck.Tests</PackageId> diff --git a/src/csharp/Grpc.HealthCheck.Tests/NUnitMain.cs b/src/csharp/Grpc.HealthCheck.Tests/NUnitMain.cs index 365551e895..db6d32a5b2 100644 --- a/src/csharp/Grpc.HealthCheck.Tests/NUnitMain.cs +++ b/src/csharp/Grpc.HealthCheck.Tests/NUnitMain.cs @@ -34,11 +34,7 @@ namespace Grpc.HealthCheck.Tests { // Make logger immune to NUnit capturing stdout and stderr to workaround https://github.com/nunit/nunit/issues/1406. GrpcEnvironment.SetLogger(new ConsoleLogger()); -#if NETCOREAPP1_0 - return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args, new ExtendedTextWrapper(Console.Out), Console.In); -#else - return new AutoRun().Execute(args); -#endif + return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); } } } diff --git a/src/csharp/Grpc.HealthCheck/Health.cs b/src/csharp/Grpc.HealthCheck/Health.cs index a90f261d28..2c3bb45c3c 100644 --- a/src/csharp/Grpc.HealthCheck/Health.cs +++ b/src/csharp/Grpc.HealthCheck/Health.cs @@ -25,15 +25,17 @@ namespace Grpc.Health.V1 { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "ChtncnBjL2hlYWx0aC92MS9oZWFsdGgucHJvdG8SDmdycGMuaGVhbHRoLnYx", - "IiUKEkhlYWx0aENoZWNrUmVxdWVzdBIPCgdzZXJ2aWNlGAEgASgJIpQBChNI", + "IiUKEkhlYWx0aENoZWNrUmVxdWVzdBIPCgdzZXJ2aWNlGAEgASgJIqkBChNI", "ZWFsdGhDaGVja1Jlc3BvbnNlEkEKBnN0YXR1cxgBIAEoDjIxLmdycGMuaGVh", - "bHRoLnYxLkhlYWx0aENoZWNrUmVzcG9uc2UuU2VydmluZ1N0YXR1cyI6Cg1T", + "bHRoLnYxLkhlYWx0aENoZWNrUmVzcG9uc2UuU2VydmluZ1N0YXR1cyJPCg1T", "ZXJ2aW5nU3RhdHVzEgsKB1VOS05PV04QABILCgdTRVJWSU5HEAESDwoLTk9U", - "X1NFUlZJTkcQAjJaCgZIZWFsdGgSUAoFQ2hlY2sSIi5ncnBjLmhlYWx0aC52", - "MS5IZWFsdGhDaGVja1JlcXVlc3QaIy5ncnBjLmhlYWx0aC52MS5IZWFsdGhD", - "aGVja1Jlc3BvbnNlQmEKEWlvLmdycGMuaGVhbHRoLnYxQgtIZWFsdGhQcm90", - "b1ABWixnb29nbGUuZ29sYW5nLm9yZy9ncnBjL2hlYWx0aC9ncnBjX2hlYWx0", - "aF92MaoCDkdycGMuSGVhbHRoLlYxYgZwcm90bzM=")); + "X1NFUlZJTkcQAhITCg9TRVJWSUNFX1VOS05PV04QAzKuAQoGSGVhbHRoElAK", + "BUNoZWNrEiIuZ3JwYy5oZWFsdGgudjEuSGVhbHRoQ2hlY2tSZXF1ZXN0GiMu", + "Z3JwYy5oZWFsdGgudjEuSGVhbHRoQ2hlY2tSZXNwb25zZRJSCgVXYXRjaBIi", + "LmdycGMuaGVhbHRoLnYxLkhlYWx0aENoZWNrUmVxdWVzdBojLmdycGMuaGVh", + "bHRoLnYxLkhlYWx0aENoZWNrUmVzcG9uc2UwAUJhChFpby5ncnBjLmhlYWx0", + "aC52MUILSGVhbHRoUHJvdG9QAVosZ29vZ2xlLmdvbGFuZy5vcmcvZ3JwYy9o", + "ZWFsdGgvZ3JwY19oZWFsdGhfdjGqAg5HcnBjLkhlYWx0aC5WMWIGcHJvdG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { @@ -309,6 +311,10 @@ namespace Grpc.Health.V1 { [pbr::OriginalName("UNKNOWN")] Unknown = 0, [pbr::OriginalName("SERVING")] Serving = 1, [pbr::OriginalName("NOT_SERVING")] NotServing = 2, + /// <summary> + /// Used only by the Watch method. + /// </summary> + [pbr::OriginalName("SERVICE_UNKNOWN")] ServiceUnknown = 3, } } diff --git a/src/csharp/Grpc.HealthCheck/HealthGrpc.cs b/src/csharp/Grpc.HealthCheck/HealthGrpc.cs index 5e79c04d2a..51956f2f23 100644 --- a/src/csharp/Grpc.HealthCheck/HealthGrpc.cs +++ b/src/csharp/Grpc.HealthCheck/HealthGrpc.cs @@ -40,6 +40,13 @@ namespace Grpc.Health.V1 { __Marshaller_grpc_health_v1_HealthCheckRequest, __Marshaller_grpc_health_v1_HealthCheckResponse); + static readonly grpc::Method<global::Grpc.Health.V1.HealthCheckRequest, global::Grpc.Health.V1.HealthCheckResponse> __Method_Watch = new grpc::Method<global::Grpc.Health.V1.HealthCheckRequest, global::Grpc.Health.V1.HealthCheckResponse>( + grpc::MethodType.ServerStreaming, + __ServiceName, + "Watch", + __Marshaller_grpc_health_v1_HealthCheckRequest, + __Marshaller_grpc_health_v1_HealthCheckResponse); + /// <summary>Service descriptor</summary> public static global::Google.Protobuf.Reflection.ServiceDescriptor Descriptor { @@ -49,11 +56,44 @@ namespace Grpc.Health.V1 { /// <summary>Base class for server-side implementations of Health</summary> public abstract partial class HealthBase { + /// <summary> + /// If the requested service is unknown, the call will fail with status + /// NOT_FOUND. + /// </summary> + /// <param name="request">The request received from the client.</param> + /// <param name="context">The context of the server-side call handler being invoked.</param> + /// <returns>The response to send back to the client (wrapped by a task).</returns> public virtual global::System.Threading.Tasks.Task<global::Grpc.Health.V1.HealthCheckResponse> Check(global::Grpc.Health.V1.HealthCheckRequest request, grpc::ServerCallContext context) { throw new grpc::RpcException(new grpc::Status(grpc::StatusCode.Unimplemented, "")); } + /// <summary> + /// Performs a watch for the serving status of the requested service. + /// The server will immediately send back a message indicating the current + /// serving status. It will then subsequently send a new message whenever + /// the service's serving status changes. + /// + /// If the requested service is unknown when the call is received, the + /// server will send a message setting the serving status to + /// SERVICE_UNKNOWN but will *not* terminate the call. If at some + /// future point, the serving status of the service becomes known, the + /// server will send a new message with the service's serving status. + /// + /// If the call terminates with status UNIMPLEMENTED, then clients + /// should assume this method is not supported and should not retry the + /// call. If the call terminates with any other status (including OK), + /// clients should retry the call with appropriate exponential backoff. + /// </summary> + /// <param name="request">The request received from the client.</param> + /// <param name="responseStream">Used for sending responses back to the client.</param> + /// <param name="context">The context of the server-side call handler being invoked.</param> + /// <returns>A task indicating completion of the handler.</returns> + public virtual global::System.Threading.Tasks.Task Watch(global::Grpc.Health.V1.HealthCheckRequest request, grpc::IServerStreamWriter<global::Grpc.Health.V1.HealthCheckResponse> responseStream, grpc::ServerCallContext context) + { + throw new grpc::RpcException(new grpc::Status(grpc::StatusCode.Unimplemented, "")); + } + } /// <summary>Client for Health</summary> @@ -79,22 +119,104 @@ namespace Grpc.Health.V1 { { } + /// <summary> + /// If the requested service is unknown, the call will fail with status + /// NOT_FOUND. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="headers">The initial metadata to send with the call. This parameter is optional.</param> + /// <param name="deadline">An optional deadline for the call. The call will be cancelled if deadline is hit.</param> + /// <param name="cancellationToken">An optional token for canceling the call.</param> + /// <returns>The response received from the server.</returns> public virtual global::Grpc.Health.V1.HealthCheckResponse Check(global::Grpc.Health.V1.HealthCheckRequest request, grpc::Metadata headers = null, global::System.DateTime? deadline = null, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken)) { return Check(request, new grpc::CallOptions(headers, deadline, cancellationToken)); } + /// <summary> + /// If the requested service is unknown, the call will fail with status + /// NOT_FOUND. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="options">The options for the call.</param> + /// <returns>The response received from the server.</returns> public virtual global::Grpc.Health.V1.HealthCheckResponse Check(global::Grpc.Health.V1.HealthCheckRequest request, grpc::CallOptions options) { return CallInvoker.BlockingUnaryCall(__Method_Check, null, options, request); } + /// <summary> + /// If the requested service is unknown, the call will fail with status + /// NOT_FOUND. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="headers">The initial metadata to send with the call. This parameter is optional.</param> + /// <param name="deadline">An optional deadline for the call. The call will be cancelled if deadline is hit.</param> + /// <param name="cancellationToken">An optional token for canceling the call.</param> + /// <returns>The call object.</returns> public virtual grpc::AsyncUnaryCall<global::Grpc.Health.V1.HealthCheckResponse> CheckAsync(global::Grpc.Health.V1.HealthCheckRequest request, grpc::Metadata headers = null, global::System.DateTime? deadline = null, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken)) { return CheckAsync(request, new grpc::CallOptions(headers, deadline, cancellationToken)); } + /// <summary> + /// If the requested service is unknown, the call will fail with status + /// NOT_FOUND. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="options">The options for the call.</param> + /// <returns>The call object.</returns> public virtual grpc::AsyncUnaryCall<global::Grpc.Health.V1.HealthCheckResponse> CheckAsync(global::Grpc.Health.V1.HealthCheckRequest request, grpc::CallOptions options) { return CallInvoker.AsyncUnaryCall(__Method_Check, null, options, request); } + /// <summary> + /// Performs a watch for the serving status of the requested service. + /// The server will immediately send back a message indicating the current + /// serving status. It will then subsequently send a new message whenever + /// the service's serving status changes. + /// + /// If the requested service is unknown when the call is received, the + /// server will send a message setting the serving status to + /// SERVICE_UNKNOWN but will *not* terminate the call. If at some + /// future point, the serving status of the service becomes known, the + /// server will send a new message with the service's serving status. + /// + /// If the call terminates with status UNIMPLEMENTED, then clients + /// should assume this method is not supported and should not retry the + /// call. If the call terminates with any other status (including OK), + /// clients should retry the call with appropriate exponential backoff. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="headers">The initial metadata to send with the call. This parameter is optional.</param> + /// <param name="deadline">An optional deadline for the call. The call will be cancelled if deadline is hit.</param> + /// <param name="cancellationToken">An optional token for canceling the call.</param> + /// <returns>The call object.</returns> + public virtual grpc::AsyncServerStreamingCall<global::Grpc.Health.V1.HealthCheckResponse> Watch(global::Grpc.Health.V1.HealthCheckRequest request, grpc::Metadata headers = null, global::System.DateTime? deadline = null, global::System.Threading.CancellationToken cancellationToken = default(global::System.Threading.CancellationToken)) + { + return Watch(request, new grpc::CallOptions(headers, deadline, cancellationToken)); + } + /// <summary> + /// Performs a watch for the serving status of the requested service. + /// The server will immediately send back a message indicating the current + /// serving status. It will then subsequently send a new message whenever + /// the service's serving status changes. + /// + /// If the requested service is unknown when the call is received, the + /// server will send a message setting the serving status to + /// SERVICE_UNKNOWN but will *not* terminate the call. If at some + /// future point, the serving status of the service becomes known, the + /// server will send a new message with the service's serving status. + /// + /// If the call terminates with status UNIMPLEMENTED, then clients + /// should assume this method is not supported and should not retry the + /// call. If the call terminates with any other status (including OK), + /// clients should retry the call with appropriate exponential backoff. + /// </summary> + /// <param name="request">The request to send to the server.</param> + /// <param name="options">The options for the call.</param> + /// <returns>The call object.</returns> + public virtual grpc::AsyncServerStreamingCall<global::Grpc.Health.V1.HealthCheckResponse> Watch(global::Grpc.Health.V1.HealthCheckRequest request, grpc::CallOptions options) + { + return CallInvoker.AsyncServerStreamingCall(__Method_Watch, null, options, request); + } /// <summary>Creates a new instance of client from given <c>ClientBaseConfiguration</c>.</summary> protected override HealthClient NewInstance(ClientBaseConfiguration configuration) { @@ -107,7 +229,18 @@ namespace Grpc.Health.V1 { public static grpc::ServerServiceDefinition BindService(HealthBase serviceImpl) { return grpc::ServerServiceDefinition.CreateBuilder() - .AddMethod(__Method_Check, serviceImpl.Check).Build(); + .AddMethod(__Method_Check, serviceImpl.Check) + .AddMethod(__Method_Watch, serviceImpl.Watch).Build(); + } + + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, HealthBase serviceImpl) + { + serviceBinder.AddMethod(__Method_Check, serviceImpl.Check); + serviceBinder.AddMethod(__Method_Watch, serviceImpl.Watch); } } diff --git a/src/csharp/Grpc.IntegrationTesting.Client/Grpc.IntegrationTesting.Client.csproj b/src/csharp/Grpc.IntegrationTesting.Client/Grpc.IntegrationTesting.Client.csproj index 35713156ea..1cd4b83e1e 100755 --- a/src/csharp/Grpc.IntegrationTesting.Client/Grpc.IntegrationTesting.Client.csproj +++ b/src/csharp/Grpc.IntegrationTesting.Client/Grpc.IntegrationTesting.Client.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.IntegrationTesting.Client</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.IntegrationTesting.Client</PackageId> diff --git a/src/csharp/Grpc.IntegrationTesting.QpsWorker/Grpc.IntegrationTesting.QpsWorker.csproj b/src/csharp/Grpc.IntegrationTesting.QpsWorker/Grpc.IntegrationTesting.QpsWorker.csproj index 3ecefe3bc4..2890a7df58 100755 --- a/src/csharp/Grpc.IntegrationTesting.QpsWorker/Grpc.IntegrationTesting.QpsWorker.csproj +++ b/src/csharp/Grpc.IntegrationTesting.QpsWorker/Grpc.IntegrationTesting.QpsWorker.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.IntegrationTesting.QpsWorker</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.IntegrationTesting.QpsWorker</PackageId> diff --git a/src/csharp/Grpc.IntegrationTesting.Server/Grpc.IntegrationTesting.Server.csproj b/src/csharp/Grpc.IntegrationTesting.Server/Grpc.IntegrationTesting.Server.csproj index 1092b2c21e..ee718958bc 100755 --- a/src/csharp/Grpc.IntegrationTesting.Server/Grpc.IntegrationTesting.Server.csproj +++ b/src/csharp/Grpc.IntegrationTesting.Server/Grpc.IntegrationTesting.Server.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.IntegrationTesting.Server</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.IntegrationTesting.Server</PackageId> @@ -19,7 +19,7 @@ <Reference Include="System" /> <Reference Include="Microsoft.CSharp" /> </ItemGroup> - + <ItemGroup> <Compile Include="..\Grpc.Core\Version.cs" /> </ItemGroup> diff --git a/src/csharp/Grpc.IntegrationTesting.StressClient/Grpc.IntegrationTesting.StressClient.csproj b/src/csharp/Grpc.IntegrationTesting.StressClient/Grpc.IntegrationTesting.StressClient.csproj index 22272547f6..99926497e4 100755 --- a/src/csharp/Grpc.IntegrationTesting.StressClient/Grpc.IntegrationTesting.StressClient.csproj +++ b/src/csharp/Grpc.IntegrationTesting.StressClient/Grpc.IntegrationTesting.StressClient.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.IntegrationTesting.StressClient</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.IntegrationTesting.StressClient</PackageId> diff --git a/src/csharp/Grpc.IntegrationTesting/BenchmarkServiceGrpc.cs b/src/csharp/Grpc.IntegrationTesting/BenchmarkServiceGrpc.cs index b5738593f2..3431b5fa18 100644 --- a/src/csharp/Grpc.IntegrationTesting/BenchmarkServiceGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/BenchmarkServiceGrpc.cs @@ -324,6 +324,19 @@ namespace Grpc.Testing { .AddMethod(__Method_StreamingBothWays, serviceImpl.StreamingBothWays).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, BenchmarkServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_UnaryCall, serviceImpl.UnaryCall); + serviceBinder.AddMethod(__Method_StreamingCall, serviceImpl.StreamingCall); + serviceBinder.AddMethod(__Method_StreamingFromClient, serviceImpl.StreamingFromClient); + serviceBinder.AddMethod(__Method_StreamingFromServer, serviceImpl.StreamingFromServer); + serviceBinder.AddMethod(__Method_StreamingBothWays, serviceImpl.StreamingBothWays); + } + } } #endregion diff --git a/src/csharp/Grpc.IntegrationTesting/Control.cs b/src/csharp/Grpc.IntegrationTesting/Control.cs index 6e00348451..368b86659a 100644 --- a/src/csharp/Grpc.IntegrationTesting/Control.cs +++ b/src/csharp/Grpc.IntegrationTesting/Control.cs @@ -34,7 +34,7 @@ namespace Grpc.Testing { "U2VjdXJpdHlQYXJhbXMSEwoLdXNlX3Rlc3RfY2EYASABKAgSHAoUc2VydmVy", "X2hvc3Rfb3ZlcnJpZGUYAiABKAkSEQoJY3JlZF90eXBlGAMgASgJIk0KCkNo", "YW5uZWxBcmcSDAoEbmFtZRgBIAEoCRITCglzdHJfdmFsdWUYAiABKAlIABIT", - "CglpbnRfdmFsdWUYAyABKAVIAEIHCgV2YWx1ZSLvBAoMQ2xpZW50Q29uZmln", + "CglpbnRfdmFsdWUYAyABKAVIAEIHCgV2YWx1ZSKiBQoMQ2xpZW50Q29uZmln", "EhYKDnNlcnZlcl90YXJnZXRzGAEgAygJEi0KC2NsaWVudF90eXBlGAIgASgO", "MhguZ3JwYy50ZXN0aW5nLkNsaWVudFR5cGUSNQoPc2VjdXJpdHlfcGFyYW1z", "GAMgASgLMhwuZ3JwYy50ZXN0aW5nLlNlY3VyaXR5UGFyYW1zEiQKHG91dHN0", @@ -48,59 +48,60 @@ namespace Grpc.Testing { "GA4gASgFEhgKEG90aGVyX2NsaWVudF9hcGkYDyABKAkSLgoMY2hhbm5lbF9h", "cmdzGBAgAygLMhguZ3JwYy50ZXN0aW5nLkNoYW5uZWxBcmcSFgoOdGhyZWFk", "c19wZXJfY3EYESABKAUSGwoTbWVzc2FnZXNfcGVyX3N0cmVhbRgSIAEoBRIY", - "ChB1c2VfY29hbGVzY2VfYXBpGBMgASgIIjgKDENsaWVudFN0YXR1cxIoCgVz", - "dGF0cxgBIAEoCzIZLmdycGMudGVzdGluZy5DbGllbnRTdGF0cyIVCgRNYXJr", - "Eg0KBXJlc2V0GAEgASgIImgKCkNsaWVudEFyZ3MSKwoFc2V0dXAYASABKAsy", - "Gi5ncnBjLnRlc3RpbmcuQ2xpZW50Q29uZmlnSAASIgoEbWFyaxgCIAEoCzIS", - "LmdycGMudGVzdGluZy5NYXJrSABCCQoHYXJndHlwZSL9AgoMU2VydmVyQ29u", - "ZmlnEi0KC3NlcnZlcl90eXBlGAEgASgOMhguZ3JwYy50ZXN0aW5nLlNlcnZl", - "clR5cGUSNQoPc2VjdXJpdHlfcGFyYW1zGAIgASgLMhwuZ3JwYy50ZXN0aW5n", - "LlNlY3VyaXR5UGFyYW1zEgwKBHBvcnQYBCABKAUSHAoUYXN5bmNfc2VydmVy", - "X3RocmVhZHMYByABKAUSEgoKY29yZV9saW1pdBgIIAEoBRIzCg5wYXlsb2Fk", - "X2NvbmZpZxgJIAEoCzIbLmdycGMudGVzdGluZy5QYXlsb2FkQ29uZmlnEhEK", - "CWNvcmVfbGlzdBgKIAMoBRIYChBvdGhlcl9zZXJ2ZXJfYXBpGAsgASgJEhYK", - "DnRocmVhZHNfcGVyX2NxGAwgASgFEhwKE3Jlc291cmNlX3F1b3RhX3NpemUY", - "6QcgASgFEi8KDGNoYW5uZWxfYXJncxjqByADKAsyGC5ncnBjLnRlc3Rpbmcu", - "Q2hhbm5lbEFyZyJoCgpTZXJ2ZXJBcmdzEisKBXNldHVwGAEgASgLMhouZ3Jw", - "Yy50ZXN0aW5nLlNlcnZlckNvbmZpZ0gAEiIKBG1hcmsYAiABKAsyEi5ncnBj", - "LnRlc3RpbmcuTWFya0gAQgkKB2FyZ3R5cGUiVQoMU2VydmVyU3RhdHVzEigK", - "BXN0YXRzGAEgASgLMhkuZ3JwYy50ZXN0aW5nLlNlcnZlclN0YXRzEgwKBHBv", - "cnQYAiABKAUSDQoFY29yZXMYAyABKAUiDQoLQ29yZVJlcXVlc3QiHQoMQ29y", - "ZVJlc3BvbnNlEg0KBWNvcmVzGAEgASgFIgYKBFZvaWQi/QEKCFNjZW5hcmlv", - "EgwKBG5hbWUYASABKAkSMQoNY2xpZW50X2NvbmZpZxgCIAEoCzIaLmdycGMu", - "dGVzdGluZy5DbGllbnRDb25maWcSEwoLbnVtX2NsaWVudHMYAyABKAUSMQoN", - "c2VydmVyX2NvbmZpZxgEIAEoCzIaLmdycGMudGVzdGluZy5TZXJ2ZXJDb25m", - "aWcSEwoLbnVtX3NlcnZlcnMYBSABKAUSFgoOd2FybXVwX3NlY29uZHMYBiAB", - "KAUSGQoRYmVuY2htYXJrX3NlY29uZHMYByABKAUSIAoYc3Bhd25fbG9jYWxf", - "d29ya2VyX2NvdW50GAggASgFIjYKCVNjZW5hcmlvcxIpCglzY2VuYXJpb3MY", - "ASADKAsyFi5ncnBjLnRlc3RpbmcuU2NlbmFyaW8ihAQKFVNjZW5hcmlvUmVz", - "dWx0U3VtbWFyeRILCgNxcHMYASABKAESGwoTcXBzX3Blcl9zZXJ2ZXJfY29y", - "ZRgCIAEoARIaChJzZXJ2ZXJfc3lzdGVtX3RpbWUYAyABKAESGAoQc2VydmVy", - "X3VzZXJfdGltZRgEIAEoARIaChJjbGllbnRfc3lzdGVtX3RpbWUYBSABKAES", - "GAoQY2xpZW50X3VzZXJfdGltZRgGIAEoARISCgpsYXRlbmN5XzUwGAcgASgB", - "EhIKCmxhdGVuY3lfOTAYCCABKAESEgoKbGF0ZW5jeV85NRgJIAEoARISCgps", - "YXRlbmN5Xzk5GAogASgBEhMKC2xhdGVuY3lfOTk5GAsgASgBEhgKEHNlcnZl", - "cl9jcHVfdXNhZ2UYDCABKAESJgoec3VjY2Vzc2Z1bF9yZXF1ZXN0c19wZXJf", - "c2Vjb25kGA0gASgBEiIKGmZhaWxlZF9yZXF1ZXN0c19wZXJfc2Vjb25kGA4g", - "ASgBEiAKGGNsaWVudF9wb2xsc19wZXJfcmVxdWVzdBgPIAEoARIgChhzZXJ2", - "ZXJfcG9sbHNfcGVyX3JlcXVlc3QYECABKAESIgoac2VydmVyX3F1ZXJpZXNf", - "cGVyX2NwdV9zZWMYESABKAESIgoaY2xpZW50X3F1ZXJpZXNfcGVyX2NwdV9z", - "ZWMYEiABKAEigwMKDlNjZW5hcmlvUmVzdWx0EigKCHNjZW5hcmlvGAEgASgL", - "MhYuZ3JwYy50ZXN0aW5nLlNjZW5hcmlvEi4KCWxhdGVuY2llcxgCIAEoCzIb", - "LmdycGMudGVzdGluZy5IaXN0b2dyYW1EYXRhEi8KDGNsaWVudF9zdGF0cxgD", - "IAMoCzIZLmdycGMudGVzdGluZy5DbGllbnRTdGF0cxIvCgxzZXJ2ZXJfc3Rh", - "dHMYBCADKAsyGS5ncnBjLnRlc3RpbmcuU2VydmVyU3RhdHMSFAoMc2VydmVy", - "X2NvcmVzGAUgAygFEjQKB3N1bW1hcnkYBiABKAsyIy5ncnBjLnRlc3Rpbmcu", - "U2NlbmFyaW9SZXN1bHRTdW1tYXJ5EhYKDmNsaWVudF9zdWNjZXNzGAcgAygI", - "EhYKDnNlcnZlcl9zdWNjZXNzGAggAygIEjkKD3JlcXVlc3RfcmVzdWx0cxgJ", - "IAMoCzIgLmdycGMudGVzdGluZy5SZXF1ZXN0UmVzdWx0Q291bnQqQQoKQ2xp", - "ZW50VHlwZRIPCgtTWU5DX0NMSUVOVBAAEhAKDEFTWU5DX0NMSUVOVBABEhAK", - "DE9USEVSX0NMSUVOVBACKlsKClNlcnZlclR5cGUSDwoLU1lOQ19TRVJWRVIQ", - "ABIQCgxBU1lOQ19TRVJWRVIQARIYChRBU1lOQ19HRU5FUklDX1NFUlZFUhAC", - "EhAKDE9USEVSX1NFUlZFUhADKnIKB1JwY1R5cGUSCQoFVU5BUlkQABINCglT", - "VFJFQU1JTkcQARIZChVTVFJFQU1JTkdfRlJPTV9DTElFTlQQAhIZChVTVFJF", - "QU1JTkdfRlJPTV9TRVJWRVIQAxIXChNTVFJFQU1JTkdfQk9USF9XQVlTEARi", - "BnByb3RvMw==")); + "ChB1c2VfY29hbGVzY2VfYXBpGBMgASgIEjEKKW1lZGlhbl9sYXRlbmN5X2Nv", + "bGxlY3Rpb25faW50ZXJ2YWxfbWlsbGlzGBQgASgFIjgKDENsaWVudFN0YXR1", + "cxIoCgVzdGF0cxgBIAEoCzIZLmdycGMudGVzdGluZy5DbGllbnRTdGF0cyIV", + "CgRNYXJrEg0KBXJlc2V0GAEgASgIImgKCkNsaWVudEFyZ3MSKwoFc2V0dXAY", + "ASABKAsyGi5ncnBjLnRlc3RpbmcuQ2xpZW50Q29uZmlnSAASIgoEbWFyaxgC", + "IAEoCzISLmdycGMudGVzdGluZy5NYXJrSABCCQoHYXJndHlwZSL9AgoMU2Vy", + "dmVyQ29uZmlnEi0KC3NlcnZlcl90eXBlGAEgASgOMhguZ3JwYy50ZXN0aW5n", + "LlNlcnZlclR5cGUSNQoPc2VjdXJpdHlfcGFyYW1zGAIgASgLMhwuZ3JwYy50", + "ZXN0aW5nLlNlY3VyaXR5UGFyYW1zEgwKBHBvcnQYBCABKAUSHAoUYXN5bmNf", + "c2VydmVyX3RocmVhZHMYByABKAUSEgoKY29yZV9saW1pdBgIIAEoBRIzCg5w", + "YXlsb2FkX2NvbmZpZxgJIAEoCzIbLmdycGMudGVzdGluZy5QYXlsb2FkQ29u", + "ZmlnEhEKCWNvcmVfbGlzdBgKIAMoBRIYChBvdGhlcl9zZXJ2ZXJfYXBpGAsg", + "ASgJEhYKDnRocmVhZHNfcGVyX2NxGAwgASgFEhwKE3Jlc291cmNlX3F1b3Rh", + "X3NpemUY6QcgASgFEi8KDGNoYW5uZWxfYXJncxjqByADKAsyGC5ncnBjLnRl", + "c3RpbmcuQ2hhbm5lbEFyZyJoCgpTZXJ2ZXJBcmdzEisKBXNldHVwGAEgASgL", + "MhouZ3JwYy50ZXN0aW5nLlNlcnZlckNvbmZpZ0gAEiIKBG1hcmsYAiABKAsy", + "Ei5ncnBjLnRlc3RpbmcuTWFya0gAQgkKB2FyZ3R5cGUiVQoMU2VydmVyU3Rh", + "dHVzEigKBXN0YXRzGAEgASgLMhkuZ3JwYy50ZXN0aW5nLlNlcnZlclN0YXRz", + "EgwKBHBvcnQYAiABKAUSDQoFY29yZXMYAyABKAUiDQoLQ29yZVJlcXVlc3Qi", + "HQoMQ29yZVJlc3BvbnNlEg0KBWNvcmVzGAEgASgFIgYKBFZvaWQi/QEKCFNj", + "ZW5hcmlvEgwKBG5hbWUYASABKAkSMQoNY2xpZW50X2NvbmZpZxgCIAEoCzIa", + "LmdycGMudGVzdGluZy5DbGllbnRDb25maWcSEwoLbnVtX2NsaWVudHMYAyAB", + "KAUSMQoNc2VydmVyX2NvbmZpZxgEIAEoCzIaLmdycGMudGVzdGluZy5TZXJ2", + "ZXJDb25maWcSEwoLbnVtX3NlcnZlcnMYBSABKAUSFgoOd2FybXVwX3NlY29u", + "ZHMYBiABKAUSGQoRYmVuY2htYXJrX3NlY29uZHMYByABKAUSIAoYc3Bhd25f", + "bG9jYWxfd29ya2VyX2NvdW50GAggASgFIjYKCVNjZW5hcmlvcxIpCglzY2Vu", + "YXJpb3MYASADKAsyFi5ncnBjLnRlc3RpbmcuU2NlbmFyaW8ihAQKFVNjZW5h", + "cmlvUmVzdWx0U3VtbWFyeRILCgNxcHMYASABKAESGwoTcXBzX3Blcl9zZXJ2", + "ZXJfY29yZRgCIAEoARIaChJzZXJ2ZXJfc3lzdGVtX3RpbWUYAyABKAESGAoQ", + "c2VydmVyX3VzZXJfdGltZRgEIAEoARIaChJjbGllbnRfc3lzdGVtX3RpbWUY", + "BSABKAESGAoQY2xpZW50X3VzZXJfdGltZRgGIAEoARISCgpsYXRlbmN5XzUw", + "GAcgASgBEhIKCmxhdGVuY3lfOTAYCCABKAESEgoKbGF0ZW5jeV85NRgJIAEo", + "ARISCgpsYXRlbmN5Xzk5GAogASgBEhMKC2xhdGVuY3lfOTk5GAsgASgBEhgK", + "EHNlcnZlcl9jcHVfdXNhZ2UYDCABKAESJgoec3VjY2Vzc2Z1bF9yZXF1ZXN0", + "c19wZXJfc2Vjb25kGA0gASgBEiIKGmZhaWxlZF9yZXF1ZXN0c19wZXJfc2Vj", + "b25kGA4gASgBEiAKGGNsaWVudF9wb2xsc19wZXJfcmVxdWVzdBgPIAEoARIg", + "ChhzZXJ2ZXJfcG9sbHNfcGVyX3JlcXVlc3QYECABKAESIgoac2VydmVyX3F1", + "ZXJpZXNfcGVyX2NwdV9zZWMYESABKAESIgoaY2xpZW50X3F1ZXJpZXNfcGVy", + "X2NwdV9zZWMYEiABKAEigwMKDlNjZW5hcmlvUmVzdWx0EigKCHNjZW5hcmlv", + "GAEgASgLMhYuZ3JwYy50ZXN0aW5nLlNjZW5hcmlvEi4KCWxhdGVuY2llcxgC", + "IAEoCzIbLmdycGMudGVzdGluZy5IaXN0b2dyYW1EYXRhEi8KDGNsaWVudF9z", + "dGF0cxgDIAMoCzIZLmdycGMudGVzdGluZy5DbGllbnRTdGF0cxIvCgxzZXJ2", + "ZXJfc3RhdHMYBCADKAsyGS5ncnBjLnRlc3RpbmcuU2VydmVyU3RhdHMSFAoM", + "c2VydmVyX2NvcmVzGAUgAygFEjQKB3N1bW1hcnkYBiABKAsyIy5ncnBjLnRl", + "c3RpbmcuU2NlbmFyaW9SZXN1bHRTdW1tYXJ5EhYKDmNsaWVudF9zdWNjZXNz", + "GAcgAygIEhYKDnNlcnZlcl9zdWNjZXNzGAggAygIEjkKD3JlcXVlc3RfcmVz", + "dWx0cxgJIAMoCzIgLmdycGMudGVzdGluZy5SZXF1ZXN0UmVzdWx0Q291bnQq", + "VgoKQ2xpZW50VHlwZRIPCgtTWU5DX0NMSUVOVBAAEhAKDEFTWU5DX0NMSUVO", + "VBABEhAKDE9USEVSX0NMSUVOVBACEhMKD0NBTExCQUNLX0NMSUVOVBADKlsK", + "ClNlcnZlclR5cGUSDwoLU1lOQ19TRVJWRVIQABIQCgxBU1lOQ19TRVJWRVIQ", + "ARIYChRBU1lOQ19HRU5FUklDX1NFUlZFUhACEhAKDE9USEVSX1NFUlZFUhAD", + "KnIKB1JwY1R5cGUSCQoFVU5BUlkQABINCglTVFJFQU1JTkcQARIZChVTVFJF", + "QU1JTkdfRlJPTV9DTElFTlQQAhIZChVTVFJFQU1JTkdfRlJPTV9TRVJWRVIQ", + "AxIXChNTVFJFQU1JTkdfQk9USF9XQVlTEARiBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { global::Grpc.Testing.PayloadsReflection.Descriptor, global::Grpc.Testing.StatsReflection.Descriptor, }, new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Grpc.Testing.ClientType), typeof(global::Grpc.Testing.ServerType), typeof(global::Grpc.Testing.RpcType), }, new pbr::GeneratedClrTypeInfo[] { @@ -109,7 +110,7 @@ namespace Grpc.Testing { new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.LoadParams), global::Grpc.Testing.LoadParams.Parser, new[]{ "ClosedLoop", "Poisson" }, new[]{ "Load" }, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.SecurityParams), global::Grpc.Testing.SecurityParams.Parser, new[]{ "UseTestCa", "ServerHostOverride", "CredType" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.ChannelArg), global::Grpc.Testing.ChannelArg.Parser, new[]{ "Name", "StrValue", "IntValue" }, new[]{ "Value" }, null, null), - new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.ClientConfig), global::Grpc.Testing.ClientConfig.Parser, new[]{ "ServerTargets", "ClientType", "SecurityParams", "OutstandingRpcsPerChannel", "ClientChannels", "AsyncClientThreads", "RpcType", "LoadParams", "PayloadConfig", "HistogramParams", "CoreList", "CoreLimit", "OtherClientApi", "ChannelArgs", "ThreadsPerCq", "MessagesPerStream", "UseCoalesceApi" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.ClientConfig), global::Grpc.Testing.ClientConfig.Parser, new[]{ "ServerTargets", "ClientType", "SecurityParams", "OutstandingRpcsPerChannel", "ClientChannels", "AsyncClientThreads", "RpcType", "LoadParams", "PayloadConfig", "HistogramParams", "CoreList", "CoreLimit", "OtherClientApi", "ChannelArgs", "ThreadsPerCq", "MessagesPerStream", "UseCoalesceApi", "MedianLatencyCollectionIntervalMillis" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.ClientStatus), global::Grpc.Testing.ClientStatus.Parser, new[]{ "Stats" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.Mark), global::Grpc.Testing.Mark.Parser, new[]{ "Reset" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::Grpc.Testing.ClientArgs), global::Grpc.Testing.ClientArgs.Parser, new[]{ "Setup", "Mark" }, new[]{ "Argtype" }, null, null), @@ -140,6 +141,7 @@ namespace Grpc.Testing { /// used for some language-specific variants /// </summary> [pbr::OriginalName("OTHER_CLIENT")] OtherClient = 2, + [pbr::OriginalName("CALLBACK_CLIENT")] CallbackClient = 3, } public enum ServerType { @@ -1054,6 +1056,7 @@ namespace Grpc.Testing { threadsPerCq_ = other.threadsPerCq_; messagesPerStream_ = other.messagesPerStream_; useCoalesceApi_ = other.useCoalesceApi_; + medianLatencyCollectionIntervalMillis_ = other.medianLatencyCollectionIntervalMillis_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -1278,6 +1281,21 @@ namespace Grpc.Testing { } } + /// <summary>Field number for the "median_latency_collection_interval_millis" field.</summary> + public const int MedianLatencyCollectionIntervalMillisFieldNumber = 20; + private int medianLatencyCollectionIntervalMillis_; + /// <summary> + /// If 0, disabled. Else, specifies the period between gathering latency + /// medians in milliseconds. + /// </summary> + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int MedianLatencyCollectionIntervalMillis { + get { return medianLatencyCollectionIntervalMillis_; } + set { + medianLatencyCollectionIntervalMillis_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as ClientConfig); @@ -1308,6 +1326,7 @@ namespace Grpc.Testing { if (ThreadsPerCq != other.ThreadsPerCq) return false; if (MessagesPerStream != other.MessagesPerStream) return false; if (UseCoalesceApi != other.UseCoalesceApi) return false; + if (MedianLatencyCollectionIntervalMillis != other.MedianLatencyCollectionIntervalMillis) return false; return Equals(_unknownFields, other._unknownFields); } @@ -1331,6 +1350,7 @@ namespace Grpc.Testing { if (ThreadsPerCq != 0) hash ^= ThreadsPerCq.GetHashCode(); if (MessagesPerStream != 0) hash ^= MessagesPerStream.GetHashCode(); if (UseCoalesceApi != false) hash ^= UseCoalesceApi.GetHashCode(); + if (MedianLatencyCollectionIntervalMillis != 0) hash ^= MedianLatencyCollectionIntervalMillis.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -1403,6 +1423,10 @@ namespace Grpc.Testing { output.WriteRawTag(152, 1); output.WriteBool(UseCoalesceApi); } + if (MedianLatencyCollectionIntervalMillis != 0) { + output.WriteRawTag(160, 1); + output.WriteInt32(MedianLatencyCollectionIntervalMillis); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -1456,6 +1480,9 @@ namespace Grpc.Testing { if (UseCoalesceApi != false) { size += 2 + 1; } + if (MedianLatencyCollectionIntervalMillis != 0) { + size += 2 + pb::CodedOutputStream.ComputeInt32Size(MedianLatencyCollectionIntervalMillis); + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -1524,6 +1551,9 @@ namespace Grpc.Testing { if (other.UseCoalesceApi != false) { UseCoalesceApi = other.UseCoalesceApi; } + if (other.MedianLatencyCollectionIntervalMillis != 0) { + MedianLatencyCollectionIntervalMillis = other.MedianLatencyCollectionIntervalMillis; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -1616,6 +1646,10 @@ namespace Grpc.Testing { UseCoalesceApi = input.ReadBool(); break; } + case 160: { + MedianLatencyCollectionIntervalMillis = input.ReadInt32(); + break; + } } } } diff --git a/src/csharp/Grpc.IntegrationTesting/EmptyServiceGrpc.cs b/src/csharp/Grpc.IntegrationTesting/EmptyServiceGrpc.cs index 2d233fbdc0..7e77f8d114 100644 --- a/src/csharp/Grpc.IntegrationTesting/EmptyServiceGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/EmptyServiceGrpc.cs @@ -80,6 +80,14 @@ namespace Grpc.Testing { return grpc::ServerServiceDefinition.CreateBuilder().Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, EmptyServiceBase serviceImpl) + { + } + } } #endregion diff --git a/src/csharp/Grpc.IntegrationTesting/Grpc.IntegrationTesting.csproj b/src/csharp/Grpc.IntegrationTesting/Grpc.IntegrationTesting.csproj index 8daf3fa98b..c342f8a107 100755 --- a/src/csharp/Grpc.IntegrationTesting/Grpc.IntegrationTesting.csproj +++ b/src/csharp/Grpc.IntegrationTesting/Grpc.IntegrationTesting.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.IntegrationTesting</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.IntegrationTesting</PackageId> diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs index e83a8a7274..4750353082 100644 --- a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs +++ b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs @@ -45,7 +45,7 @@ namespace Grpc.IntegrationTesting [Option("server_host", Default = "localhost")] public string ServerHost { get; set; } - [Option("server_host_override", Default = TestCredentials.DefaultHostOverride)] + [Option("server_host_override")] public string ServerHostOverride { get; set; } [Option("server_port", Required = true)] diff --git a/src/csharp/Grpc.IntegrationTesting/MetricsGrpc.cs b/src/csharp/Grpc.IntegrationTesting/MetricsGrpc.cs index 9f16f41ac1..c66a9a9161 100644 --- a/src/csharp/Grpc.IntegrationTesting/MetricsGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/MetricsGrpc.cs @@ -193,6 +193,16 @@ namespace Grpc.Testing { .AddMethod(__Method_GetGauge, serviceImpl.GetGauge).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, MetricsServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_GetAllGauges, serviceImpl.GetAllGauges); + serviceBinder.AddMethod(__Method_GetGauge, serviceImpl.GetGauge); + } + } } #endregion diff --git a/src/csharp/Grpc.IntegrationTesting/NUnitMain.cs b/src/csharp/Grpc.IntegrationTesting/NUnitMain.cs index 9d24762e0a..4135186275 100644 --- a/src/csharp/Grpc.IntegrationTesting/NUnitMain.cs +++ b/src/csharp/Grpc.IntegrationTesting/NUnitMain.cs @@ -34,11 +34,7 @@ namespace Grpc.IntegrationTesting { // Make logger immune to NUnit capturing stdout and stderr to workaround https://github.com/nunit/nunit/issues/1406. GrpcEnvironment.SetLogger(new ConsoleLogger()); -#if NETCOREAPP1_0 - return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args, new ExtendedTextWrapper(Console.Out), Console.In); -#else - return new AutoRun().Execute(args); -#endif + return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); } } } diff --git a/src/csharp/Grpc.IntegrationTesting/ReportQpsScenarioServiceGrpc.cs b/src/csharp/Grpc.IntegrationTesting/ReportQpsScenarioServiceGrpc.cs index 1da0548cb4..954c172272 100644 --- a/src/csharp/Grpc.IntegrationTesting/ReportQpsScenarioServiceGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/ReportQpsScenarioServiceGrpc.cs @@ -143,6 +143,15 @@ namespace Grpc.Testing { .AddMethod(__Method_ReportScenario, serviceImpl.ReportScenario).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, ReportQpsScenarioServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_ReportScenario, serviceImpl.ReportScenario); + } + } } #endregion diff --git a/src/csharp/Grpc.IntegrationTesting/TestGrpc.cs b/src/csharp/Grpc.IntegrationTesting/TestGrpc.cs index 2176916b43..d125fd5627 100644 --- a/src/csharp/Grpc.IntegrationTesting/TestGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/TestGrpc.cs @@ -539,6 +539,22 @@ namespace Grpc.Testing { .AddMethod(__Method_UnimplementedCall, serviceImpl.UnimplementedCall).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, TestServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_EmptyCall, serviceImpl.EmptyCall); + serviceBinder.AddMethod(__Method_UnaryCall, serviceImpl.UnaryCall); + serviceBinder.AddMethod(__Method_CacheableUnaryCall, serviceImpl.CacheableUnaryCall); + serviceBinder.AddMethod(__Method_StreamingOutputCall, serviceImpl.StreamingOutputCall); + serviceBinder.AddMethod(__Method_StreamingInputCall, serviceImpl.StreamingInputCall); + serviceBinder.AddMethod(__Method_FullDuplexCall, serviceImpl.FullDuplexCall); + serviceBinder.AddMethod(__Method_HalfDuplexCall, serviceImpl.HalfDuplexCall); + serviceBinder.AddMethod(__Method_UnimplementedCall, serviceImpl.UnimplementedCall); + } + } /// <summary> /// A simple service NOT implemented at servers so clients can test for @@ -661,6 +677,15 @@ namespace Grpc.Testing { .AddMethod(__Method_UnimplementedCall, serviceImpl.UnimplementedCall).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, UnimplementedServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_UnimplementedCall, serviceImpl.UnimplementedCall); + } + } /// <summary> /// A service used to control reconnect server. @@ -779,6 +804,16 @@ namespace Grpc.Testing { .AddMethod(__Method_Stop, serviceImpl.Stop).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, ReconnectServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_Start, serviceImpl.Start); + serviceBinder.AddMethod(__Method_Stop, serviceImpl.Stop); + } + } } #endregion diff --git a/src/csharp/Grpc.IntegrationTesting/WorkerServiceGrpc.cs b/src/csharp/Grpc.IntegrationTesting/WorkerServiceGrpc.cs index b9e8f91231..5b22337d53 100644 --- a/src/csharp/Grpc.IntegrationTesting/WorkerServiceGrpc.cs +++ b/src/csharp/Grpc.IntegrationTesting/WorkerServiceGrpc.cs @@ -321,6 +321,18 @@ namespace Grpc.Testing { .AddMethod(__Method_QuitWorker, serviceImpl.QuitWorker).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, WorkerServiceBase serviceImpl) + { + serviceBinder.AddMethod(__Method_RunServer, serviceImpl.RunServer); + serviceBinder.AddMethod(__Method_RunClient, serviceImpl.RunClient); + serviceBinder.AddMethod(__Method_CoreCount, serviceImpl.CoreCount); + serviceBinder.AddMethod(__Method_QuitWorker, serviceImpl.QuitWorker); + } + } } #endregion diff --git a/src/csharp/Grpc.Microbenchmarks/Grpc.Microbenchmarks.csproj b/src/csharp/Grpc.Microbenchmarks/Grpc.Microbenchmarks.csproj index d39d46cf1b..5b1656080a 100644 --- a/src/csharp/Grpc.Microbenchmarks/Grpc.Microbenchmarks.csproj +++ b/src/csharp/Grpc.Microbenchmarks/Grpc.Microbenchmarks.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Microbenchmarks</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Microbenchmarks</PackageId> diff --git a/src/csharp/Grpc.Reflection.Tests/Grpc.Reflection.Tests.csproj b/src/csharp/Grpc.Reflection.Tests/Grpc.Reflection.Tests.csproj index 0c12f38f25..8b586c6ecb 100755 --- a/src/csharp/Grpc.Reflection.Tests/Grpc.Reflection.Tests.csproj +++ b/src/csharp/Grpc.Reflection.Tests/Grpc.Reflection.Tests.csproj @@ -4,7 +4,7 @@ <Import Project="..\Grpc.Core\Common.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <AssemblyName>Grpc.Reflection.Tests</AssemblyName> <OutputType>Exe</OutputType> <PackageId>Grpc.Reflection.Tests</PackageId> diff --git a/src/csharp/Grpc.Reflection.Tests/NUnitMain.cs b/src/csharp/Grpc.Reflection.Tests/NUnitMain.cs index 49ed1cc8d4..de4b4af6cf 100644 --- a/src/csharp/Grpc.Reflection.Tests/NUnitMain.cs +++ b/src/csharp/Grpc.Reflection.Tests/NUnitMain.cs @@ -34,11 +34,7 @@ namespace Grpc.Reflection.Tests { // Make logger immune to NUnit capturing stdout and stderr to workaround https://github.com/nunit/nunit/issues/1406. GrpcEnvironment.SetLogger(new ConsoleLogger()); -#if NETCOREAPP1_0 - return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args, new ExtendedTextWrapper(Console.Out), Console.In); -#else - return new AutoRun().Execute(args); -#endif + return new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); } } } diff --git a/src/csharp/Grpc.Reflection/ReflectionGrpc.cs b/src/csharp/Grpc.Reflection/ReflectionGrpc.cs index c00075b7c6..ed55c2f584 100644 --- a/src/csharp/Grpc.Reflection/ReflectionGrpc.cs +++ b/src/csharp/Grpc.Reflection/ReflectionGrpc.cs @@ -123,6 +123,15 @@ namespace Grpc.Reflection.V1Alpha { .AddMethod(__Method_ServerReflectionInfo, serviceImpl.ServerReflectionInfo).Build(); } + /// <summary>Register service method implementations with a service binder. Useful when customizing the service binding logic. + /// Note: this method is part of an experimental API that can change or be removed without any prior notice.</summary> + /// <param name="serviceBinder">Service methods will be bound by calling <c>AddMethod</c> on this object.</param> + /// <param name="serviceImpl">An object implementing the server-side handling logic.</param> + public static void BindService(grpc::ServiceBinderBase serviceBinder, ServerReflectionBase serviceImpl) + { + serviceBinder.AddMethod(__Method_ServerReflectionInfo, serviceImpl.ServerReflectionInfo); + } + } } #endregion diff --git a/src/csharp/Grpc.Tools.Tests/Grpc.Tools.Tests.csproj b/src/csharp/Grpc.Tools.Tests/Grpc.Tools.Tests.csproj index a2d4874eec..cfb40f44ae 100644 --- a/src/csharp/Grpc.Tools.Tests/Grpc.Tools.Tests.csproj +++ b/src/csharp/Grpc.Tools.Tests/Grpc.Tools.Tests.csproj @@ -3,7 +3,7 @@ <Import Project="..\Grpc.Core\Version.csproj.include" /> <PropertyGroup> - <TargetFrameworks>net45;netcoreapp1.0</TargetFrameworks> + <TargetFrameworks>net45;netcoreapp1.1</TargetFrameworks> <OutputType>Exe</OutputType> </PropertyGroup> diff --git a/src/csharp/Grpc.Tools.Tests/NUnitMain.cs b/src/csharp/Grpc.Tools.Tests/NUnitMain.cs index 418c33820e..d30d608aa3 100644 --- a/src/csharp/Grpc.Tools.Tests/NUnitMain.cs +++ b/src/csharp/Grpc.Tools.Tests/NUnitMain.cs @@ -24,10 +24,6 @@ namespace Grpc.Tools.Tests static class NUnitMain { public static int Main(string[] args) => -#if NETCOREAPP1_0 || NETCOREAPP1_1 new AutoRun(typeof(NUnitMain).GetTypeInfo().Assembly).Execute(args); -#else - new AutoRun().Execute(args); -#endif }; } diff --git a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets index 5f76c03ce5..3fe1ccc918 100644 --- a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets @@ -22,9 +22,8 @@ <Target Name="gRPC_ResolvePluginFullPath" AfterTargets="Protobuf_ResolvePlatform"> <PropertyGroup> <!-- TODO(kkm): Do not use Protobuf_PackagedToolsPath, roll gRPC's own. --> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath> <gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/$(gRPC_PluginFileName)</gRPC_PluginFullPath> </PropertyGroup> diff --git a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets index 1d233d23a8..26f9efb5a8 100644 --- a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets +++ b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets @@ -74,9 +74,8 @@ <!-- Next try OS and CPU resolved by ProtoToolsPlatform. --> <Protobuf_ToolsOs Condition=" '$(Protobuf_ToolsOs)' == '' ">$(_Protobuf_ToolsOs)</Protobuf_ToolsOs> <Protobuf_ToolsCpu Condition=" '$(Protobuf_ToolsCpu)' == '' ">$(_Protobuf_ToolsCpu)</Protobuf_ToolsCpu> - <!-- TODO(kkm): Do not package windows x64 builds (#13098). --> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' " - >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\protoc.exe</Protobuf_ProtocFullPath> + >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\protoc.exe</Protobuf_ProtocFullPath> <Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' " >$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/protoc</Protobuf_ProtocFullPath> </PropertyGroup> diff --git a/src/csharp/build_packages_dotnetcli.bat b/src/csharp/build_packages_dotnetcli.bat index 27688360e9..76d4f14390 100755 --- a/src/csharp/build_packages_dotnetcli.bat +++ b/src/csharp/build_packages_dotnetcli.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.17.0-dev +set VERSION=1.18.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/csharp/build_unitypackage.bat b/src/csharp/build_unitypackage.bat index dd74de0491..3334d24c11 100644 --- a/src/csharp/build_unitypackage.bat +++ b/src/csharp/build_unitypackage.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.17.0-dev +set VERSION=1.18.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec index a95a120d21..55ca6048bc 100644 --- a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec +++ b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec @@ -42,7 +42,7 @@ Pod::Spec.new do |s| # exclamation mark ensures that other "regular" pods will be able to find it as it'll be installed # before them. s.name = '!ProtoCompiler-gRPCPlugin' - v = '1.17.0-dev' + v = '1.18.0-dev' s.version = v s.summary = 'The gRPC ProtoC plugin generates Objective-C files from .proto services.' s.description = <<-DESC diff --git a/src/objective-c/GRPCClient/GRPCCall.h b/src/objective-c/GRPCClient/GRPCCall.h index e0ef8b1391..ddc6ae054d 100644 --- a/src/objective-c/GRPCClient/GRPCCall.h +++ b/src/objective-c/GRPCClient/GRPCCall.h @@ -253,7 +253,8 @@ extern id const kGRPCTrailersKey; + (void)setCallSafety:(GRPCCallSafety)callSafety host:(NSString *)host path:(NSString *)path; /** - * Set the dispatch queue to be used for callbacks. + * Set the dispatch queue to be used for callbacks. Current implementation requires \a queue to be a + * serial queue. * * This configuration is only effective before the call starts. */ diff --git a/src/objective-c/GRPCClient/private/version.h b/src/objective-c/GRPCClient/private/version.h index d5463c0b4c..0be0e3c9a0 100644 --- a/src/objective-c/GRPCClient/private/version.h +++ b/src/objective-c/GRPCClient/private/version.h @@ -22,4 +22,4 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.17.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" diff --git a/src/objective-c/README.md b/src/objective-c/README.md index 32e3956a1e..83775f86e1 100644 --- a/src/objective-c/README.md +++ b/src/objective-c/README.md @@ -242,3 +242,12 @@ pod `gRPC-Core`, :podspec => "." # assuming gRPC-Core.podspec is in the same dir These steps should allow gRPC to use OpenSSL and drop BoringSSL dependency. If you see any issue, file an issue to us. + +## Upgrade issue with BoringSSL +If you were using an old version of gRPC (<= v1.14) which depended on pod `BoringSSL` rather than +`BoringSSL-GRPC` and meet issue with the library like: +``` +ld: framework not found openssl +``` +updating `-framework openssl` in Other Linker Flags to `-framework openssl_grpc` in your project +may resolve this issue (see [#16821](https://github.com/grpc/grpc/issues/16821)). diff --git a/src/objective-c/tests/version.h b/src/objective-c/tests/version.h index ca27c03b3c..f2fd692070 100644 --- a/src/objective-c/tests/version.h +++ b/src/objective-c/tests/version.h @@ -22,5 +22,5 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.17.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" #define GRPC_C_VERSION_STRING @"7.0.0-dev" diff --git a/src/php/composer.json b/src/php/composer.json index d54db91b5f..9c298c0e85 100644 --- a/src/php/composer.json +++ b/src/php/composer.json @@ -2,7 +2,7 @@ "name": "grpc/grpc-dev", "description": "gRPC library for PHP - for Developement use only", "license": "Apache-2.0", - "version": "1.17.0", + "version": "1.18.0", "require": { "php": ">=5.5.0", "google/protobuf": "^v3.3.0" diff --git a/src/php/ext/grpc/channel.c b/src/php/ext/grpc/channel.c index b17f3d9a61..c06bdea7fe 100644 --- a/src/php/ext/grpc/channel.c +++ b/src/php/ext/grpc/channel.c @@ -393,6 +393,8 @@ PHP_METHOD(Channel, __construct) { channel->wrapper->target = strdup(target); channel->wrapper->args_hashstr = strdup(sha1str); channel->wrapper->creds_hashstr = NULL; + channel->wrapper->creds = creds; + channel->wrapper->args = args; if (creds != NULL && creds->hashstr != NULL) { php_grpc_int creds_hashstr_len = strlen(creds->hashstr); char *channel_creds_hashstr = malloc(creds_hashstr_len + 1); diff --git a/src/php/ext/grpc/channel.h b/src/php/ext/grpc/channel.h index 27752c9a3f..ce17c4a58a 100644 --- a/src/php/ext/grpc/channel.h +++ b/src/php/ext/grpc/channel.h @@ -19,6 +19,7 @@ #ifndef NET_GRPC_PHP_GRPC_CHANNEL_H_ #define NET_GRPC_PHP_GRPC_CHANNEL_H_ +#include "channel_credentials.h" #include "php_grpc.h" /* Class entry for the PHP Channel class */ @@ -32,6 +33,8 @@ typedef struct _grpc_channel_wrapper { char *creds_hashstr; size_t ref_count; gpr_mu mu; + grpc_channel_args args; + wrapped_grpc_channel_credentials *creds; } grpc_channel_wrapper; /* Wrapper struct for grpc_channel that can be associated with a PHP object */ diff --git a/src/php/ext/grpc/config.m4 b/src/php/ext/grpc/config.m4 index fa54ebd920..9ec2c7cf04 100755 --- a/src/php/ext/grpc/config.m4 +++ b/src/php/ext/grpc/config.m4 @@ -103,7 +103,7 @@ if test "$PHP_COVERAGE" = "yes"; then AC_MSG_ERROR([ccache must be disabled when --enable-coverage option is used. You can disable ccache by setting environment variable CCACHE_DISABLE=1.]) fi - lcov_version_list="1.5 1.6 1.7 1.9 1.10 1.11" + lcov_version_list="1.5 1.6 1.7 1.9 1.10 1.11 1.12 1.13" AC_CHECK_PROG(LCOV, lcov, lcov) AC_CHECK_PROG(GENHTML, genhtml, genhtml) diff --git a/src/php/ext/grpc/php_grpc.c b/src/php/ext/grpc/php_grpc.c index fabd98975d..111c6f4867 100644 --- a/src/php/ext/grpc/php_grpc.c +++ b/src/php/ext/grpc/php_grpc.c @@ -26,6 +26,8 @@ #include "call_credentials.h" #include "server_credentials.h" #include "completion_queue.h" +#include <ext/spl/spl_exceptions.h> +#include <zend_exceptions.h> ZEND_DECLARE_MODULE_GLOBALS(grpc) static PHP_GINIT_FUNCTION(grpc); @@ -86,6 +88,125 @@ ZEND_GET_MODULE(grpc) } */ /* }}} */ +void create_new_channel( + wrapped_grpc_channel *channel, + char *target, + grpc_channel_args args, + wrapped_grpc_channel_credentials *creds) { + if (creds == NULL) { + channel->wrapper->wrapped = grpc_insecure_channel_create(target, &args, + NULL); + } else { + channel->wrapper->wrapped = + grpc_secure_channel_create(creds->wrapped, target, &args, NULL); + } +} + +void acquire_persistent_locks() { + zval *data; + PHP_GRPC_HASH_FOREACH_VAL_START(&grpc_persistent_list, data) + php_grpc_zend_resource *rsrc = + (php_grpc_zend_resource*) PHP_GRPC_HASH_VALPTR_TO_VAL(data) + if (rsrc == NULL) { + break; + } + channel_persistent_le_t* le = rsrc->ptr; + + gpr_mu_lock(&le->channel->mu); + PHP_GRPC_HASH_FOREACH_END() +} + +void release_persistent_locks() { + zval *data; + PHP_GRPC_HASH_FOREACH_VAL_START(&grpc_persistent_list, data) + php_grpc_zend_resource *rsrc = + (php_grpc_zend_resource*) PHP_GRPC_HASH_VALPTR_TO_VAL(data) + if (rsrc == NULL) { + break; + } + channel_persistent_le_t* le = rsrc->ptr; + + gpr_mu_unlock(&le->channel->mu); + PHP_GRPC_HASH_FOREACH_END() +} + +void destroy_grpc_channels() { + zval *data; + PHP_GRPC_HASH_FOREACH_VAL_START(&grpc_persistent_list, data) + php_grpc_zend_resource *rsrc = + (php_grpc_zend_resource*) PHP_GRPC_HASH_VALPTR_TO_VAL(data) + if (rsrc == NULL) { + break; + } + channel_persistent_le_t* le = rsrc->ptr; + + wrapped_grpc_channel wrapped_channel; + wrapped_channel.wrapper = le->channel; + grpc_channel_wrapper *channel = wrapped_channel.wrapper; + grpc_channel_destroy(channel->wrapped); + PHP_GRPC_HASH_FOREACH_END() +} + +void restart_channels() { + zval *data; + PHP_GRPC_HASH_FOREACH_VAL_START(&grpc_persistent_list, data) + php_grpc_zend_resource *rsrc = + (php_grpc_zend_resource*) PHP_GRPC_HASH_VALPTR_TO_VAL(data) + if (rsrc == NULL) { + break; + } + channel_persistent_le_t* le = rsrc->ptr; + + wrapped_grpc_channel wrapped_channel; + wrapped_channel.wrapper = le->channel; + grpc_channel_wrapper *channel = wrapped_channel.wrapper; + create_new_channel(&wrapped_channel, channel->target, channel->args, + channel->creds); + gpr_mu_unlock(&channel->mu); + PHP_GRPC_HASH_FOREACH_END() +} + +void prefork() { + acquire_persistent_locks(); +} + +void postfork_child() { + TSRMLS_FETCH(); + + // loop through persistant list and destroy all underlying grpc_channel objs + destroy_grpc_channels(); + + // clear completion queue + grpc_php_shutdown_completion_queue(TSRMLS_C); + + // clean-up grpc_core + grpc_shutdown(); + if (grpc_is_initialized() > 0) { + zend_throw_exception(spl_ce_UnexpectedValueException, + "Oops, failed to shutdown gRPC Core after fork()", + 1 TSRMLS_CC); + } + + // restart grpc_core + grpc_init(); + grpc_php_init_completion_queue(TSRMLS_C); + + // re-create grpc_channel and point wrapped to it + // unlock wrapped grpc channel mutex + restart_channels(); +} + +void postfork_parent() { + release_persistent_locks(); +} + +void register_fork_handlers() { + if (getenv("GRPC_ENABLE_FORK_SUPPORT")) { +#ifdef GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK + pthread_atfork(&prefork, &postfork_parent, &postfork_child); +#endif // GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK + } +} /* {{{ PHP_MINIT_FUNCTION */ @@ -265,6 +386,7 @@ PHP_MINFO_FUNCTION(grpc) { PHP_RINIT_FUNCTION(grpc) { if (!GRPC_G(initialized)) { grpc_init(); + register_fork_handlers(); grpc_php_init_completion_queue(TSRMLS_C); GRPC_G(initialized) = 1; } diff --git a/src/php/ext/grpc/version.h b/src/php/ext/grpc/version.h index 70f8bbbf40..1ddf90a667 100644 --- a/src/php/ext/grpc/version.h +++ b/src/php/ext/grpc/version.h @@ -20,6 +20,6 @@ #ifndef VERSION_H #define VERSION_H -#define PHP_GRPC_VERSION "1.17.0dev" +#define PHP_GRPC_VERSION "1.18.0dev" #endif /* VERSION_H */ diff --git a/src/php/tests/interop/interop_client.php b/src/php/tests/interop/interop_client.php index c865678f70..19cbf21bc2 100755 --- a/src/php/tests/interop/interop_client.php +++ b/src/php/tests/interop/interop_client.php @@ -530,7 +530,7 @@ function _makeStub($args) throw new Exception('Missing argument: --test_case is required'); } - if ($args['server_port'] === 443) { + if ($args['server_port'] === '443') { $server_address = $args['server_host']; } else { $server_address = $args['server_host'].':'.$args['server_port']; @@ -538,7 +538,7 @@ function _makeStub($args) $test_case = $args['test_case']; - $host_override = 'foo.test.google.fr'; + $host_override = ''; if (array_key_exists('server_host_override', $args)) { $host_override = $args['server_host_override']; } @@ -565,7 +565,9 @@ function _makeStub($args) $ssl_credentials = Grpc\ChannelCredentials::createSsl(); } $opts['credentials'] = $ssl_credentials; - $opts['grpc.ssl_target_name_override'] = $host_override; + if (!empty($host_override)) { + $opts['grpc.ssl_target_name_override'] = $host_override; + } } else { $opts['credentials'] = Grpc\ChannelCredentials::createInsecure(); } diff --git a/src/proto/grpc/channelz/BUILD b/src/proto/grpc/channelz/BUILD index bdb03d5e2d..b6b485e3e8 100644 --- a/src/proto/grpc/channelz/BUILD +++ b/src/proto/grpc/channelz/BUILD @@ -24,3 +24,10 @@ grpc_proto_library( has_services = True, well_known_protos = True, ) + +filegroup( + name = "channelz_proto_file", + srcs = [ + "channelz.proto", + ], +) diff --git a/src/proto/grpc/channelz/channelz.proto b/src/proto/grpc/channelz/channelz.proto index 6202e5e817..20de23f7fa 100644 --- a/src/proto/grpc/channelz/channelz.proto +++ b/src/proto/grpc/channelz/channelz.proto @@ -177,6 +177,7 @@ message SubchannelRef { // SocketRef is a reference to a Socket. message SocketRef { + // The globally unique id for this socket. Must be a positive number. int64 socket_id = 3; // An optional name associated with the socket. string name = 4; @@ -288,7 +289,8 @@ message SocketData { // include stream level or TCP level flow control info. google.protobuf.Int64Value remote_flow_control_window = 12; - // Socket options set on this socket. May be absent. + // Socket options set on this socket. May be absent if 'summary' is set + // on GetSocketRequest. repeated SocketOption option = 13; } @@ -439,12 +441,21 @@ service Channelz { message GetTopChannelsRequest { // start_channel_id indicates that only channels at or above this id should be // included in the results. + // To request the first page, this should be set to 0. To request + // subsequent pages, the client generates this value by adding 1 to + // the highest seen result ID. int64 start_channel_id = 1; + + // If non-zero, the server will return a page of results containing + // at most this many items. If zero, the server will choose a + // reasonable page size. Must never be negative. + int64 max_results = 2; } message GetTopChannelsResponse { // list of channels that the connection detail service knows about. Sorted in // ascending channel_id order. + // Must contain at least 1 result, otherwise 'end' must be true. repeated Channel channel = 1; // If set, indicates that the list of channels is the final list. Requesting // more channels can only return more if they are created after this RPC @@ -455,12 +466,21 @@ message GetTopChannelsResponse { message GetServersRequest { // start_server_id indicates that only servers at or above this id should be // included in the results. + // To request the first page, this must be set to 0. To request + // subsequent pages, the client generates this value by adding 1 to + // the highest seen result ID. int64 start_server_id = 1; + + // If non-zero, the server will return a page of results containing + // at most this many items. If zero, the server will choose a + // reasonable page size. Must never be negative. + int64 max_results = 2; } message GetServersResponse { // list of servers that the connection detail service knows about. Sorted in // ascending server_id order. + // Must contain at least 1 result, otherwise 'end' must be true. repeated Server server = 1; // If set, indicates that the list of servers is the final list. Requesting // more servers will only return more if they are created after this RPC @@ -483,12 +503,21 @@ message GetServerSocketsRequest { int64 server_id = 1; // start_socket_id indicates that only sockets at or above this id should be // included in the results. + // To request the first page, this must be set to 0. To request + // subsequent pages, the client generates this value by adding 1 to + // the highest seen result ID. int64 start_socket_id = 2; + + // If non-zero, the server will return a page of results containing + // at most this many items. If zero, the server will choose a + // reasonable page size. Must never be negative. + int64 max_results = 3; } message GetServerSocketsResponse { // list of socket refs that the connection detail service knows about. Sorted in // ascending socket_id order. + // Must contain at least 1 result, otherwise 'end' must be true. repeated SocketRef socket_ref = 1; // If set, indicates that the list of sockets is the final list. Requesting // more sockets will only return more if they are created after this RPC @@ -521,6 +550,11 @@ message GetSubchannelResponse { message GetSocketRequest { // socket_id is the identifier of the specific socket to get. int64 socket_id = 1; + + // If true, the response will contain only high level information + // that is inexpensive to obtain. Fields thay may be omitted are + // documented. + bool summary = 2; } message GetSocketResponse { diff --git a/src/proto/grpc/health/v1/BUILD b/src/proto/grpc/health/v1/BUILD index 97642985c9..38a7d992e0 100644 --- a/src/proto/grpc/health/v1/BUILD +++ b/src/proto/grpc/health/v1/BUILD @@ -29,4 +29,3 @@ filegroup( "health.proto", ], ) - diff --git a/src/proto/grpc/testing/BUILD b/src/proto/grpc/testing/BUILD index 7048911b9a..9876d5160a 100644 --- a/src/proto/grpc/testing/BUILD +++ b/src/proto/grpc/testing/BUILD @@ -50,7 +50,8 @@ grpc_proto_library( grpc_proto_library( name = "echo_proto", srcs = ["echo.proto"], - deps = ["echo_messages_proto"], + deps = ["echo_messages_proto", + "simple_messages_proto"], generate_mocks = True, ) @@ -120,6 +121,12 @@ grpc_proto_library( ) grpc_proto_library( + name = "simple_messages_proto", + srcs = ["simple_messages.proto"], + has_services = False, +) + +grpc_proto_library( name = "stats_proto", srcs = ["stats.proto"], has_services = False, diff --git a/src/proto/grpc/testing/compiler_test.proto b/src/proto/grpc/testing/compiler_test.proto index db5ca48331..9fa5590a59 100644 --- a/src/proto/grpc/testing/compiler_test.proto +++ b/src/proto/grpc/testing/compiler_test.proto @@ -20,6 +20,9 @@ syntax = "proto3"; // Ignored detached comment +// The comments in this file are not meant for readability +// but rather to test to make sure that the code generator +// properly preserves comments on files, services, and RPCs // Ignored package leading comment package grpc.testing; diff --git a/src/proto/grpc/testing/echo.proto b/src/proto/grpc/testing/echo.proto index 13e28320fc..977858f6bc 100644 --- a/src/proto/grpc/testing/echo.proto +++ b/src/proto/grpc/testing/echo.proto @@ -16,11 +16,15 @@ syntax = "proto3"; import "src/proto/grpc/testing/echo_messages.proto"; +import "src/proto/grpc/testing/simple_messages.proto"; package grpc.testing; service EchoTestService { rpc Echo(EchoRequest) returns (EchoResponse); + // A service which checks that the initial metadata sent over contains some + // expected key value pair + rpc CheckClientInitialMetadata(SimpleRequest) returns (SimpleResponse); rpc RequestStream(stream EchoRequest) returns (EchoResponse); rpc ResponseStream(EchoRequest) returns (stream EchoResponse); rpc BidiStream(stream EchoRequest) returns (stream EchoResponse); diff --git a/src/proto/grpc/testing/simple_messages.proto b/src/proto/grpc/testing/simple_messages.proto new file mode 100644 index 0000000000..4b65d18909 --- /dev/null +++ b/src/proto/grpc/testing/simple_messages.proto @@ -0,0 +1,24 @@ + +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package grpc.testing; + +message SimpleRequest { +} + +message SimpleResponse { +} diff --git a/src/python/grpcio/grpc/BUILD.bazel b/src/python/grpcio/grpc/BUILD.bazel index 2e6839ef2d..6958ccdfb6 100644 --- a/src/python/grpcio/grpc/BUILD.bazel +++ b/src/python/grpcio/grpc/BUILD.bazel @@ -13,7 +13,6 @@ py_library( ":interceptor", ":server", "//src/python/grpcio/grpc/_cython:cygrpc", - "//src/python/grpcio/grpc/beta", "//src/python/grpcio/grpc/experimental", "//src/python/grpcio/grpc/framework", requirement('enum34'), diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index df98dd10ad..daf869b156 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -15,12 +15,14 @@ import abc import enum +import logging import sys - import six from grpc._cython import cygrpc as _cygrpc +logging.getLogger(__name__).addHandler(logging.NullHandler()) + ############################## Future Interface ############################### @@ -264,6 +266,22 @@ class StatusCode(enum.Enum): UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated') +############################# gRPC Status ################################ + + +class Status(six.with_metaclass(abc.ABCMeta)): + """Describes the status of an RPC. + + This is an EXPERIMENTAL API. + + Attributes: + code: A StatusCode object to be sent to the client. + details: An ASCII-encodable string to be sent to the client upon + termination of the RPC. + trailing_metadata: The trailing :term:`metadata` in the RPC. + """ + + ############################# gRPC Exceptions ################################ @@ -387,7 +405,8 @@ class ClientCallDetails(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready mechanism. + wait_for_ready: This is an EXPERIMENTAL argument. An optional flag t + enable wait for ready mechanism. """ @@ -654,8 +673,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: The response value for the RPC. @@ -683,8 +702,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: The response value for the RPC and a Call value for the RPC. @@ -712,8 +731,8 @@ class UnaryUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: An object that is both a Call for the RPC and a Future. @@ -744,8 +763,8 @@ class UnaryStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: An optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: An object that is both a Call for the RPC and an iterator of @@ -776,8 +795,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: The response value for the RPC. @@ -806,8 +825,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: The response value for the RPC and a Call object for the RPC. @@ -835,8 +854,8 @@ class StreamUnaryMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: An object that is both a Call for the RPC and a Future. @@ -867,8 +886,8 @@ class StreamStreamMultiCallable(six.with_metaclass(abc.ABCMeta)): metadata: Optional :term:`metadata` to be transmitted to the service-side of the RPC. credentials: An optional CallCredentials for the RPC. - wait_for_ready: An optional flag to enable wait for ready - mechanism + wait_for_ready: This is an EXPERIMENTAL argument. An optional + flag to enable wait for ready mechanism Returns: An object that is both a Call for the RPC and an iterator of @@ -1116,6 +1135,25 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): raise NotImplementedError() @abc.abstractmethod + def abort_with_status(self, status): + """Raises an exception to terminate the RPC with a non-OK status. + + The status passed as argument will supercede any existing status code, + status message and trailing metadata. + + This is an EXPERIMENTAL API. + + Args: + status: A grpc.Status object. The status code in it must not be + StatusCode.OK. + + Raises: + Exception: An exception is always raised to signal the abortion the + RPC to the gRPC runtime. + """ + raise NotImplementedError() + + @abc.abstractmethod def set_code(self, code): """Sets the value to be used as status code upon RPC completion. @@ -1720,7 +1758,7 @@ def server(thread_pool, handlers. The interceptors are given control in the order they are specified. This is an EXPERIMENTAL API. options: An optional list of key-value pairs (channel args in gRPC runtime) - to configure the channel. + to configure the channel. maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server will service before returning RESOURCE_EXHAUSTED status, or None to indicate no limit. @@ -1744,6 +1782,7 @@ __all__ = ( 'Future', 'ChannelConnectivity', 'StatusCode', + 'Status', 'RpcError', 'RpcContext', 'Call', diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index c17824563d..9b990f490d 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -46,7 +46,7 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin): # Hack to determine if these are JWT creds and we need to pass # additional_claims when getting a token - self._is_jwt = 'additional_claims' in inspect.getargspec( + self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method credentials.get_access_token).args def __call__(self, context, callback): diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 3ff7658748..8051fb306c 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -25,7 +25,6 @@ from grpc._cython import cygrpc from grpc.framework.foundation import callable_util _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) _USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__) @@ -176,6 +175,7 @@ def _event_handler(state, response_deserializer): return handle_event +#pylint: disable=too-many-statements def _consume_request_iterator(request_iterator, state, call, request_serializer, event_handler): if cygrpc.is_fork_support_enabled(): @@ -499,6 +499,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def _prepare(self, request, timeout, metadata, wait_for_ready): deadline, serialized_request, rendezvous = _start_unary_request( @@ -525,17 +526,18 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready) if state is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: call = self._channel.segregated_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, (( operations, None, - ),)) + ),), self._context) event = call.next_event() _handle_event(event, state, self._response_deserializer) - return state, call, + return state, call def __call__(self, request, @@ -566,13 +568,14 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): state, operations, deadline, rendezvous = self._prepare( request, timeout, metadata, wait_for_ready) if state is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - (operations,), event_handler) + (operations,), event_handler, self._context) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -587,6 +590,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def __call__(self, request, @@ -599,7 +603,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) if serialized_request is None: - raise rendezvous + raise rendezvous # pylint: disable-msg=raising-bad-type else: state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) operationses = ( @@ -615,9 +619,10 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): ) event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, None, deadline, metadata, None if credentials is None else credentials._credentials, - operationses, event_handler) + operationses, event_handler, self._context) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -632,6 +637,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def _blocking(self, request_iterator, timeout, metadata, credentials, wait_for_ready): @@ -640,10 +646,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) call = self._channel.segregated_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses_and_tags( - metadata, initial_metadata_flags)) + metadata, initial_metadata_flags), self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, None) while True: @@ -653,7 +660,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): state.condition.notify_all() if not state.due: break - return state, call, + return state, call def __call__(self, request_iterator, @@ -687,10 +694,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( wait_for_ready) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operationses( - metadata, initial_metadata_flags), event_handler) + metadata, initial_metadata_flags), event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -706,6 +714,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._context = cygrpc.build_context() def __call__(self, request_iterator, @@ -727,9 +736,10 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): ) event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - 0, self._method, None, deadline, metadata, None + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, + None, deadline, metadata, None if credentials is None else credentials._credentials, operationses, - event_handler) + event_handler, self._context) _consume_request_iterator(request_iterator, state, call, self._request_serializer, event_handler) return _Rendezvous(state, call, self._response_deserializer, deadline) @@ -745,10 +755,10 @@ class _InitialMetadataFlags(int): def with_wait_for_ready(self, wait_for_ready): if wait_for_ready is not None: if wait_for_ready: - self = self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ + return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) elif not wait_for_ready: - self = self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ + return self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) return self @@ -789,7 +799,7 @@ def _channel_managed_call_management(state): # pylint: disable=too-many-arguments def create(flags, method, host, deadline, metadata, credentials, - operationses, event_handler): + operationses, event_handler, context): """Creates a cygrpc.IntegratedCall. Args: @@ -804,7 +814,7 @@ def _channel_managed_call_management(state): started on the call. event_handler: A behavior to call to handle the events resultant from the operations on the call. - + context: Context object for distributed tracing. Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ @@ -815,7 +825,7 @@ def _channel_managed_call_management(state): with state.lock: call = state.channel.integrated_call(flags, method, host, deadline, metadata, credentials, - operationses_and_tags) + operationses_and_tags, context) if state.managed_calls == 0: state.managed_calls = 1 _run_channel_spin_thread(state) diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 42f3a4e614..f69127e38e 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -21,7 +21,6 @@ import grpc from grpc._cython import cygrpc _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { cygrpc.ConnectivityState.idle: diff --git a/src/python/grpcio/grpc/_cython/BUILD.bazel b/src/python/grpcio/grpc/_cython/BUILD.bazel index cfd3a51d9b..42db7b8721 100644 --- a/src/python/grpcio/grpc/_cython/BUILD.bazel +++ b/src/python/grpcio/grpc/_cython/BUILD.bazel @@ -6,45 +6,47 @@ pyx_library( name = "cygrpc", srcs = [ "__init__.py", - "cygrpc.pxd", - "cygrpc.pyx", + "_cygrpc/_hooks.pxd.pxi", "_cygrpc/_hooks.pyx.pxi", - "_cygrpc/grpc_string.pyx.pxi", + "_cygrpc/arguments.pxd.pxi", "_cygrpc/arguments.pyx.pxi", + "_cygrpc/call.pxd.pxi", "_cygrpc/call.pyx.pxi", + "_cygrpc/channel.pxd.pxi", "_cygrpc/channel.pyx.pxi", - "_cygrpc/credentials.pyx.pxi", + "_cygrpc/channelz.pyx.pxi", + "_cygrpc/completion_queue.pxd.pxi", "_cygrpc/completion_queue.pyx.pxi", - "_cygrpc/event.pyx.pxi", - "_cygrpc/fork_posix.pyx.pxi", - "_cygrpc/metadata.pyx.pxi", - "_cygrpc/operation.pyx.pxi", - "_cygrpc/records.pyx.pxi", - "_cygrpc/security.pyx.pxi", - "_cygrpc/server.pyx.pxi", - "_cygrpc/tag.pyx.pxi", - "_cygrpc/time.pyx.pxi", - "_cygrpc/grpc_gevent.pyx.pxi", - "_cygrpc/grpc.pxi", - "_cygrpc/_hooks.pxd.pxi", - "_cygrpc/arguments.pxd.pxi", - "_cygrpc/call.pxd.pxi", - "_cygrpc/channel.pxd.pxi", "_cygrpc/credentials.pxd.pxi", - "_cygrpc/completion_queue.pxd.pxi", + "_cygrpc/credentials.pyx.pxi", "_cygrpc/event.pxd.pxi", + "_cygrpc/event.pyx.pxi", "_cygrpc/fork_posix.pxd.pxi", + "_cygrpc/fork_posix.pyx.pxi", + "_cygrpc/grpc.pxi", + "_cygrpc/grpc_gevent.pxd.pxi", + "_cygrpc/grpc_gevent.pyx.pxi", + "_cygrpc/grpc_string.pyx.pxi", "_cygrpc/metadata.pxd.pxi", + "_cygrpc/metadata.pyx.pxi", "_cygrpc/operation.pxd.pxi", + "_cygrpc/operation.pyx.pxi", + "_cygrpc/propagation_bits.pxd.pxi", + "_cygrpc/propagation_bits.pyx.pxi", "_cygrpc/records.pxd.pxi", + "_cygrpc/records.pyx.pxi", "_cygrpc/security.pxd.pxi", + "_cygrpc/security.pyx.pxi", "_cygrpc/server.pxd.pxi", + "_cygrpc/server.pyx.pxi", "_cygrpc/tag.pxd.pxi", + "_cygrpc/tag.pyx.pxi", "_cygrpc/time.pxd.pxi", - "_cygrpc/grpc_gevent.pxd.pxi", + "_cygrpc/time.pyx.pxi", + "cygrpc.pxd", + "cygrpc.pyx", ], deps = [ "//:grpc", ], ) - diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/_hooks.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/_hooks.pyx.pxi index 38cf629dc2..cd4a51a635 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/_hooks.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/_hooks.pyx.pxi @@ -15,3 +15,18 @@ cdef object _custom_op_on_c_call(int op, grpc_call *call): raise NotImplementedError("No custom hooks are implemented") + +def install_census_context_from_call(Call call): + pass + +def uninstall_context(): + pass + +def build_context(): + pass + +cdef class CensusContext: + pass + +def set_census_context_on_call(_CallState call_state, CensusContext census_ctx): + pass diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index a81ff4d823..135d224095 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -159,7 +159,8 @@ cdef void _call( _ChannelState channel_state, _CallState call_state, grpc_completion_queue *c_completion_queue, on_success, int flags, method, host, object deadline, CallCredentials credentials, - object operationses_and_user_tags, object metadata) except *: + object operationses_and_user_tags, object metadata, + object context) except *: """Invokes an RPC. Args: @@ -185,6 +186,7 @@ cdef void _call( which is an object to be used as a tag. A SendInitialMetadataOperation must be present in the first element of this value. metadata: The metadata for this call. + context: Context object for distributed tracing. """ cdef grpc_slice method_slice cdef grpc_slice host_slice @@ -208,6 +210,8 @@ cdef void _call( grpc_slice_unref(method_slice) if host_slice_ptr: grpc_slice_unref(host_slice) + if context is not None: + set_census_context_on_call(call_state, context) if credentials is not None: c_call_credentials = credentials.c() c_call_error = grpc_call_set_credentials( @@ -257,7 +261,8 @@ cdef class IntegratedCall: cdef IntegratedCall _integrated_call( _ChannelState state, int flags, method, host, object deadline, - object metadata, CallCredentials credentials, operationses_and_user_tags): + object metadata, CallCredentials credentials, operationses_and_user_tags, + object context): call_state = _CallState() def on_success(started_tags): @@ -266,7 +271,7 @@ cdef IntegratedCall _integrated_call( _call( state, call_state, state.c_call_completion_queue, on_success, flags, - method, host, deadline, credentials, operationses_and_user_tags, metadata) + method, host, deadline, credentials, operationses_and_user_tags, metadata, context) return IntegratedCall(state, call_state) @@ -308,7 +313,8 @@ cdef class SegregatedCall: cdef SegregatedCall _segregated_call( _ChannelState state, int flags, method, host, object deadline, - object metadata, CallCredentials credentials, operationses_and_user_tags): + object metadata, CallCredentials credentials, operationses_and_user_tags, + object context): cdef _CallState call_state = _CallState() cdef SegregatedCall segregated_call cdef grpc_completion_queue *c_completion_queue @@ -325,7 +331,8 @@ cdef SegregatedCall _segregated_call( try: _call( state, call_state, c_completion_queue, on_success, flags, method, host, - deadline, credentials, operationses_and_user_tags, metadata) + deadline, credentials, operationses_and_user_tags, metadata, + context) except: _destroy_c_completion_queue(c_completion_queue) raise @@ -443,10 +450,11 @@ cdef class Channel: def integrated_call( self, int flags, method, host, object deadline, object metadata, - CallCredentials credentials, operationses_and_tags): + CallCredentials credentials, operationses_and_tags, + object context = None): return _integrated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags) + operationses_and_tags, context) def next_call_event(self): def on_success(tag): @@ -461,10 +469,11 @@ cdef class Channel: def segregated_call( self, int flags, method, host, object deadline, object metadata, - CallCredentials credentials, operationses_and_tags): + CallCredentials credentials, operationses_and_tags, + object context = None): return _segregated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags) + operationses_and_tags, context) def check_connectivity_state(self, bint try_to_connect): with self._state.condition: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channelz.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channelz.pyx.pxi new file mode 100644 index 0000000000..36c8cd121c --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channelz.pyx.pxi @@ -0,0 +1,71 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def channelz_get_top_channels(start_channel_id): + cdef char *c_returned_str = grpc_channelz_get_top_channels( + start_channel_id, + ) + if c_returned_str == NULL: + raise ValueError('Failed to get top channels, please ensure your' \ + ' start_channel_id==%s is valid' % start_channel_id) + return c_returned_str + +def channelz_get_servers(start_server_id): + cdef char *c_returned_str = grpc_channelz_get_servers(start_server_id) + if c_returned_str == NULL: + raise ValueError('Failed to get servers, please ensure your' \ + ' start_server_id==%s is valid' % start_server_id) + return c_returned_str + +def channelz_get_server(server_id): + cdef char *c_returned_str = grpc_channelz_get_server(server_id) + if c_returned_str == NULL: + raise ValueError('Failed to get the server, please ensure your' \ + ' server_id==%s is valid' % server_id) + return c_returned_str + +def channelz_get_server_sockets(server_id, start_socket_id, max_results): + cdef char *c_returned_str = grpc_channelz_get_server_sockets( + server_id, + start_socket_id, + max_results, + ) + if c_returned_str == NULL: + raise ValueError('Failed to get server sockets, please ensure your' \ + ' server_id==%s and start_socket_id==%s and' \ + ' max_results==%s is valid' % + (server_id, start_socket_id, max_results)) + return c_returned_str + +def channelz_get_channel(channel_id): + cdef char *c_returned_str = grpc_channelz_get_channel(channel_id) + if c_returned_str == NULL: + raise ValueError('Failed to get the channel, please ensure your' \ + ' channel_id==%s is valid' % (channel_id)) + return c_returned_str + +def channelz_get_subchannel(subchannel_id): + cdef char *c_returned_str = grpc_channelz_get_subchannel(subchannel_id) + if c_returned_str == NULL: + raise ValueError('Failed to get the subchannel, please ensure your' \ + ' subchannel_id==%s is valid' % (subchannel_id)) + return c_returned_str + +def channelz_get_socket(socket_id): + cdef char *c_returned_str = grpc_channelz_get_socket(socket_id) + if c_returned_str == NULL: + raise ValueError('Failed to get the socket, please ensure your' \ + ' socket_id==%s is valid' % (socket_id)) + return c_returned_str diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi index 141116df5d..3c33b46dbb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi @@ -49,7 +49,7 @@ cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline): cdef _interpret_event(grpc_event c_event): cdef _Tag tag if c_event.type == GRPC_QUEUE_TIMEOUT: - # NOTE(nathaniel): For now we coopt ConnectivityEvent here. + # TODO(ericgribkoff) Do not coopt ConnectivityEvent here. return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None) elif c_event.type == GRPC_QUEUE_SHUTDOWN: # NOTE(nathaniel): For now we coopt ConnectivityEvent here. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi index 8d73215247..1cef726970 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi @@ -15,7 +15,7 @@ cdef class CallCredentials: - cdef grpc_call_credentials *c(self) + cdef grpc_call_credentials *c(self) except * # TODO(https://github.com/grpc/grpc/issues/12531): remove. cdef grpc_call_credentials *c_credentials @@ -36,7 +36,7 @@ cdef class MetadataPluginCallCredentials(CallCredentials): cdef readonly object _metadata_plugin cdef readonly bytes _name - cdef grpc_call_credentials *c(self) + cdef grpc_call_credentials *c(self) except * cdef grpc_call_credentials *_composition(call_credentialses) @@ -46,12 +46,12 @@ cdef class CompositeCallCredentials(CallCredentials): cdef readonly tuple _call_credentialses - cdef grpc_call_credentials *c(self) + cdef grpc_call_credentials *c(self) except * cdef class ChannelCredentials: - cdef grpc_channel_credentials *c(self) + cdef grpc_channel_credentials *c(self) except * # TODO(https://github.com/grpc/grpc/issues/12531): remove. cdef grpc_channel_credentials *c_credentials @@ -68,7 +68,7 @@ cdef class SSLChannelCredentials(ChannelCredentials): cdef readonly object _private_key cdef readonly object _certificate_chain - cdef grpc_channel_credentials *c(self) + cdef grpc_channel_credentials *c(self) except * cdef class CompositeChannelCredentials(ChannelCredentials): @@ -76,7 +76,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials): cdef readonly tuple _call_credentialses cdef readonly ChannelCredentials _channel_credentials - cdef grpc_channel_credentials *c(self) + cdef grpc_channel_credentials *c(self) except * cdef class ServerCertificateConfig: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index 63048e8da0..2f51be40ce 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -35,7 +35,7 @@ def _spawn_callback_async(callback, args): cdef class CallCredentials: - cdef grpc_call_credentials *c(self): + cdef grpc_call_credentials *c(self) except *: raise NotImplementedError() @@ -61,6 +61,7 @@ cdef int _get_metadata( cdef void _destroy(void *state) with gil: cpython.Py_DECREF(<object>state) + grpc_shutdown() cdef class MetadataPluginCallCredentials(CallCredentials): @@ -69,13 +70,14 @@ cdef class MetadataPluginCallCredentials(CallCredentials): self._metadata_plugin = metadata_plugin self._name = name - cdef grpc_call_credentials *c(self): + cdef grpc_call_credentials *c(self) except *: cdef grpc_metadata_credentials_plugin c_metadata_plugin c_metadata_plugin.get_metadata = _get_metadata c_metadata_plugin.destroy = _destroy c_metadata_plugin.state = <void *>self._metadata_plugin c_metadata_plugin.type = self._name cpython.Py_INCREF(self._metadata_plugin) + fork_handlers_and_grpc_init() return grpc_metadata_credentials_create_from_plugin(c_metadata_plugin, NULL) @@ -101,13 +103,13 @@ cdef class CompositeCallCredentials(CallCredentials): def __cinit__(self, call_credentialses): self._call_credentialses = call_credentialses - cdef grpc_call_credentials *c(self): + cdef grpc_call_credentials *c(self) except *: return _composition(self._call_credentialses) cdef class ChannelCredentials: - cdef grpc_channel_credentials *c(self): + cdef grpc_channel_credentials *c(self) except *: raise NotImplementedError() @@ -129,11 +131,13 @@ cdef class SSLSessionCacheLRU: cdef class SSLChannelCredentials(ChannelCredentials): def __cinit__(self, pem_root_certificates, private_key, certificate_chain): + if pem_root_certificates is not None and not isinstance(pem_root_certificates, bytes): + raise TypeError('expected certificate to be bytes, got %s' % (type(pem_root_certificates))) self._pem_root_certificates = pem_root_certificates self._private_key = private_key self._certificate_chain = certificate_chain - cdef grpc_channel_credentials *c(self): + cdef grpc_channel_credentials *c(self) except *: cdef const char *c_pem_root_certificates cdef grpc_ssl_pem_key_cert_pair c_pem_key_certificate_pair if self._pem_root_certificates is None: @@ -162,7 +166,7 @@ cdef class CompositeChannelCredentials(ChannelCredentials): self._call_credentialses = call_credentialses self._channel_credentials = channel_credentials - cdef grpc_channel_credentials *c(self): + cdef grpc_channel_credentials *c(self) except *: cdef grpc_channel_credentials *c_channel_credentials c_channel_credentials = self._channel_credentials.c() cdef grpc_call_credentials *c_call_credentials_composition = _composition( diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index 23428f0b0c..fc7a9ba439 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -13,6 +13,7 @@ # limitations under the License. cimport libc.time +from libc.stdint cimport intptr_t # Typedef types with approximately the same semantics to provide their names to @@ -121,7 +122,6 @@ cdef extern from "grpc/grpc.h": GRPC_STATUS_DATA_LOSS GRPC_STATUS__DO_NOT_USE - const char *GRPC_ARG_PRIMARY_USER_AGENT_STRING const char *GRPC_ARG_ENABLE_CENSUS const char *GRPC_ARG_MAX_CONCURRENT_STREAMS const char *GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH @@ -190,12 +190,6 @@ cdef extern from "grpc/grpc.h": size_t arguments_length "num_args" grpc_arg *arguments "args" - ctypedef enum grpc_compression_level: - GRPC_COMPRESS_LEVEL_NONE - GRPC_COMPRESS_LEVEL_LOW - GRPC_COMPRESS_LEVEL_MED - GRPC_COMPRESS_LEVEL_HIGH - ctypedef enum grpc_stream_compression_level: GRPC_STREAM_COMPRESS_LEVEL_NONE GRPC_STREAM_COMPRESS_LEVEL_LOW @@ -391,6 +385,16 @@ cdef extern from "grpc/grpc.h": void grpc_server_cancel_all_calls(grpc_server *server) nogil void grpc_server_destroy(grpc_server *server) nogil + char* grpc_channelz_get_top_channels(intptr_t start_channel_id) + char* grpc_channelz_get_servers(intptr_t start_server_id) + char* grpc_channelz_get_server(intptr_t server_id) + char* grpc_channelz_get_server_sockets(intptr_t server_id, + intptr_t start_socket_id, + intptr_t max_results) + char* grpc_channelz_get_channel(intptr_t channel_id) + char* grpc_channelz_get_subchannel(intptr_t subchannel_id) + char* grpc_channelz_get_socket(intptr_t socket_id) + cdef extern from "grpc/grpc_security.h": diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx.pxi index f9a1b2856d..a1618d04d0 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_gevent.pyx.pxi @@ -266,7 +266,7 @@ cdef grpc_error* socket_listen(grpc_custom_socket* socket) with gil: (<SocketWrapper>socket.impl).socket.listen(50) return grpc_error_none() -cdef void accept_callback_cython(SocketWrapper s): +cdef void accept_callback_cython(SocketWrapper s) except *: try: conn, address = s.socket.accept() sw = SocketWrapper() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi index fa356d913e..00a1b23a67 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc_string.pyx.pxi @@ -15,7 +15,6 @@ import logging _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) # This function will ascii encode unicode string inputs if neccesary. # In Python3, unicode strings are the default str type. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi index a18c365807..fc72ac1576 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pxd.pxi @@ -14,10 +14,10 @@ cdef void _store_c_metadata( - metadata, grpc_metadata **c_metadata, size_t *c_count) + metadata, grpc_metadata **c_metadata, size_t *c_count) except * -cdef void _release_c_metadata(grpc_metadata *c_metadata, int count) +cdef void _release_c_metadata(grpc_metadata *c_metadata, int count) except * cdef tuple _metadatum(grpc_slice key_slice, grpc_slice value_slice) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi index 53f0c7f0bb..caf867b569 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi @@ -25,7 +25,7 @@ _Metadatum = collections.namedtuple('_Metadatum', ('key', 'value',)) cdef void _store_c_metadata( - metadata, grpc_metadata **c_metadata, size_t *c_count): + metadata, grpc_metadata **c_metadata, size_t *c_count) except *: if metadata is None: c_count[0] = 0 c_metadata[0] = NULL @@ -45,7 +45,7 @@ cdef void _store_c_metadata( c_metadata[0][index].value = _slice_from_bytes(encoded_value) -cdef void _release_c_metadata(grpc_metadata *c_metadata, int count): +cdef void _release_c_metadata(grpc_metadata *c_metadata, int count) except *: if 0 < count: for index in range(count): grpc_slice_unref(c_metadata[index].key) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/operation.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/operation.pxd.pxi index 69a2a4989e..c9df32dadf 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/operation.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/operation.pxd.pxi @@ -15,8 +15,8 @@ cdef class Operation: - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * # TODO(https://github.com/grpc/grpc/issues/7950): Eliminate this! cdef grpc_op c_op @@ -29,8 +29,8 @@ cdef class SendInitialMetadataOperation(Operation): cdef grpc_metadata *_c_initial_metadata cdef size_t _c_initial_metadata_count - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class SendMessageOperation(Operation): @@ -39,16 +39,16 @@ cdef class SendMessageOperation(Operation): cdef readonly int _flags cdef grpc_byte_buffer *_c_message_byte_buffer - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class SendCloseFromClientOperation(Operation): cdef readonly int _flags - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class SendStatusFromServerOperation(Operation): @@ -61,8 +61,8 @@ cdef class SendStatusFromServerOperation(Operation): cdef size_t _c_trailing_metadata_count cdef grpc_slice _c_details - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class ReceiveInitialMetadataOperation(Operation): @@ -71,8 +71,8 @@ cdef class ReceiveInitialMetadataOperation(Operation): cdef tuple _initial_metadata cdef grpc_metadata_array _c_initial_metadata - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class ReceiveMessageOperation(Operation): @@ -81,8 +81,8 @@ cdef class ReceiveMessageOperation(Operation): cdef grpc_byte_buffer *_c_message_byte_buffer cdef bytes _message - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class ReceiveStatusOnClientOperation(Operation): @@ -97,8 +97,8 @@ cdef class ReceiveStatusOnClientOperation(Operation): cdef str _details cdef str _error_string - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * cdef class ReceiveCloseOnServerOperation(Operation): @@ -107,5 +107,5 @@ cdef class ReceiveCloseOnServerOperation(Operation): cdef object _cancelled cdef int _c_cancelled - cdef void c(self) - cdef void un_c(self) + cdef void c(self) except * + cdef void un_c(self) except * diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/operation.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/operation.pyx.pxi index 454627f570..c8a390106a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/operation.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/operation.pyx.pxi @@ -15,10 +15,10 @@ cdef class Operation: - cdef void c(self): + cdef void c(self) except *: raise NotImplementedError() - cdef void un_c(self): + cdef void un_c(self) except *: raise NotImplementedError() @@ -31,7 +31,7 @@ cdef class SendInitialMetadataOperation(Operation): def type(self): return GRPC_OP_SEND_INITIAL_METADATA - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_SEND_INITIAL_METADATA self.c_op.flags = self._flags _store_c_metadata( @@ -41,7 +41,7 @@ cdef class SendInitialMetadataOperation(Operation): self.c_op.data.send_initial_metadata.count = self._c_initial_metadata_count self.c_op.data.send_initial_metadata.maybe_compression_level.is_set = 0 - cdef void un_c(self): + cdef void un_c(self) except *: _release_c_metadata( self._c_initial_metadata, self._c_initial_metadata_count) @@ -55,7 +55,7 @@ cdef class SendMessageOperation(Operation): def type(self): return GRPC_OP_SEND_MESSAGE - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_SEND_MESSAGE self.c_op.flags = self._flags cdef grpc_slice message_slice = grpc_slice_from_copied_buffer( @@ -65,7 +65,7 @@ cdef class SendMessageOperation(Operation): grpc_slice_unref(message_slice) self.c_op.data.send_message.send_message = self._c_message_byte_buffer - cdef void un_c(self): + cdef void un_c(self) except *: grpc_byte_buffer_destroy(self._c_message_byte_buffer) @@ -77,11 +77,11 @@ cdef class SendCloseFromClientOperation(Operation): def type(self): return GRPC_OP_SEND_CLOSE_FROM_CLIENT - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_SEND_CLOSE_FROM_CLIENT self.c_op.flags = self._flags - cdef void un_c(self): + cdef void un_c(self) except *: pass @@ -96,7 +96,7 @@ cdef class SendStatusFromServerOperation(Operation): def type(self): return GRPC_OP_SEND_STATUS_FROM_SERVER - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_SEND_STATUS_FROM_SERVER self.c_op.flags = self._flags _store_c_metadata( @@ -110,7 +110,7 @@ cdef class SendStatusFromServerOperation(Operation): self._c_details = _slice_from_bytes(_encode(self._details)) self.c_op.data.send_status_from_server.status_details = &self._c_details - cdef void un_c(self): + cdef void un_c(self) except *: grpc_slice_unref(self._c_details) _release_c_metadata( self._c_trailing_metadata, self._c_trailing_metadata_count) @@ -124,14 +124,14 @@ cdef class ReceiveInitialMetadataOperation(Operation): def type(self): return GRPC_OP_RECV_INITIAL_METADATA - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_RECV_INITIAL_METADATA self.c_op.flags = self._flags grpc_metadata_array_init(&self._c_initial_metadata) self.c_op.data.receive_initial_metadata.receive_initial_metadata = ( &self._c_initial_metadata) - cdef void un_c(self): + cdef void un_c(self) except *: self._initial_metadata = _metadata(&self._c_initial_metadata) grpc_metadata_array_destroy(&self._c_initial_metadata) @@ -147,13 +147,13 @@ cdef class ReceiveMessageOperation(Operation): def type(self): return GRPC_OP_RECV_MESSAGE - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_RECV_MESSAGE self.c_op.flags = self._flags self.c_op.data.receive_message.receive_message = ( &self._c_message_byte_buffer) - cdef void un_c(self): + cdef void un_c(self) except *: cdef grpc_byte_buffer_reader message_reader cdef bint message_reader_status cdef grpc_slice message_slice @@ -189,7 +189,7 @@ cdef class ReceiveStatusOnClientOperation(Operation): def type(self): return GRPC_OP_RECV_STATUS_ON_CLIENT - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_RECV_STATUS_ON_CLIENT self.c_op.flags = self._flags grpc_metadata_array_init(&self._c_trailing_metadata) @@ -202,7 +202,7 @@ cdef class ReceiveStatusOnClientOperation(Operation): self.c_op.data.receive_status_on_client.error_string = ( &self._c_error_string) - cdef void un_c(self): + cdef void un_c(self) except *: self._trailing_metadata = _metadata(&self._c_trailing_metadata) grpc_metadata_array_destroy(&self._c_trailing_metadata) self._code = self._c_code @@ -235,12 +235,12 @@ cdef class ReceiveCloseOnServerOperation(Operation): def type(self): return GRPC_OP_RECV_CLOSE_ON_SERVER - cdef void c(self): + cdef void c(self) except *: self.c_op.type = GRPC_OP_RECV_CLOSE_ON_SERVER self.c_op.flags = self._flags self.c_op.data.receive_close_on_server.cancelled = &self._c_cancelled - cdef void un_c(self): + cdef void un_c(self) except *: self._cancelled = bool(self._c_cancelled) def cancelled(self): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pxd.pxi new file mode 100644 index 0000000000..cd6e94c816 --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pxd.pxi @@ -0,0 +1,20 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef extern from "grpc/impl/codegen/propagation_bits.h": + cdef int _GRPC_PROPAGATE_DEADLINE "GRPC_PROPAGATE_DEADLINE" + cdef int _GRPC_PROPAGATE_CENSUS_STATS_CONTEXT "GRPC_PROPAGATE_CENSUS_STATS_CONTEXT" + cdef int _GRPC_PROPAGATE_CENSUS_TRACING_CONTEXT "GRPC_PROPAGATE_CENSUS_TRACING_CONTEXT" + cdef int _GRPC_PROPAGATE_CANCELLATION "GRPC_PROPAGATE_CANCELLATION" + cdef int _GRPC_PROPAGATE_DEFAULTS "GRPC_PROPAGATE_DEFAULTS" diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pyx.pxi new file mode 100644 index 0000000000..2dcc76a2db --- /dev/null +++ b/src/python/grpcio/grpc/_cython/_cygrpc/propagation_bits.pyx.pxi @@ -0,0 +1,20 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class PropagationConstants: + GRPC_PROPAGATE_DEADLINE = _GRPC_PROPAGATE_DEADLINE + GRPC_PROPAGATE_CENSUS_STATS_CONTEXT = _GRPC_PROPAGATE_CENSUS_STATS_CONTEXT + GRPC_PROPAGATE_CENSUS_TRACING_CONTEXT = _GRPC_PROPAGATE_CENSUS_TRACING_CONTEXT + GRPC_PROPAGATE_CANCELLATION = _GRPC_PROPAGATE_CANCELLATION + GRPC_PROPAGATE_DEFAULTS = _GRPC_PROPAGATE_DEFAULTS diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi index f9d1e863ca..e89e02b171 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi @@ -19,7 +19,6 @@ import time import grpc _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) cdef class Server: @@ -129,7 +128,10 @@ cdef class Server: with nogil: grpc_server_cancel_all_calls(self.c_server) - def __dealloc__(self): + # TODO(https://github.com/grpc/grpc/issues/17515) Determine what, if any, + # portion of this is safe to call from __dealloc__, and potentially remove + # backup_shutdown_queue. + def destroy(self): if self.c_server != NULL: if not self.is_started: pass @@ -147,4 +149,8 @@ cdef class Server: while not self.is_shutdown: time.sleep(0) grpc_server_destroy(self.c_server) - grpc_shutdown() + self.c_server = NULL + + def __dealloc(self): + if self.c_server == NULL: + grpc_shutdown() diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/tag.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/tag.pxd.pxi index f9a3b5e8f4..d8ba1ea9bd 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/tag.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/tag.pxd.pxi @@ -32,7 +32,7 @@ cdef class _RequestCallTag(_Tag): cdef CallDetails call_details cdef grpc_metadata_array c_invocation_metadata - cdef void prepare(self) + cdef void prepare(self) except * cdef RequestCallEvent event(self, grpc_event c_event) @@ -44,7 +44,7 @@ cdef class _BatchOperationTag(_Tag): cdef grpc_op *c_ops cdef size_t c_nops - cdef void prepare(self) + cdef void prepare(self) except * cdef BatchOperationEvent event(self, grpc_event c_event) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/tag.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/tag.pyx.pxi index aaca458442..be5013c8f7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/tag.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/tag.pyx.pxi @@ -35,7 +35,7 @@ cdef class _RequestCallTag(_Tag): self.call = None self.call_details = None - cdef void prepare(self): + cdef void prepare(self) except *: self.call = Call() self.call_details = CallDetails() grpc_metadata_array_init(&self.c_invocation_metadata) @@ -55,7 +55,7 @@ cdef class _BatchOperationTag: self._operations = operations self._retained_call = call - cdef void prepare(self): + cdef void prepare(self) except *: self.c_nops = 0 if self._operations is None else len(self._operations) if 0 < self.c_nops: self.c_ops = <grpc_op *>gpr_malloc(sizeof(grpc_op) * self.c_nops) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/time.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/time.pxd.pxi index ce67c61eaf..1319ac0481 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/time.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/time.pxd.pxi @@ -16,4 +16,4 @@ cdef gpr_timespec _timespec_from_time(object time) -cdef double _time_from_timespec(gpr_timespec timespec) +cdef double _time_from_timespec(gpr_timespec timespec) except * diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/time.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/time.pyx.pxi index 7a668680b8..c452dd54f8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/time.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/time.pyx.pxi @@ -24,7 +24,7 @@ cdef gpr_timespec _timespec_from_time(object time): return timespec -cdef double _time_from_timespec(gpr_timespec timespec): +cdef double _time_from_timespec(gpr_timespec timespec) except *: cdef gpr_timespec real_timespec = gpr_convert_clock_type( timespec, GPR_CLOCK_REALTIME) return <double>real_timespec.seconds + <double>real_timespec.nanoseconds / 1e9 diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pxd b/src/python/grpcio/grpc/_cython/cygrpc.pxd index 8258b857bc..64cae6b34d 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pxd +++ b/src/python/grpcio/grpc/_cython/cygrpc.pxd @@ -29,6 +29,7 @@ include "_cygrpc/server.pxd.pxi" include "_cygrpc/tag.pxd.pxi" include "_cygrpc/time.pxd.pxi" include "_cygrpc/_hooks.pxd.pxi" +include "_cygrpc/propagation_bits.pxd.pxi" include "_cygrpc/grpc_gevent.pxd.pxi" diff --git a/src/python/grpcio/grpc/_cython/cygrpc.pyx b/src/python/grpcio/grpc/_cython/cygrpc.pyx index ae5c07bfc8..ce98fa3a8e 100644 --- a/src/python/grpcio/grpc/_cython/cygrpc.pyx +++ b/src/python/grpcio/grpc/_cython/cygrpc.pyx @@ -35,6 +35,8 @@ include "_cygrpc/server.pyx.pxi" include "_cygrpc/tag.pyx.pxi" include "_cygrpc/time.pyx.pxi" include "_cygrpc/_hooks.pyx.pxi" +include "_cygrpc/channelz.pyx.pxi" +include "_cygrpc/propagation_bits.pyx.pxi" include "_cygrpc/grpc_gevent.pyx.pxi" diff --git a/src/python/grpcio/grpc/_grpcio_metadata.py b/src/python/grpcio/grpc/_grpcio_metadata.py index 42b3a1ad49..7a9f173947 100644 --- a/src/python/grpcio/grpc/_grpcio_metadata.py +++ b/src/python/grpcio/grpc/_grpcio_metadata.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc/_grpcio_metadata.py.template`!!! -__version__ = """1.17.0.dev0""" +__version__ = """1.18.0.dev0""" diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 4345114026..fc0ad77eb9 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -135,9 +135,12 @@ class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): def __iter__(self): return self - def next(self): + def __next__(self): raise self._exception + def next(self): + return self.__next__() + class _UnaryOutcome(grpc.Call, grpc.Future): @@ -232,8 +235,8 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): credentials=new_credentials, wait_for_ready=new_wait_for_ready) return _UnaryOutcome(response, call) - except grpc.RpcError: - raise + except grpc.RpcError as rpc_error: + return rpc_error except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) @@ -354,8 +357,8 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): credentials=new_credentials, wait_for_ready=new_wait_for_ready) return _UnaryOutcome(response, call) - except grpc.RpcError: - raise + except grpc.RpcError as rpc_error: + return rpc_error except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index 53af2ff913..916ee080b6 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -21,7 +21,6 @@ from grpc import _common from grpc._cython import cygrpc _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) class _AuthMetadataContext( diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 7276a7fd90..eb750ef1a8 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -48,7 +48,7 @@ _CANCELLED = 'cancelled' _EMPTY_FLAGS = 0 -_UNEXPECTED_EXIT_SERVER_GRACE = 1.0 +_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0 def _serialized_request(request_event): @@ -291,6 +291,10 @@ class _Context(grpc.ServicerContext): self._state.abortion = Exception() raise self._state.abortion + def abort_with_status(self, status): + self._state.trailing_metadata = status.trailing_metadata + self.abort(status.code, status.details) + def set_code(self, code): with self._state.condition: self._state.code = code @@ -480,43 +484,51 @@ def _status(rpc_event, state, serialized_response): def _unary_response_in_pool(rpc_event, state, behavior, argument_thunk, request_deserializer, response_serializer): - argument = argument_thunk() - if argument is not None: - response, proceed = _call_behavior(rpc_event, state, behavior, argument, - request_deserializer) - if proceed: - serialized_response = _serialize_response( - rpc_event, state, response, response_serializer) - if serialized_response is not None: - _status(rpc_event, state, serialized_response) + cygrpc.install_census_context_from_call(rpc_event.call) + try: + argument = argument_thunk() + if argument is not None: + response, proceed = _call_behavior(rpc_event, state, behavior, + argument, request_deserializer) + if proceed: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + _status(rpc_event, state, serialized_response) + finally: + cygrpc.uninstall_context() def _stream_response_in_pool(rpc_event, state, behavior, argument_thunk, request_deserializer, response_serializer): - argument = argument_thunk() - if argument is not None: - response_iterator, proceed = _call_behavior( - rpc_event, state, behavior, argument, request_deserializer) - if proceed: - while True: - response, proceed = _take_response_from_response_iterator( - rpc_event, state, response_iterator) - if proceed: - if response is None: - _status(rpc_event, state, None) - break - else: - serialized_response = _serialize_response( - rpc_event, state, response, response_serializer) - if serialized_response is not None: - proceed = _send_response(rpc_event, state, - serialized_response) - if not proceed: - break - else: + cygrpc.install_census_context_from_call(rpc_event.call) + try: + argument = argument_thunk() + if argument is not None: + response_iterator, proceed = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer) + if proceed: + while True: + response, proceed = _take_response_from_response_iterator( + rpc_event, state, response_iterator) + if proceed: + if response is None: + _status(rpc_event, state, None) break - else: - break + else: + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer) + if serialized_response is not None: + proceed = _send_response( + rpc_event, state, serialized_response) + if not proceed: + break + else: + break + else: + break + finally: + cygrpc.uninstall_context() def _handle_unary_unary(rpc_event, state, method_handler, thread_pool): @@ -664,6 +676,9 @@ class _ServerState(object): self.rpc_states = set() self.due = set() + # A "volatile" flag to interrupt the daemon serving thread + self.server_deallocated = False + def _add_generic_handlers(state, generic_handlers): with state.lock: @@ -690,6 +705,7 @@ def _request_call(state): # TODO(https://github.com/grpc/grpc/issues/6597): delete this function. def _stop_serving(state): if not state.rpc_states and not state.due: + state.server.destroy() for shutdown_event in state.shutdown_events: shutdown_event.set() state.stage = _ServerStage.STOPPED @@ -703,49 +719,69 @@ def _on_call_completed(state): state.active_rpc_count -= 1 -def _serve(state): - while True: - event = state.completion_queue.poll() - if event.tag is _SHUTDOWN_TAG: +def _process_event_and_continue(state, event): + should_continue = True + if event.tag is _SHUTDOWN_TAG: + with state.lock: + state.due.remove(_SHUTDOWN_TAG) + if _stop_serving(state): + should_continue = False + elif event.tag is _REQUEST_CALL_TAG: + with state.lock: + state.due.remove(_REQUEST_CALL_TAG) + concurrency_exceeded = ( + state.maximum_concurrent_rpcs is not None and + state.active_rpc_count >= state.maximum_concurrent_rpcs) + rpc_state, rpc_future = _handle_call( + event, state.generic_handlers, state.interceptor_pipeline, + state.thread_pool, concurrency_exceeded) + if rpc_state is not None: + state.rpc_states.add(rpc_state) + if rpc_future is not None: + state.active_rpc_count += 1 + rpc_future.add_done_callback( + lambda unused_future: _on_call_completed(state)) + if state.stage is _ServerStage.STARTED: + _request_call(state) + elif _stop_serving(state): + should_continue = False + else: + rpc_state, callbacks = event.tag(event) + for callback in callbacks: + callable_util.call_logging_exceptions(callback, + 'Exception calling callback!') + if rpc_state is not None: with state.lock: - state.due.remove(_SHUTDOWN_TAG) + state.rpc_states.remove(rpc_state) if _stop_serving(state): - return - elif event.tag is _REQUEST_CALL_TAG: - with state.lock: - state.due.remove(_REQUEST_CALL_TAG) - concurrency_exceeded = ( - state.maximum_concurrent_rpcs is not None and - state.active_rpc_count >= state.maximum_concurrent_rpcs) - rpc_state, rpc_future = _handle_call( - event, state.generic_handlers, state.interceptor_pipeline, - state.thread_pool, concurrency_exceeded) - if rpc_state is not None: - state.rpc_states.add(rpc_state) - if rpc_future is not None: - state.active_rpc_count += 1 - rpc_future.add_done_callback( - lambda unused_future: _on_call_completed(state)) - if state.stage is _ServerStage.STARTED: - _request_call(state) - elif _stop_serving(state): - return - else: - rpc_state, callbacks = event.tag(event) - for callback in callbacks: - callable_util.call_logging_exceptions( - callback, 'Exception calling callback!') - if rpc_state is not None: - with state.lock: - state.rpc_states.remove(rpc_state) - if _stop_serving(state): - return + should_continue = False + return should_continue + + +def _serve(state): + while True: + timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S + event = state.completion_queue.poll(timeout) + if state.server_deallocated: + _begin_shutdown_once(state) + if event.completion_type != cygrpc.CompletionType.queue_timeout: + if not _process_event_and_continue(state, event): + return # We want to force the deletion of the previous event # ~before~ we poll again; if the event has a reference # to a shutdown Call object, this can induce spinlock. event = None +def _begin_shutdown_once(state): + with state.lock: + if state.stage is _ServerStage.STARTED: + state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) + state.stage = _ServerStage.GRACE + state.shutdown_events = [] + state.due.add(_SHUTDOWN_TAG) + + def _stop(state, grace): with state.lock: if state.stage is _ServerStage.STOPPED: @@ -753,11 +789,7 @@ def _stop(state, grace): shutdown_event.set() return shutdown_event else: - if state.stage is _ServerStage.STARTED: - state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG) - state.stage = _ServerStage.GRACE - state.shutdown_events = [] - state.due.add(_SHUTDOWN_TAG) + _begin_shutdown_once(state) shutdown_event = threading.Event() state.shutdown_events.append(shutdown_event) if grace is None: @@ -828,7 +860,9 @@ class _Server(grpc.Server): return _stop(self._state, grace) def __del__(self): - _stop(self._state, None) + # We can not grab a lock in __del__(), so set a flag to signal the + # serving daemon thread (if it exists) to initiate shutdown. + self._state.server_deallocated = True def create_server(thread_pool, generic_rpc_handlers, interceptors, options, diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index d90b34bcbd..2938a38b44 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -132,15 +132,12 @@ class _ChannelReadyFuture(grpc.Future): def result(self, timeout=None): self._block(timeout) - return None def exception(self, timeout=None): self._block(timeout) - return None def traceback(self, timeout=None): self._block(timeout) - return None def add_done_callback(self, fn): with self._condition: diff --git a/src/python/grpcio/grpc/beta/BUILD.bazel b/src/python/grpcio/grpc/beta/BUILD.bazel deleted file mode 100644 index 731be5cb25..0000000000 --- a/src/python/grpcio/grpc/beta/BUILD.bazel +++ /dev/null @@ -1,58 +0,0 @@ -load("@grpc_python_dependencies//:requirements.bzl", "requirement") -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "beta", - srcs = ["__init__.py",], - deps = [ - ":client_adaptations", - ":metadata", - ":server_adaptations", - ":implementations", - ":interfaces", - ":utilities", - ], -) - -py_library( - name = "client_adaptations", - srcs = ["_client_adaptations.py"], - imports=["../../",] -) - -py_library( - name = "metadata", - srcs = ["_metadata.py"], -) - -py_library( - name = "server_adaptations", - srcs = ["_server_adaptations.py"], - imports=["../../",], -) - -py_library( - name = "implementations", - srcs = ["implementations.py"], - imports=["../../",], -) - -py_library( - name = "interfaces", - srcs = ["interfaces.py"], - deps = [ - requirement("six"), - ], - imports=["../../",], -) - -py_library( - name = "utilities", - srcs = ["utilities.py"], - deps = [ - ":implementations", - ":interfaces", - "//src/python/grpcio/grpc/framework/foundation", - ], -) - diff --git a/src/python/grpcio/grpc/framework/foundation/callable_util.py b/src/python/grpcio/grpc/framework/foundation/callable_util.py index 36066e19df..24daf3406f 100644 --- a/src/python/grpcio/grpc/framework/foundation/callable_util.py +++ b/src/python/grpcio/grpc/framework/foundation/callable_util.py @@ -22,7 +22,6 @@ import logging import six _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) class Outcome(six.with_metaclass(abc.ABCMeta)): diff --git a/src/python/grpcio/grpc/framework/foundation/logging_pool.py b/src/python/grpcio/grpc/framework/foundation/logging_pool.py index dfb8dbdc30..216e3990db 100644 --- a/src/python/grpcio/grpc/framework/foundation/logging_pool.py +++ b/src/python/grpcio/grpc/framework/foundation/logging_pool.py @@ -18,7 +18,6 @@ import logging from concurrent import futures _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) def _wrap(behavior): diff --git a/src/python/grpcio/grpc/framework/foundation/stream_util.py b/src/python/grpcio/grpc/framework/foundation/stream_util.py index e03130cced..1faaf29bd7 100644 --- a/src/python/grpcio/grpc/framework/foundation/stream_util.py +++ b/src/python/grpcio/grpc/framework/foundation/stream_util.py @@ -20,7 +20,6 @@ from grpc.framework.foundation import stream _NO_VALUE = object() _LOGGER = logging.getLogger(__name__) -_LOGGER.addHandler(logging.NullHandler()) class TransformingConsumer(stream.Consumer): diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index 42a4a7aa04..6a1fd676ca 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -215,6 +215,7 @@ CORE_SOURCE_FILES = [ 'src/core/ext/transport/chttp2/transport/bin_encoder.cc', 'src/core/ext/transport/chttp2/transport/chttp2_plugin.cc', 'src/core/ext/transport/chttp2/transport/chttp2_transport.cc', + 'src/core/ext/transport/chttp2/transport/context_list.cc', 'src/core/ext/transport/chttp2/transport/flow_control.cc', 'src/core/ext/transport/chttp2/transport/frame_data.cc', 'src/core/ext/transport/chttp2/transport/frame_goaway.cc', @@ -321,20 +322,20 @@ CORE_SOURCE_FILES = [ 'src/core/ext/filters/client_channel/http_connect_handshaker.cc', 'src/core/ext/filters/client_channel/http_proxy.cc', 'src/core/ext/filters/client_channel/lb_policy.cc', - 'src/core/ext/filters/client_channel/lb_policy_factory.cc', 'src/core/ext/filters/client_channel/lb_policy_registry.cc', - 'src/core/ext/filters/client_channel/method_params.cc', 'src/core/ext/filters/client_channel/parse_address.cc', 'src/core/ext/filters/client_channel/proxy_mapper.cc', 'src/core/ext/filters/client_channel/proxy_mapper_registry.cc', + 'src/core/ext/filters/client_channel/request_routing.cc', 'src/core/ext/filters/client_channel/resolver.cc', 'src/core/ext/filters/client_channel/resolver_registry.cc', + 'src/core/ext/filters/client_channel/resolver_result_parsing.cc', 'src/core/ext/filters/client_channel/retry_throttle.cc', + 'src/core/ext/filters/client_channel/server_address.cc', 'src/core/ext/filters/client_channel/subchannel.cc', 'src/core/ext/filters/client_channel/subchannel_index.cc', 'src/core/ext/filters/deadline/deadline_filter.cc', 'src/core/ext/filters/client_channel/health/health.pb.c', - 'src/core/tsi/alts_transport_security.cc', 'src/core/tsi/fake_transport_security.cc', 'src/core/tsi/local_transport_security.cc', 'src/core/tsi/ssl/session_cache/ssl_session_boringssl.cc', diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py index 71113e68d9..2e91818d2c 100644 --- a/src/python/grpcio/grpc_version.py +++ b/src/python/grpcio/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_channelz/.gitignore b/src/python/grpcio_channelz/.gitignore new file mode 100644 index 0000000000..0c5da6b5af --- /dev/null +++ b/src/python/grpcio_channelz/.gitignore @@ -0,0 +1,6 @@ +*.proto +*_pb2.py +*_pb2_grpc.py +build/ +grpcio_channelz.egg-info/ +dist/ diff --git a/src/python/grpcio_channelz/MANIFEST.in b/src/python/grpcio_channelz/MANIFEST.in new file mode 100644 index 0000000000..ee93e21a69 --- /dev/null +++ b/src/python/grpcio_channelz/MANIFEST.in @@ -0,0 +1,4 @@ +include grpc_version.py +recursive-include grpc_channelz *.py +global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_channelz/README.rst b/src/python/grpcio_channelz/README.rst new file mode 100644 index 0000000000..efeaa56064 --- /dev/null +++ b/src/python/grpcio_channelz/README.rst @@ -0,0 +1,9 @@ +gRPC Python Channelz package +============================== + +Channelz is a live debug tool in gRPC Python. + +Dependencies +------------ + +Depends on the `grpcio` package, available from PyPI via `pip install grpcio`. diff --git a/src/python/grpcio_channelz/channelz_commands.py b/src/python/grpcio_channelz/channelz_commands.py new file mode 100644 index 0000000000..7f158c2a4b --- /dev/null +++ b/src/python/grpcio_channelz/channelz_commands.py @@ -0,0 +1,67 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides distutils command classes for the GRPC Python setup process.""" + +import os +import shutil + +import setuptools + +ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) +CHANNELZ_PROTO = os.path.join(ROOT_DIR, + '../../proto/grpc/channelz/channelz.proto') +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + + +class Preprocess(setuptools.Command): + """Command to copy proto modules from grpc/src/proto and LICENSE from + the root directory""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + if os.path.isfile(CHANNELZ_PROTO): + shutil.copyfile(CHANNELZ_PROTO, + os.path.join(ROOT_DIR, + 'grpc_channelz/v1/channelz.proto')) + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + + +class BuildPackageProtos(setuptools.Command): + """Command to generate project *_pb2.py modules from proto files.""" + + description = 'build grpc protobuf modules' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + # due to limitations of the proto generator, we require that only *one* + # directory is provided as an 'include' directory. We assume it's the '' key + # to `self.distribution.package_dir` (and get a key error if it's not + # there). + from grpc_tools import command + command.build_package_protos(self.distribution.package_dir['']) diff --git a/src/python/grpcio_channelz/grpc_channelz/__init__.py b/src/python/grpcio_channelz/grpc_channelz/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_channelz/grpc_channelz/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/BUILD.bazel b/src/python/grpcio_channelz/grpc_channelz/v1/BUILD.bazel new file mode 100644 index 0000000000..aae8cedb76 --- /dev/null +++ b/src/python/grpcio_channelz/grpc_channelz/v1/BUILD.bazel @@ -0,0 +1,38 @@ +load("@grpc_python_dependencies//:requirements.bzl", "requirement") +load("@org_pubref_rules_protobuf//python:rules.bzl", "py_proto_library") + +package(default_visibility = ["//visibility:public"]) + +genrule( + name = "mv_channelz_proto", + srcs = [ + "//src/proto/grpc/channelz:channelz_proto_file", + ], + outs = ["channelz.proto",], + cmd = "cp $< $@", +) + +py_proto_library( + name = "py_channelz_proto", + protos = ["mv_channelz_proto",], + imports = [ + "external/com_google_protobuf/src/", + ], + inputs = [ + "@com_google_protobuf//:well_known_protos", + ], + with_grpc = True, + deps = [ + requirement('protobuf'), + ], +) + +py_library( + name = "grpc_channelz", + srcs = ["channelz.py",], + deps = [ + ":py_channelz_proto", + "//src/python/grpcio/grpc:grpcio", + ], + imports=["../../",], +) diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/__init__.py b/src/python/grpcio_channelz/grpc_channelz/v1/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_channelz/grpc_channelz/v1/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py b/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py new file mode 100644 index 0000000000..00eca311dc --- /dev/null +++ b/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py @@ -0,0 +1,142 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Channelz debug service implementation in gRPC Python.""" + +import grpc +from grpc._cython import cygrpc + +import grpc_channelz.v1.channelz_pb2 as _channelz_pb2 +import grpc_channelz.v1.channelz_pb2_grpc as _channelz_pb2_grpc + +from google.protobuf import json_format + + +class ChannelzServicer(_channelz_pb2_grpc.ChannelzServicer): + """Servicer handling RPCs for service statuses.""" + + @staticmethod + def GetTopChannels(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_top_channels(request.start_channel_id), + _channelz_pb2.GetTopChannelsResponse(), + ) + except (ValueError, json_format.ParseError) as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetServers(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_servers(request.start_server_id), + _channelz_pb2.GetServersResponse(), + ) + except (ValueError, json_format.ParseError) as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetServer(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_server(request.server_id), + _channelz_pb2.GetServerResponse(), + ) + except ValueError as e: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(str(e)) + except json_format.ParseError as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetServerSockets(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_server_sockets(request.server_id, + request.start_socket_id, + request.max_results), + _channelz_pb2.GetServerSocketsResponse(), + ) + except ValueError as e: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(str(e)) + except json_format.ParseError as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetChannel(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_channel(request.channel_id), + _channelz_pb2.GetChannelResponse(), + ) + except ValueError as e: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(str(e)) + except json_format.ParseError as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetSubchannel(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_subchannel(request.subchannel_id), + _channelz_pb2.GetSubchannelResponse(), + ) + except ValueError as e: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(str(e)) + except json_format.ParseError as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + @staticmethod + def GetSocket(request, context): + try: + return json_format.Parse( + cygrpc.channelz_get_socket(request.socket_id), + _channelz_pb2.GetSocketResponse(), + ) + except ValueError as e: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(str(e)) + except json_format.ParseError as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + + +def add_channelz_servicer(server): + """Add Channelz servicer to a server. Channelz servicer is in charge of + pulling information from C-Core for entire process. It will allow the + server to response to Channelz queries. + + The Channelz statistic is enabled by default inside C-Core. Whether the + statistic is enabled or not is isolated from adding Channelz servicer. + That means you can query Channelz info with a Channelz-disabled channel, + and you can add Channelz servicer to a Channelz-disabled server. + + The Channelz statistic can be enabled or disabled by channel option + 'grpc.enable_channelz'. Set to 1 to enable, set to 0 to disable. + + This is an EXPERIMENTAL API. + + Args: + server: grpc.Server to which Channelz service will be added. + """ + _channelz_pb2_grpc.add_ChannelzServicer_to_server(ChannelzServicer(), + server) diff --git a/src/python/grpcio_channelz/grpc_version.py b/src/python/grpcio_channelz/grpc_version.py new file mode 100644 index 0000000000..16356ea402 --- /dev/null +++ b/src/python/grpcio_channelz/grpc_version.py @@ -0,0 +1,17 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_channelz/grpc_version.py.template`!!! + +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_channelz/setup.py b/src/python/grpcio_channelz/setup.py new file mode 100644 index 0000000000..f8c0e93913 --- /dev/null +++ b/src/python/grpcio_channelz/setup.py @@ -0,0 +1,96 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup module for the GRPC Python package's Channelz.""" + +import os +import sys + +import setuptools + +# Ensure we're in the proper directory whether or not we're being used by pip. +os.chdir(os.path.dirname(os.path.abspath(__file__))) + +# Break import-style to ensure we can actually find our local modules. +import grpc_version + + +class _NoOpCommand(setuptools.Command): + """No-op command.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + pass + + +CLASSIFIERS = [ + 'Development Status :: 5 - Production/Stable', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'License :: OSI Approved :: Apache Software License', +] + +PACKAGE_DIRECTORIES = { + '': '.', +} + +INSTALL_REQUIRES = ( + 'protobuf>=3.6.0', + 'grpcio>={version}'.format(version=grpc_version.VERSION), +) + +try: + import channelz_commands as _channelz_commands + # we are in the build environment, otherwise the above import fails + SETUP_REQUIRES = ( + 'grpcio-tools=={version}'.format(version=grpc_version.VERSION),) + COMMAND_CLASS = { + # Run preprocess from the repository *before* doing any packaging! + 'preprocess': _channelz_commands.Preprocess, + 'build_package_protos': _channelz_commands.BuildPackageProtos, + } +except ImportError: + SETUP_REQUIRES = () + COMMAND_CLASS = { + # wire up commands to no-op not to break the external dependencies + 'preprocess': _NoOpCommand, + 'build_package_protos': _NoOpCommand, + } + +setuptools.setup( + name='grpcio-channelz', + version=grpc_version.VERSION, + license='Apache License 2.0', + description='Channel Level Live Debug Information Service for gRPC', + author='The gRPC Authors', + author_email='grpc-io@googlegroups.com', + classifiers=CLASSIFIERS, + url='https://grpc.io', + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages('.'), + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, + cmdclass=COMMAND_CLASS) diff --git a/src/python/grpcio_health_checking/MANIFEST.in b/src/python/grpcio_health_checking/MANIFEST.in index 996c74a9d4..3a22311b8e 100644 --- a/src/python/grpcio_health_checking/MANIFEST.in +++ b/src/python/grpcio_health_checking/MANIFEST.in @@ -1,3 +1,4 @@ include grpc_version.py recursive-include grpc_health *.py global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_health_checking/grpc_version.py b/src/python/grpcio_health_checking/grpc_version.py index a30aac2e0b..85fa762f7e 100644 --- a/src/python/grpcio_health_checking/grpc_version.py +++ b/src/python/grpcio_health_checking/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_health_checking/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_health_checking/health_commands.py b/src/python/grpcio_health_checking/health_commands.py index 933f965aa2..3820ef0bba 100644 --- a/src/python/grpcio_health_checking/health_commands.py +++ b/src/python/grpcio_health_checking/health_commands.py @@ -20,10 +20,12 @@ import setuptools ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) HEALTH_PROTO = os.path.join(ROOT_DIR, '../../proto/grpc/health/v1/health.proto') +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') -class CopyProtoModules(setuptools.Command): - """Command to copy proto modules from grpc/src/proto.""" +class Preprocess(setuptools.Command): + """Command to copy proto modules from grpc/src/proto and LICENSE from + the root directory""" description = '' user_options = [] @@ -39,6 +41,8 @@ class CopyProtoModules(setuptools.Command): shutil.copyfile(HEALTH_PROTO, os.path.join(ROOT_DIR, 'grpc_health/v1/health.proto')) + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) class BuildPackageProtos(setuptools.Command): diff --git a/src/python/grpcio_health_checking/setup.py b/src/python/grpcio_health_checking/setup.py index db2edae2ce..5a09a80f6a 100644 --- a/src/python/grpcio_health_checking/setup.py +++ b/src/python/grpcio_health_checking/setup.py @@ -68,7 +68,7 @@ try: 'grpcio-tools=={version}'.format(version=grpc_version.VERSION),) COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _health_commands.CopyProtoModules, + 'preprocess': _health_commands.Preprocess, 'build_package_protos': _health_commands.BuildPackageProtos, } except ImportError: diff --git a/src/python/grpcio_reflection/MANIFEST.in b/src/python/grpcio_reflection/MANIFEST.in index d6fb6ce73a..10b01fa41d 100644 --- a/src/python/grpcio_reflection/MANIFEST.in +++ b/src/python/grpcio_reflection/MANIFEST.in @@ -1,3 +1,4 @@ include grpc_version.py recursive-include grpc_reflection *.py global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_reflection/grpc_version.py b/src/python/grpcio_reflection/grpc_version.py index aafea9fe76..e62ab169a2 100644 --- a/src/python/grpcio_reflection/grpc_version.py +++ b/src/python/grpcio_reflection/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_reflection/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_reflection/reflection_commands.py b/src/python/grpcio_reflection/reflection_commands.py index 6f91f6b875..311ca4c4db 100644 --- a/src/python/grpcio_reflection/reflection_commands.py +++ b/src/python/grpcio_reflection/reflection_commands.py @@ -21,10 +21,12 @@ import setuptools ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) REFLECTION_PROTO = os.path.join( ROOT_DIR, '../../proto/grpc/reflection/v1alpha/reflection.proto') +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') -class CopyProtoModules(setuptools.Command): - """Command to copy proto modules from grpc/src/proto.""" +class Preprocess(setuptools.Command): + """Command to copy proto modules from grpc/src/proto and LICENSE from + the root directory""" description = '' user_options = [] @@ -41,6 +43,8 @@ class CopyProtoModules(setuptools.Command): REFLECTION_PROTO, os.path.join(ROOT_DIR, 'grpc_reflection/v1alpha/reflection.proto')) + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) class BuildPackageProtos(setuptools.Command): diff --git a/src/python/grpcio_reflection/setup.py b/src/python/grpcio_reflection/setup.py index b4087d87b4..f205069acd 100644 --- a/src/python/grpcio_reflection/setup.py +++ b/src/python/grpcio_reflection/setup.py @@ -69,7 +69,7 @@ try: 'grpcio-tools=={version}'.format(version=grpc_version.VERSION),) COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _reflection_commands.CopyProtoModules, + 'preprocess': _reflection_commands.Preprocess, 'build_package_protos': _reflection_commands.BuildPackageProtos, } except ImportError: diff --git a/src/python/grpcio_status/.gitignore b/src/python/grpcio_status/.gitignore new file mode 100644 index 0000000000..19d1523efd --- /dev/null +++ b/src/python/grpcio_status/.gitignore @@ -0,0 +1,3 @@ +build/ +grpcio_status.egg-info/ +dist/ diff --git a/src/python/grpcio_status/MANIFEST.in b/src/python/grpcio_status/MANIFEST.in new file mode 100644 index 0000000000..09b8ea721e --- /dev/null +++ b/src/python/grpcio_status/MANIFEST.in @@ -0,0 +1,4 @@ +include grpc_version.py +recursive-include grpc_status *.py +global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_status/README.rst b/src/python/grpcio_status/README.rst new file mode 100644 index 0000000000..dc2f7b1dab --- /dev/null +++ b/src/python/grpcio_status/README.rst @@ -0,0 +1,9 @@ +gRPC Python Status Proto +=========================== + +Reference package for GRPC Python status proto mapping. + +Dependencies +------------ + +Depends on the `grpcio` package, available from PyPI via `pip install grpcio`. diff --git a/src/python/grpcio_status/grpc_status/BUILD.bazel b/src/python/grpcio_status/grpc_status/BUILD.bazel new file mode 100644 index 0000000000..223a077c3f --- /dev/null +++ b/src/python/grpcio_status/grpc_status/BUILD.bazel @@ -0,0 +1,14 @@ +load("@grpc_python_dependencies//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "grpc_status", + srcs = ["rpc_status.py",], + deps = [ + "//src/python/grpcio/grpc:grpcio", + requirement('protobuf'), + requirement('googleapis-common-protos'), + ], + imports=["../",], +) diff --git a/src/python/grpcio_status/grpc_status/__init__.py b/src/python/grpcio_status/grpc_status/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_status/grpc_status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_status/grpc_status/rpc_status.py b/src/python/grpcio_status/grpc_status/rpc_status.py new file mode 100644 index 0000000000..87618fa541 --- /dev/null +++ b/src/python/grpcio_status/grpc_status/rpc_status.py @@ -0,0 +1,92 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reference implementation for status mapping in gRPC Python.""" + +import collections + +import grpc + +# TODO(https://github.com/bazelbuild/bazel/issues/6844) +# Due to Bazel issue, the namespace packages won't resolve correctly. +# Adding this unused-import as a workaround to avoid module-not-found error +# under Bazel builds. +import google.protobuf # pylint: disable=unused-import +from google.rpc import status_pb2 + +_CODE_TO_GRPC_CODE_MAPPING = {x.value[0]: x for x in grpc.StatusCode} + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + + +class _Status( + collections.namedtuple( + '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status): + pass + + +def _code_to_grpc_status_code(code): + try: + return _CODE_TO_GRPC_CODE_MAPPING[code] + except KeyError: + raise ValueError('Invalid status code %s' % code) + + +def from_call(call): + """Returns a google.rpc.status.Status message corresponding to a given grpc.Call. + + This is an EXPERIMENTAL API. + + Args: + call: A grpc.Call instance. + + Returns: + A google.rpc.status.Status message representing the status of the RPC. + + Raises: + ValueError: If the gRPC call's code or details are inconsistent with the + status code and message inside of the google.rpc.status.Status. + """ + for key, value in call.trailing_metadata(): + if key == _GRPC_DETAILS_METADATA_KEY: + rich_status = status_pb2.Status.FromString(value) + if call.code().value[0] != rich_status.code: + raise ValueError( + 'Code in Status proto (%s) doesn\'t match status code (%s)' + % (_code_to_grpc_status_code(rich_status.code), + call.code())) + if call.details() != rich_status.message: + raise ValueError( + 'Message in Status proto (%s) doesn\'t match status details (%s)' + % (rich_status.message, call.details())) + return rich_status + return None + + +def to_status(status): + """Convert a google.rpc.status.Status message to grpc.Status. + + This is an EXPERIMENTAL API. + + Args: + status: a google.rpc.status.Status message representing the non-OK status + to terminate the RPC with and communicate it to the client. + + Returns: + A grpc.Status instance representing the input google.rpc.status.Status message. + """ + return _Status( + code=_code_to_grpc_status_code(status.code), + details=status.message, + trailing_metadata=((_GRPC_DETAILS_METADATA_KEY, + status.SerializeToString()),)) diff --git a/src/python/grpcio_status/grpc_version.py b/src/python/grpcio_status/grpc_version.py new file mode 100644 index 0000000000..e009843b94 --- /dev/null +++ b/src/python/grpcio_status/grpc_version.py @@ -0,0 +1,17 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_status/grpc_version.py.template`!!! + +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_status/setup.py b/src/python/grpcio_status/setup.py new file mode 100644 index 0000000000..983d3ea430 --- /dev/null +++ b/src/python/grpcio_status/setup.py @@ -0,0 +1,93 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Setup module for the GRPC Python package's status mapping.""" + +import os + +import setuptools + +# Ensure we're in the proper directory whether or not we're being used by pip. +os.chdir(os.path.dirname(os.path.abspath(__file__))) + +# Break import-style to ensure we can actually find our local modules. +import grpc_version + + +class _NoOpCommand(setuptools.Command): + """No-op command.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + pass + + +CLASSIFIERS = [ + 'Development Status :: 5 - Production/Stable', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'License :: OSI Approved :: Apache Software License', +] + +PACKAGE_DIRECTORIES = { + '': '.', +} + +INSTALL_REQUIRES = ( + 'protobuf>=3.6.0', + 'grpcio>={version}'.format(version=grpc_version.VERSION), + 'googleapis-common-protos>=1.5.5', +) + +try: + import status_commands as _status_commands + # we are in the build environment, otherwise the above import fails + COMMAND_CLASS = { + # Run preprocess from the repository *before* doing any packaging! + 'preprocess': _status_commands.Preprocess, + 'build_package_protos': _NoOpCommand, + } +except ImportError: + COMMAND_CLASS = { + # wire up commands to no-op not to break the external dependencies + 'preprocess': _NoOpCommand, + 'build_package_protos': _NoOpCommand, + } + +setuptools.setup( + name='grpcio-status', + version=grpc_version.VERSION, + description='Status proto mapping for gRPC', + author='The gRPC Authors', + author_email='grpc-io@googlegroups.com', + url='https://grpc.io', + license='Apache License 2.0', + classifiers=CLASSIFIERS, + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages('.'), + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS) diff --git a/src/python/grpcio_status/status_commands.py b/src/python/grpcio_status/status_commands.py new file mode 100644 index 0000000000..78cd497f62 --- /dev/null +++ b/src/python/grpcio_status/status_commands.py @@ -0,0 +1,39 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides distutils command classes for the GRPC Python setup process.""" + +import os +import shutil + +import setuptools + +ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + + +class Preprocess(setuptools.Command): + """Command to copy LICENSE from root directory.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) diff --git a/src/python/grpcio_testing/MANIFEST.in b/src/python/grpcio_testing/MANIFEST.in index 39b3565217..559dfaf786 100644 --- a/src/python/grpcio_testing/MANIFEST.in +++ b/src/python/grpcio_testing/MANIFEST.in @@ -1,3 +1,4 @@ include grpc_version.py recursive-include grpc_testing *.py global-exclude *.pyc +include LICENSE diff --git a/src/python/grpcio_testing/grpc_testing/_server/_handler.py b/src/python/grpcio_testing/grpc_testing/_server/_handler.py index 0e3404b0d0..100d8195f6 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_handler.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_handler.py @@ -185,7 +185,7 @@ class _Handler(Handler): elif self._code is None: self._condition.wait() else: - return self._trailing_metadata, self._code, self._details, + return self._trailing_metadata, self._code, self._details def expire(self): with self._condition: diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py index 90eeb130d3..5b1dfeacdf 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py @@ -70,6 +70,9 @@ class ServicerContext(grpc.ServicerContext): def abort(self, code, details): raise NotImplementedError() + def abort_with_status(self, status): + raise NotImplementedError() + def set_code(self, code): self._rpc.set_code(code) diff --git a/src/python/grpcio_testing/grpc_version.py b/src/python/grpcio_testing/grpc_version.py index 876acd3142..7b4c1695fa 100644 --- a/src/python/grpcio_testing/grpc_version.py +++ b/src/python/grpcio_testing/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_testing/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_testing/setup.py b/src/python/grpcio_testing/setup.py index 6ceb1fc5c9..18db71e0f0 100644 --- a/src/python/grpcio_testing/setup.py +++ b/src/python/grpcio_testing/setup.py @@ -24,6 +24,23 @@ os.chdir(os.path.dirname(os.path.abspath(__file__))) # Break import style to ensure that we can find same-directory modules. import grpc_version + +class _NoOpCommand(setuptools.Command): + """No-op command.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + pass + + PACKAGE_DIRECTORIES = { '': '.', } @@ -33,6 +50,19 @@ INSTALL_REQUIRES = ( 'grpcio>={version}'.format(version=grpc_version.VERSION), ) +try: + import testing_commands as _testing_commands + # we are in the build environment, otherwise the above import fails + COMMAND_CLASS = { + # Run preprocess from the repository *before* doing any packaging! + 'preprocess': _testing_commands.Preprocess, + } +except ImportError: + COMMAND_CLASS = { + # wire up commands to no-op not to break the external dependencies + 'preprocess': _NoOpCommand, + } + setuptools.setup( name='grpcio-testing', version=grpc_version.VERSION, @@ -43,4 +73,5 @@ setuptools.setup( url='https://grpc.io', package_dir=PACKAGE_DIRECTORIES, packages=setuptools.find_packages('.'), - install_requires=INSTALL_REQUIRES) + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS) diff --git a/src/python/grpcio_testing/testing_commands.py b/src/python/grpcio_testing/testing_commands.py new file mode 100644 index 0000000000..fb40d37efb --- /dev/null +++ b/src/python/grpcio_testing/testing_commands.py @@ -0,0 +1,39 @@ +# Copyright 2018 gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Provides distutils command classes for the GRPC Python setup process.""" + +import os +import shutil + +import setuptools + +ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) +LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + + +class Preprocess(setuptools.Command): + """Command to copy LICENSE from root directory.""" + + description = '' + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + if os.path.isfile(LICENSE): + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index d163f6fb68..d5327711d3 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -22,7 +22,6 @@ import re import shutil import subprocess import sys -import traceback import setuptools from setuptools.command import build_ext @@ -132,7 +131,21 @@ class TestGevent(setuptools.Command): 'unit.beta._beta_features_test', # TODO(https://github.com/grpc/grpc/issues/15411) unpin gevent version # This test will stuck while running higher version of gevent - 'unit._auth_context_test.AuthContextTest.testSessionResumption') + 'unit._auth_context_test.AuthContextTest.testSessionResumption', + # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests + 'unit._metadata_flags_test', + 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_stream_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call', + 'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call', + # TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels', + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets', + 'channelz._channelz_servicer_test.ChannelzServicerTest.test_streaming_rpc' + ) description = 'run tests with gevent. Assumes grpc/gevent are installed' user_options = [] diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py index cc9b41587c..2fcd1ad617 100644 --- a/src/python/grpcio_tests/grpc_version.py +++ b/src/python/grpcio_tests/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index 61c98fa038..800b865da6 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -37,11 +37,19 @@ PACKAGE_DIRECTORIES = { } INSTALL_REQUIRES = ( - 'coverage>=4.0', 'enum34>=1.0.4', + 'coverage>=4.0', + 'enum34>=1.0.4', 'grpcio>={version}'.format(version=grpc_version.VERSION), + # TODO(https://github.com/pypa/warehouse/issues/5196) + # Re-enable it once we got the name back + # 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), + 'grpcio-status>={version}'.format(version=grpc_version.VERSION), 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), 'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION), - 'oauth2client>=1.4.7', 'protobuf>=3.6.0', 'six>=1.10', 'google-auth>=1.0.0', + 'oauth2client>=1.4.7', + 'protobuf>=3.6.0', + 'six>=1.10', + 'google-auth>=1.0.0', 'requests>=2.14.2') if not PY3: diff --git a/src/python/grpcio_tests/tests/_runner.py b/src/python/grpcio_tests/tests/_runner.py index eaaa027e61..9ef0f17684 100644 --- a/src/python/grpcio_tests/tests/_runner.py +++ b/src/python/grpcio_tests/tests/_runner.py @@ -203,7 +203,7 @@ class Runner(object): check_kill_self() time.sleep(0) case_thread.join() - except: + except: # pylint: disable=try-except-raise # re-raise the exception after forcing the with-block to end raise result.set_output(augmented_case.case, stdout_pipe.output(), diff --git a/src/python/grpcio_tests/tests/channelz/BUILD.bazel b/src/python/grpcio_tests/tests/channelz/BUILD.bazel new file mode 100644 index 0000000000..63513616e7 --- /dev/null +++ b/src/python/grpcio_tests/tests/channelz/BUILD.bazel @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "channelz_servicer_test", + srcs = ["_channelz_servicer_test.py"], + main = "_channelz_servicer_test.py", + size = "small", + deps = [ + "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_channelz/grpc_channelz/v1:grpc_channelz", + "//src/python/grpcio_tests/tests/unit:test_common", + "//src/python/grpcio_tests/tests/unit/framework/common:common", + ], + imports = ["../../",], +) diff --git a/src/python/grpcio_tests/tests/channelz/__init__.py b/src/python/grpcio_tests/tests/channelz/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_tests/tests/channelz/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py new file mode 100644 index 0000000000..c63ff5cd84 --- /dev/null +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -0,0 +1,469 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_channelz.v1.channelz.""" + +import unittest + +from concurrent import futures + +import grpc +from grpc_channelz.v1 import channelz +from grpc_channelz.v1 import channelz_pb2 +from grpc_channelz.v1 import channelz_pb2_grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_SUCCESSFUL_UNARY_UNARY = '/test/SuccessfulUnaryUnary' +_FAILED_UNARY_UNARY = '/test/FailedUnaryUnary' +_SUCCESSFUL_STREAM_STREAM = '/test/SuccessfulStreamStream' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_DISABLE_REUSE_PORT = (('grpc.so_reuseport', 0),) +_ENABLE_CHANNELZ = (('grpc.enable_channelz', 1),) +_DISABLE_CHANNELZ = (('grpc.enable_channelz', 0),) + + +def _successful_unary_unary(request, servicer_context): + return _RESPONSE + + +def _failed_unary_unary(request, servicer_context): + servicer_context.set_code(grpc.StatusCode.INTERNAL) + servicer_context.set_details("Channelz Test Intended Failure") + + +def _successful_stream_stream(request_iterator, servicer_context): + for _ in request_iterator: + yield _RESPONSE + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _SUCCESSFUL_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_successful_unary_unary) + elif handler_call_details.method == _FAILED_UNARY_UNARY: + return grpc.unary_unary_rpc_method_handler(_failed_unary_unary) + elif handler_call_details.method == _SUCCESSFUL_STREAM_STREAM: + return grpc.stream_stream_rpc_method_handler( + _successful_stream_stream) + else: + return None + + +class _ChannelServerPair(object): + + def __init__(self): + # Server will enable channelz service + self.server = grpc.server( + futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + _ENABLE_CHANNELZ) + port = self.server.add_insecure_port('[::]:0') + self.server.add_generic_rpc_handlers((_GenericHandler(),)) + self.server.start() + + # Channel will enable channelz service... + self.channel = grpc.insecure_channel('localhost:%d' % port, + _ENABLE_CHANNELZ) + + +def _generate_channel_server_pairs(n): + return [_ChannelServerPair() for i in range(n)] + + +def _close_channel_server_pairs(pairs): + for pair in pairs: + pair.server.stop(None) + pair.channel.close() + + +@unittest.skip('https://github.com/pypa/warehouse/issues/5196') +class ChannelzServicerTest(unittest.TestCase): + + def _send_successful_unary_unary(self, idx): + _, r = self._pairs[idx].channel.unary_unary( + _SUCCESSFUL_UNARY_UNARY).with_call(_REQUEST) + self.assertEqual(r.code(), grpc.StatusCode.OK) + + def _send_failed_unary_unary(self, idx): + try: + self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( + _REQUEST) + except grpc.RpcError: + return + else: + self.fail("This call supposed to fail") + + def _send_successful_stream_stream(self, idx): + response_iterator = self._pairs[idx].channel.stream_stream( + _SUCCESSFUL_STREAM_STREAM).__call__( + iter([_REQUEST] * test_constants.STREAM_LENGTH)) + cnt = 0 + for _ in response_iterator: + cnt += 1 + self.assertEqual(cnt, test_constants.STREAM_LENGTH) + + def _get_channel_id(self, idx): + """Channel id may not be consecutive""" + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertGreater(len(resp.channel), idx) + return resp.channel[idx].ref.channel_id + + def setUp(self): + self._pairs = [] + # This server is for Channelz info fetching only + # It self should not enable Channelz + self._server = grpc.server( + futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + _DISABLE_CHANNELZ) + port = self._server.add_insecure_port('[::]:0') + channelz.add_channelz_servicer(self._server) + self._server.start() + + # This channel is used to fetch Channelz info only + # Channelz should not be enabled + self._channel = grpc.insecure_channel('localhost:%d' % port, + _DISABLE_CHANNELZ) + self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + _close_channel_server_pairs(self._pairs) + + def test_get_top_channels_basic(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(resp.channel), 1) + self.assertEqual(resp.end, True) + + def test_get_top_channels_high_start_id(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=10000)) + self.assertEqual(len(resp.channel), 0) + self.assertEqual(resp.end, True) + + def test_successful_request(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_successful_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 1) + self.assertEqual(resp.channel.data.calls_failed, 0) + + def test_failed_request(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_failed_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, 1) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 1) + + def test_many_requests(self): + self._pairs = _generate_channel_server_pairs(1) + k_success = 7 + k_failed = 9 + for i in range(k_success): + self._send_successful_unary_unary(0) + for i in range(k_failed): + self._send_failed_unary_unary(0) + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + def test_many_channel(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(resp.channel), k_channels) + + def test_many_requests_many_channel(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 11 + k_failed = 13 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + # The first channel saw only successes + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(resp.channel.data.calls_started, k_success) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, 0) + + # The second channel saw only failures + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(1))) + self.assertEqual(resp.channel.data.calls_started, k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The third channel saw both successes and failures + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(2))) + self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) + self.assertEqual(resp.channel.data.calls_succeeded, k_success) + self.assertEqual(resp.channel.data.calls_failed, k_failed) + + # The fourth channel saw nothing + resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(3))) + self.assertEqual(resp.channel.data.calls_started, 0) + self.assertEqual(resp.channel.data.calls_succeeded, 0) + self.assertEqual(resp.channel.data.calls_failed, 0) + + def test_many_subchannels(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 17 + k_failed = 19 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + gtc_resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(gtc_resp.channel), k_channels) + for i in range(k_channels): + # If no call performed in the channel, there shouldn't be any subchannel + if gtc_resp.channel[i].data.calls_started == 0: + self.assertEqual(len(gtc_resp.channel[i].subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gtc_resp.channel[i].subchannel_ref[ + 0].subchannel_id)) + self.assertEqual(gtc_resp.channel[i].data.calls_started, + gsc_resp.subchannel.data.calls_started) + self.assertEqual(gtc_resp.channel[i].data.calls_succeeded, + gsc_resp.subchannel.data.calls_succeeded) + self.assertEqual(gtc_resp.channel[i].data.calls_failed, + gsc_resp.subchannel.data.calls_failed) + + def test_server_basic(self): + self._pairs = _generate_channel_server_pairs(1) + resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(resp.server), 1) + + def test_get_one_server(self): + self._pairs = _generate_channel_server_pairs(1) + gss_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gss_resp.server), 1) + gs_resp = self._channelz_stub.GetServer( + channelz_pb2.GetServerRequest( + server_id=gss_resp.server[0].ref.server_id)) + self.assertEqual(gss_resp.server[0].ref.server_id, + gs_resp.server.ref.server_id) + + def test_server_call(self): + self._pairs = _generate_channel_server_pairs(1) + k_success = 23 + k_failed = 29 + for i in range(k_success): + self._send_successful_unary_unary(0) + for i in range(k_failed): + self._send_failed_unary_unary(0) + + resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(resp.server), 1) + self.assertEqual(resp.server[0].data.calls_started, + k_success + k_failed) + self.assertEqual(resp.server[0].data.calls_succeeded, k_success) + self.assertEqual(resp.server[0].data.calls_failed, k_failed) + + def test_many_subchannels_and_sockets(self): + k_channels = 4 + self._pairs = _generate_channel_server_pairs(k_channels) + k_success = 3 + k_failed = 5 + for i in range(k_success): + self._send_successful_unary_unary(0) + self._send_successful_unary_unary(2) + for i in range(k_failed): + self._send_failed_unary_unary(1) + self._send_failed_unary_unary(2) + + gtc_resp = self._channelz_stub.GetTopChannels( + channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + self.assertEqual(len(gtc_resp.channel), k_channels) + for i in range(k_channels): + # If no call performed in the channel, there shouldn't be any subchannel + if gtc_resp.channel[i].data.calls_started == 0: + self.assertEqual(len(gtc_resp.channel[i].subchannel_ref), 0) + continue + + # Otherwise, the subchannel should exist + self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gtc_resp.channel[i].subchannel_ref[ + 0].subchannel_id)) + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_started) + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_succeeded) + # Calls started == messages sent, only valid for unary calls + self.assertEqual(gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.messages_sent) + # Only receive responses when the RPC was successful + self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, + gs_resp.socket.data.messages_received) + + def test_streaming_rpc(self): + self._pairs = _generate_channel_server_pairs(1) + # In C++, the argument for _send_successful_stream_stream is message length. + # Here the argument is still channel idx, to be consistent with the other two. + self._send_successful_stream_stream(0) + + gc_resp = self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + self.assertEqual(gc_resp.channel.data.calls_started, 1) + self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) + self.assertEqual(gc_resp.channel.data.calls_failed, 0) + # Subchannel exists + self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) + + gsc_resp = self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest( + subchannel_id=gc_resp.channel.subchannel_ref[0].subchannel_id)) + self.assertEqual(gsc_resp.subchannel.data.calls_started, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, 1) + self.assertEqual(gsc_resp.subchannel.data.calls_failed, 0) + # Socket exists + self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) + self.assertEqual(gs_resp.socket.data.streams_started, 1) + self.assertEqual(gs_resp.socket.data.streams_succeeded, 1) + self.assertEqual(gs_resp.socket.data.streams_failed, 0) + self.assertEqual(gs_resp.socket.data.messages_sent, + test_constants.STREAM_LENGTH) + self.assertEqual(gs_resp.socket.data.messages_received, + test_constants.STREAM_LENGTH) + + def test_server_sockets(self): + self._pairs = _generate_channel_server_pairs(1) + self._send_successful_unary_unary(0) + self._send_failed_unary_unary(0) + + gs_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gs_resp.server), 1) + self.assertEqual(gs_resp.server[0].data.calls_started, 2) + self.assertEqual(gs_resp.server[0].data.calls_succeeded, 1) + self.assertEqual(gs_resp.server[0].data.calls_failed, 1) + + gss_resp = self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest( + server_id=gs_resp.server[0].ref.server_id, start_socket_id=0)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + + def test_server_listen_sockets(self): + self._pairs = _generate_channel_server_pairs(1) + + gss_resp = self._channelz_stub.GetServers( + channelz_pb2.GetServersRequest(start_server_id=0)) + self.assertEqual(len(gss_resp.server), 1) + self.assertEqual(len(gss_resp.server[0].listen_socket), 1) + + gs_resp = self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest( + socket_id=gss_resp.server[0].listen_socket[0].socket_id)) + # If the RPC call failed, it will raise a grpc.RpcError + # So, if there is no exception raised, considered pass + + def test_invalid_query_get_server(self): + try: + self._channelz_stub.GetServer( + channelz_pb2.GetServerRequest(server_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_channel(self): + try: + self._channelz_stub.GetChannel( + channelz_pb2.GetChannelRequest(channel_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_subchannel(self): + try: + self._channelz_stub.GetSubchannel( + channelz_pb2.GetSubchannelRequest(subchannel_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_socket(self): + try: + self._channelz_stub.GetSocket( + channelz_pb2.GetSocketRequest(socket_id=10000)) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + def test_invalid_query_get_server_sockets(self): + try: + self._channelz_stub.GetServerSockets( + channelz_pb2.GetServerSocketsRequest( + server_id=10000, + start_socket_id=0, + )) + except BaseException as e: + self.assertIn('StatusCode.NOT_FOUND', str(e)) + else: + self.fail('Invalid query not detected') + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 350b5eebe5..c1d9436c2f 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -39,8 +39,12 @@ class HealthServicerTest(unittest.TestCase): health_pb2_grpc.add_HealthServicer_to_server(servicer, self._server) self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = health_pb2_grpc.HealthStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = health_pb2_grpc.HealthStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() def test_empty_service(self): request = health_pb2.HealthCheckRequest() diff --git a/src/python/grpcio_tests/tests/interop/BUILD.bazel b/src/python/grpcio_tests/tests/interop/BUILD.bazel index a39f30d32a..aebdbf67eb 100644 --- a/src/python/grpcio_tests/tests/interop/BUILD.bazel +++ b/src/python/grpcio_tests/tests/interop/BUILD.bazel @@ -35,6 +35,10 @@ py_library( requirement('google-auth'), requirement('requests'), requirement('enum34'), + requirement('urllib3'), + requirement('chardet'), + requirement('certifi'), + requirement('idna'), ], imports=["../../",], ) diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py index 698c37017f..56cb29477c 100644 --- a/src/python/grpcio_tests/tests/interop/client.py +++ b/src/python/grpcio_tests/tests/interop/client.py @@ -54,7 +54,6 @@ def _args(): help='replace platform root CAs with ca.pem') parser.add_argument( '--server_host_override', - default="foo.test.google.fr", type=str, help='the server host to which to claim to connect') parser.add_argument( @@ -100,10 +99,13 @@ def _stub(args): channel_credentials = grpc.composite_channel_credentials( channel_credentials, call_credentials) - channel = grpc.secure_channel(target, channel_credentials, (( - 'grpc.ssl_target_name_override', - args.server_host_override, - ),)) + channel_opts = None + if args.server_host_override: + channel_opts = (( + 'grpc.ssl_target_name_override', + args.server_host_override, + ),) + channel = grpc.secure_channel(target, channel_credentials, channel_opts) else: channel = grpc.insecure_channel(target) if args.test_case == "unimplemented_service": diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py index 721dedf0b7..c11f6c8fad 100644 --- a/src/python/grpcio_tests/tests/interop/methods.py +++ b/src/python/grpcio_tests/tests/interop/methods.py @@ -23,7 +23,6 @@ from google.auth import environment_vars as google_auth_environment_vars from google.auth.transport import grpc as google_auth_transport_grpc from google.auth.transport import requests as google_auth_transport_requests import grpc -from grpc.beta import implementations from src.proto.grpc.testing import empty_pb2 from src.proto.grpc.testing import messages_pb2 @@ -377,7 +376,7 @@ def _unimplemented_service(unimplemented_service_stub): def _custom_metadata(stub): initial_metadata_value = "test_initial_metadata_value" - trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b" + trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), (_TRAILING_METADATA_KEY, trailing_metadata_value)) @@ -391,7 +390,7 @@ def _custom_metadata(stub): if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: raise ValueError('expected trailing metadata %s, got %s' % (trailing_metadata_value, - initial_metadata[_TRAILING_METADATA_KEY])) + trailing_metadata[_TRAILING_METADATA_KEY])) # Testing with UnaryCall request = messages_pb2.SimpleRequest( @@ -422,7 +421,7 @@ def _compute_engine_creds(stub, args): def _oauth2_auth_token(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] response = _large_unary_common_behavior(stub, True, True, None) if wanted_email != response.username: raise ValueError('expected username %s, got %s' % (wanted_email, @@ -435,7 +434,7 @@ def _oauth2_auth_token(stub, args): def _jwt_token_creds(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] response = _large_unary_common_behavior(stub, True, False, None) if wanted_email != response.username: raise ValueError('expected username %s, got %s' % (wanted_email, @@ -444,7 +443,7 @@ def _jwt_token_creds(stub, args): def _per_rpc_creds(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] + wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] google_credentials, unused_project_id = google_auth.default( scopes=[args.oauth_scope]) call_credentials = grpc.metadata_call_credentials( diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py index e21ea0010a..2b735526cb 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py @@ -144,7 +144,7 @@ class _ProtoBeforeGrpcProtocStyle(object): absolute_proto_file_names) pb2_grpc_protoc_exit_code = _protoc( proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names) - return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code, + return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code class _GrpcBeforeProtoProtocStyle(object): @@ -160,7 +160,7 @@ class _GrpcBeforeProtoProtocStyle(object): proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names) pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None, absolute_proto_file_names) - return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code, + return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code _PROTOC_STYLES = ( @@ -243,9 +243,9 @@ class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)): def _services_modules(self): if self.PROTOC_STYLE.grpc_in_pb2_expected(): - return self._services_pb2, self._services_pb2_grpc, + return self._services_pb2, self._services_pb2_grpc else: - return self._services_pb2_grpc, + return (self._services_pb2_grpc,) def test_imported_attributes(self): self._protoc() diff --git a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py index b46e53315e..43c90af6a7 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py @@ -223,7 +223,7 @@ def _CreateService(payload_pb2, responses_pb2, service_pb2): server.start() channel = implementations.insecure_channel('localhost', port) stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) - yield servicer_methods, stub, + yield servicer_methods, stub server.stop(0) diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index 0488450740..fac0e44e5a 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -180,7 +180,7 @@ class StreamingSyncBenchmarkClient(BenchmarkClient): self._streams = [ _SyncStream(self._stub, self._generic, self._request, self._handle_response) - for _ in xrange(config.outstanding_rpcs_per_channel) + for _ in range(config.outstanding_rpcs_per_channel) ] self._curr_stream = 0 diff --git a/src/python/grpcio_tests/tests/qps/client_runner.py b/src/python/grpcio_tests/tests/qps/client_runner.py index e79abab3c7..a57524c74e 100644 --- a/src/python/grpcio_tests/tests/qps/client_runner.py +++ b/src/python/grpcio_tests/tests/qps/client_runner.py @@ -77,7 +77,7 @@ class ClosedLoopClientRunner(ClientRunner): def start(self): self._is_running = True self._client.start() - for _ in xrange(self._request_count): + for _ in range(self._request_count): self._client.send_request() def stop(self): diff --git a/src/python/grpcio_tests/tests/qps/worker_server.py b/src/python/grpcio_tests/tests/qps/worker_server.py index 740bdcf1eb..a03367ec63 100644 --- a/src/python/grpcio_tests/tests/qps/worker_server.py +++ b/src/python/grpcio_tests/tests/qps/worker_server.py @@ -39,7 +39,7 @@ class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): self._quit_event = threading.Event() def RunServer(self, request_iterator, context): - config = next(request_iterator).setup + config = next(request_iterator).setup #pylint: disable=stop-iteration-return server, port = self._create_server(config) cores = multiprocessing.cpu_count() server.start() @@ -102,14 +102,14 @@ class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer): return (server, port) def RunClient(self, request_iterator, context): - config = next(request_iterator).setup + config = next(request_iterator).setup #pylint: disable=stop-iteration-return client_runners = [] qps_data = histogram.Histogram(config.histogram_params.resolution, config.histogram_params.max_possible) start_time = time.time() # Create a client for each channel - for i in xrange(config.client_channels): + for i in range(config.client_channels): server = config.server_targets[i % len(config.server_targets)] runner = self._create_client_runner(server, config, qps_data) client_runners.append(runner) diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py index bcd9e14a38..560f6d3ddb 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -56,8 +56,12 @@ class ReflectionServicerTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - self._stub = reflection_pb2_grpc.ServerReflectionStub(channel) + self._channel = grpc.insecure_channel('localhost:%d' % port) + self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) + + def tearDown(self): + self._server.stop(None) + self._channel.close() def testFileByName(self): requests = ( diff --git a/src/python/grpcio_tests/tests/status/BUILD.bazel b/src/python/grpcio_tests/tests/status/BUILD.bazel new file mode 100644 index 0000000000..937e50498e --- /dev/null +++ b/src/python/grpcio_tests/tests/status/BUILD.bazel @@ -0,0 +1,19 @@ +load("@grpc_python_dependencies//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "grpc_status_test", + srcs = ["_grpc_status_test.py"], + main = "_grpc_status_test.py", + size = "small", + deps = [ + "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_status/grpc_status:grpc_status", + "//src/python/grpcio_tests/tests/unit:test_common", + "//src/python/grpcio_tests/tests/unit/framework/common:common", + requirement('protobuf'), + requirement('googleapis-common-protos'), + ], + imports = ["../../",], +) diff --git a/src/python/grpcio_tests/tests/status/__init__.py b/src/python/grpcio_tests/tests/status/__init__.py new file mode 100644 index 0000000000..38fdfc9c5c --- /dev/null +++ b/src/python/grpcio_tests/tests/status/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py new file mode 100644 index 0000000000..519c372a96 --- /dev/null +++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -0,0 +1,173 @@ +# Copyright 2018 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests of grpc_status.""" + +import unittest + +import logging +import traceback + +import grpc +from grpc_status import rpc_status + +from tests.unit import test_common + +from google.protobuf import any_pb2 +from google.rpc import code_pb2, status_pb2, error_details_pb2 + +_STATUS_OK = '/test/StatusOK' +_STATUS_NOT_OK = '/test/StatusNotOk' +_ERROR_DETAILS = '/test/ErrorDetails' +_INCONSISTENT = '/test/Inconsistent' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' + +_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' + +_STATUS_DETAILS = 'This is an error detail' +_STATUS_DETAILS_ANOTHER = 'This is another error detail' + + +def _ok_unary_unary(request, servicer_context): + return _RESPONSE + + +def _not_ok_unary_unary(request, servicer_context): + servicer_context.abort(grpc.StatusCode.INTERNAL, _STATUS_DETAILS) + + +def _error_details_unary_unary(request, servicer_context): + details = any_pb2.Any() + details.Pack( + error_details_pb2.DebugInfo( + stack_entries=traceback.format_stack(), + detail='Intentionally invoked')) + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + details=[details], + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +def _inconsistent_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=code_pb2.INTERNAL, + message=_STATUS_DETAILS, + ) + servicer_context.set_code(grpc.StatusCode.NOT_FOUND) + servicer_context.set_details(_STATUS_DETAILS_ANOTHER) + # User put inconsistent status information in trailing metadata + servicer_context.set_trailing_metadata(((_GRPC_DETAILS_METADATA_KEY, + rich_status.SerializeToString()),)) + + +def _invalid_code_unary_unary(request, servicer_context): + rich_status = status_pb2.Status( + code=42, + message='Invalid code', + ) + servicer_context.abort_with_status(rpc_status.to_status(rich_status)) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _STATUS_OK: + return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) + elif handler_call_details.method == _STATUS_NOT_OK: + return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) + elif handler_call_details.method == _ERROR_DETAILS: + return grpc.unary_unary_rpc_method_handler( + _error_details_unary_unary) + elif handler_call_details.method == _INCONSISTENT: + return grpc.unary_unary_rpc_method_handler( + _inconsistent_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.unary_unary_rpc_method_handler( + _invalid_code_unary_unary) + else: + return None + + +class StatusTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def test_status_ok(self): + _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) + + # Succeed RPC doesn't have status + status = rpc_status.from_call(call) + self.assertIs(status, None) + + def test_status_not_ok(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + # Failed RPC doesn't automatically generate status + status = rpc_status.from_call(rpc_error) + self.assertIs(status, None) + + def test_error_details(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) + rpc_error = exception_context.exception + + status = rpc_status.from_call(rpc_error) + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + + # Check if the underlying proto message is intact + self.assertEqual(status.details[0].Is( + error_details_pb2.DebugInfo.DESCRIPTOR), True) + info = error_details_pb2.DebugInfo() + status.details[0].Unpack(info) + self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + + def test_code_message_validation(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) + + # Code/Message validation failed + self.assertRaises(ValueError, rpc_status.from_call, rpc_error) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) + rpc_error = exception_context.exception + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + # Invalid status code exception raised during coversion + self.assertIn('Invalid status code', rpc_error.details()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py index 41f2e1b6c2..4c35b05044 100644 --- a/src/python/grpcio_tests/tests/stress/client.py +++ b/src/python/grpcio_tests/tests/stress/client.py @@ -71,7 +71,6 @@ def _args(): '--use_tls', help='Whether to use TLS', default=False, type=bool) parser.add_argument( '--server_host_override', - default="foo.test.google.fr", help='the server host to which to claim to connect', type=str) return parser.parse_args() @@ -132,9 +131,9 @@ def run_test(args): server.start() for test_server_target in test_server_targets: - for _ in xrange(args.num_channels_per_server): + for _ in range(args.num_channels_per_server): channel = _get_channel(test_server_target, args) - for _ in xrange(args.num_stubs_per_channel): + for _ in range(args.num_stubs_per_channel): stub = test_pb2_grpc.TestServiceStub(channel) runner = test_runner.TestRunner(stub, test_cases, hist, exception_queue, stop_event) diff --git a/src/python/grpcio_tests/tests/testing/_client_application.py b/src/python/grpcio_tests/tests/testing/_client_application.py index 3ddeba2373..4d42df0389 100644 --- a/src/python/grpcio_tests/tests/testing/_client_application.py +++ b/src/python/grpcio_tests/tests/testing/_client_application.py @@ -130,9 +130,9 @@ def _run_stream_stream(stub): request_pipe = _Pipe() response_iterator = stub.StreStre(iter(request_pipe)) request_pipe.add(_application_common.STREAM_STREAM_REQUEST) - first_responses = next(response_iterator), next(response_iterator), + first_responses = next(response_iterator), next(response_iterator) request_pipe.add(_application_common.STREAM_STREAM_REQUEST) - second_responses = next(response_iterator), next(response_iterator), + second_responses = next(response_iterator), next(response_iterator) request_pipe.close() try: next(response_iterator) diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index a3006d9afc..f202a3f932 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -1,5 +1,6 @@ [ "_sanity._sanity_test.SanityTest", + "channelz._channelz_servicer_test.ChannelzServicerTest", "health_check._health_servicer_test.HealthServicerTest", "interop._insecure_intraop_test.InsecureIntraopTest", "interop._secure_intraop_test.SecureIntraopTest", @@ -14,10 +15,12 @@ "protoc_plugin._split_definitions_test.SplitProtoSingleProtocExecutionProtocStyleTest", "protoc_plugin.beta_python_plugin_test.PythonPluginTest", "reflection._reflection_servicer_test.ReflectionServicerTest", + "status._grpc_status_test.StatusTest", "testing._client_test.ClientTest", "testing._server_test.FirstServiceServicerTest", "testing._time_test.StrictFakeTimeTest", "testing._time_test.StrictRealTimeTest", + "unit._abort_test.AbortTest", "unit._api_test.AllTest", "unit._api_test.ChannelConnectivityTest", "unit._api_test.ChannelTest", @@ -54,6 +57,7 @@ "unit._reconnect_test.ReconnectTest", "unit._resource_exhausted_test.ResourceExhaustedTest", "unit._rpc_test.RPCTest", + "unit._server_shutdown_test.ServerShutdown", "unit._server_ssl_cert_config_test.ServerSSLCertConfigFetcherParamsChecks", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth", diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index de33b81e32..1b462ec67a 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -3,6 +3,7 @@ load("@grpc_python_dependencies//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) GRPCIO_TESTS_UNIT = [ + "_abort_test.py", "_api_test.py", "_auth_context_test.py", "_auth_test.py", @@ -27,6 +28,7 @@ GRPCIO_TESTS_UNIT = [ # TODO(ghostwriternr): To be added later. # "_server_ssl_cert_config_test.py", "_server_test.py", + "_server_shutdown_test.py", "_session_cache_test.py", ] @@ -49,6 +51,11 @@ py_library( ) py_library( + name = "_server_shutdown_scenarios", + srcs = ["_server_shutdown_scenarios.py"], +) + +py_library( name = "_thread_pool", srcs = ["_thread_pool.py"], ) @@ -69,6 +76,7 @@ py_library( ":resources", ":test_common", ":_exit_scenarios", + ":_server_shutdown_scenarios", ":_thread_pool", ":_from_grpc_import_star", "//src/python/grpcio_tests/tests/unit/framework/common", diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py new file mode 100644 index 0000000000..6438f6897a --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -0,0 +1,124 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests server context abort mechanism""" + +import unittest +import collections +import logging + +import grpc + +from tests.unit import test_common +from tests.unit.framework.common import test_constants + +_ABORT = '/test/abort' +_ABORT_WITH_STATUS = '/test/AbortWithStatus' +_INVALID_CODE = '/test/InvalidCode' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_ABORT_DETAILS = 'Abandon ship!' +_ABORT_METADATA = (('a-trailing-metadata', '42'),) + + +class _Status( + collections.namedtuple( + '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status): + pass + + +def abort_unary_unary(request, servicer_context): + servicer_context.abort( + grpc.StatusCode.INTERNAL, + _ABORT_DETAILS, + ) + raise Exception('This line should not be executed!') + + +def abort_with_status_unary_unary(request, servicer_context): + servicer_context.abort_with_status( + _Status( + code=grpc.StatusCode.INTERNAL, + details=_ABORT_DETAILS, + trailing_metadata=_ABORT_METADATA, + )) + raise Exception('This line should not be executed!') + + +def invalid_code_unary_unary(request, servicer_context): + servicer_context.abort( + 42, + _ABORT_DETAILS, + ) + + +class _GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == _ABORT: + return grpc.unary_unary_rpc_method_handler(abort_unary_unary) + elif handler_call_details.method == _ABORT_WITH_STATUS: + return grpc.unary_unary_rpc_method_handler( + abort_with_status_unary_unary) + elif handler_call_details.method == _INVALID_CODE: + return grpc.stream_stream_rpc_method_handler( + invalid_code_unary_unary) + else: + return None + + +class AbortTest(unittest.TestCase): + + def setUp(self): + self._server = test_common.test_server() + port = self._server.add_insecure_port('[::]:0') + self._server.add_generic_rpc_handlers((_GenericHandler(),)) + self._server.start() + + self._channel = grpc.insecure_channel('localhost:%d' % port) + + def tearDown(self): + self._channel.close() + self._server.stop(0) + + def test_abort(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + def test_abort_with_status(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + self.assertEqual(rpc_error.trailing_metadata(), _ABORT_METADATA) + + def test_invalid_code(self): + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + rpc_error = exception_context.exception + + self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) + self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index 38072861a4..0dc6a8718c 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -32,6 +32,7 @@ class AllTest(unittest.TestCase): 'Future', 'ChannelConnectivity', 'StatusCode', + 'Status', 'RpcError', 'RpcContext', 'Call', @@ -100,6 +101,7 @@ class ChannelTest(unittest.TestCase): def test_secure_channel(self): channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel('google.com:443', channel_credentials) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index b1b5bbdcab..96c4e9ec76 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -71,8 +71,8 @@ class AuthContextTest(unittest.TestCase): port = server.add_insecure_port('[::]:0') server.start() - channel = grpc.insecure_channel('localhost:%d' % port) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + with grpc.insecure_channel('localhost:%d' % port) as channel: + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) server.stop(None) auth_data = pickle.loads(response) @@ -98,6 +98,7 @@ class AuthContextTest(unittest.TestCase): channel_creds, options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) @@ -132,6 +133,7 @@ class AuthContextTest(unittest.TestCase): options=_PROPERTY_OPTIONS) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + channel.close() server.stop(None) auth_data = pickle.loads(response) diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py index 727fb7d65f..565bd39b3a 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -75,6 +75,8 @@ class ChannelConnectivityTest(unittest.TestCase): channel.unsubscribe(callback.update) fifth_connectivities = callback.connectivities() + channel.close() + self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities) @@ -108,7 +110,8 @@ class ChannelConnectivityTest(unittest.TestCase): _ready_in_connectivities) second_callback.block_until_connectivities_satisfy( _ready_in_connectivities) - del channel + channel.close() + server.stop(None) self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), first_connectivities) @@ -139,6 +142,7 @@ class ChannelConnectivityTest(unittest.TestCase): callback.block_until_connectivities_satisfy( _last_connectivity_is_not_ready) channel.unsubscribe(callback.update) + channel.close() self.assertFalse(thread_pool.was_used()) diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py index 345460ef40..46a4eb9bb6 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -60,6 +60,8 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertTrue(ready_future.done()) self.assertFalse(ready_future.running()) + channel.close() + def test_immediately_connectable_channel_connectivity(self): thread_pool = _thread_pool.RecordingThreadPool(max_workers=None) server = grpc.server(thread_pool, options=(('grpc.so_reuseport', 0),)) @@ -84,6 +86,9 @@ class ChannelReadyFutureTest(unittest.TestCase): self.assertFalse(ready_future.running()) self.assertFalse(thread_pool.was_used()) + channel.close() + server.stop(None) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index 876d8e827e..87884a19dc 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -77,6 +77,9 @@ class CompressionTest(unittest.TestCase): self._port = self._server.add_insecure_port('[::]:0') self._server.start() + def tearDown(self): + self._server.stop(None) + def testUnary(self): request = b'\x00' * 100 @@ -102,6 +105,7 @@ class CompressionTest(unittest.TestCase): response = multi_callable( request, metadata=[('grpc-internal-encoding-request', 'gzip')]) self.assertEqual(request, response) + compressed_channel.close() def testStreaming(self): request = b'\x00' * 100 @@ -115,6 +119,7 @@ class CompressionTest(unittest.TestCase): call = multi_callable(iter([request] * test_constants.STREAM_LENGTH)) for response in call: self.assertEqual(request, response) + compressed_channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_credentials_test.py b/src/python/grpcio_tests/tests/unit/_credentials_test.py index be7378ecbc..187a6f0388 100644 --- a/src/python/grpcio_tests/tests/unit/_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_credentials_test.py @@ -15,6 +15,7 @@ import unittest import logging +import six import grpc @@ -53,6 +54,16 @@ class CredentialsTest(unittest.TestCase): self.assertIsInstance(channel_first_second_and_third, grpc.ChannelCredentials) + @unittest.skipIf(six.PY2, 'only invalid in Python3') + def test_invalid_string_certificate(self): + self.assertRaises( + TypeError, + grpc.ssl_channel_credentials, + root_certificates='A Certificate', + private_key=None, + certificate_chain=None, + ) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index 3e8393b53c..f27ea422d0 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -96,6 +96,7 @@ class EmptyMessageTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index 6c551df3ec..81de1dae1d 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -71,6 +71,7 @@ class ErrorMessageEncodingTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py index f489db12cb..d1263c2c56 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py @@ -87,7 +87,7 @@ def hang_stream_stream(request_iterator, servicer_context): def hang_partial_stream_stream(request_iterator, servicer_context): for _ in range(test_constants.STREAM_LENGTH // 2): - yield next(request_iterator) + yield next(request_iterator) #pylint: disable=stop-iteration-return time.sleep(WAIT_TIME) diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py index 5226537579..b429ee089f 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_test.py +++ b/src/python/grpcio_tests/tests/unit/_exit_test.py @@ -71,7 +71,6 @@ def wait(process): process.wait() -@unittest.skip('https://github.com/grpc/grpc/issues/7311') class ExitTest(unittest.TestCase): def test_unstarted_server(self): @@ -130,6 +129,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_unary_unary_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], @@ -138,6 +139,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_unary_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], @@ -145,6 +148,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_stream_unary_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], @@ -153,6 +158,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_stream_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], @@ -161,6 +168,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_unary_stream_call(self): process = subprocess.Popen( BASE_COMMAND + @@ -169,6 +178,8 @@ class ExitTest(unittest.TestCase): stderr=sys.stderr) interrupt_and_wait(process) + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_stream_unary_call(self): process = subprocess.Popen( BASE_COMMAND + @@ -178,6 +189,8 @@ class ExitTest(unittest.TestCase): interrupt_and_wait(process) @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999') + @unittest.skipIf(os.name == 'nt', + 'os.kill does not have required permission on Windows') def test_in_flight_partial_stream_stream_call(self): process = subprocess.Popen( BASE_COMMAND + diff --git a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py index ad847ae03e..1ada25382d 100644 --- a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py +++ b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py @@ -14,7 +14,7 @@ _BEFORE_IMPORT = tuple(globals()) -from grpc import * # pylint: disable=wildcard-import +from grpc import * # pylint: disable=wildcard-import,unused-wildcard-import _AFTER_IMPORT = tuple(globals()) diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 99db0ac58b..a647e5e720 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -337,6 +337,7 @@ class InterceptorTest(unittest.TestCase): def tearDown(self): self._server.stop(None) self._server_pool.shutdown(wait=True) + self._channel.close() def testTripleRequestMessagesClientInterceptor(self): diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index 9771764635..7ed7c83893 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -62,6 +62,9 @@ class InvalidMetadataTest(unittest.TestCase): self._stream_unary = _stream_unary_multi_callable(self._channel) self._stream_stream = _stream_stream_multi_callable(self._channel) + def tearDown(self): + self._channel.close() + def testUnaryRequestBlockingUnaryResponse(self): request = b'\x07\x08' metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) @@ -130,6 +133,9 @@ class InvalidMetadataTest(unittest.TestCase): self._stream_stream(request_iterator, metadata=metadata) self.assertIn(expected_error_details, str(exception_context.exception)) + def testInvalidMetadata(self): + self.assertRaises(TypeError, self._unary_unary, b'', metadata=42) + if __name__ == '__main__': logging.basicConfig() diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index 00949e2236..e89b521cc5 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -215,6 +215,7 @@ class InvocationDefectsTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testIterableStreamRequestBlockingUnaryResponse(self): requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)] diff --git a/src/python/grpcio_tests/tests/unit/_logging_test.py b/src/python/grpcio_tests/tests/unit/_logging_test.py index 80b1f1b3c1..631b9de9db 100644 --- a/src/python/grpcio_tests/tests/unit/_logging_test.py +++ b/src/python/grpcio_tests/tests/unit/_logging_test.py @@ -68,6 +68,13 @@ class LoggingTest(unittest.TestCase): self.assertEqual(1, len(logging.getLogger().handlers)) self.assertIs(logging.getLogger().handlers[0].stream, intended_stream) + @isolated_logging + def test_grpc_logger(self): + self.assertIn("grpc", logging.Logger.manager.loggerDict) + root_logger = logging.getLogger("grpc") + self.assertEqual(1, len(root_logger.handlers)) + self.assertIsInstance(root_logger.handlers[0], logging.NullHandler) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 0dafab827a..a63664ac5d 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -198,8 +198,8 @@ class MetadataCodeDetailsTest(unittest.TestCase): port = self._server.add_insecure_port('[::]:0') self._server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) - self._unary_unary = channel.unary_unary( + self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = self._channel.unary_unary( '/'.join(( '', _SERVICE, @@ -208,17 +208,17 @@ class MetadataCodeDetailsTest(unittest.TestCase): request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) - self._unary_stream = channel.unary_stream('/'.join(( + self._unary_stream = self._channel.unary_stream('/'.join(( '', _SERVICE, _UNARY_STREAM, )),) - self._stream_unary = channel.stream_unary('/'.join(( + self._stream_unary = self._channel.stream_unary('/'.join(( '', _SERVICE, _STREAM_UNARY, )),) - self._stream_stream = channel.stream_stream( + self._stream_stream = self._channel.stream_stream( '/'.join(( '', _SERVICE, @@ -228,6 +228,10 @@ class MetadataCodeDetailsTest(unittest.TestCase): response_deserializer=_RESPONSE_DESERIALIZER, ) + def tearDown(self): + self._server.stop(None) + self._channel.close() + def testSuccessfulUnaryUnary(self): self._servicer.set_details(_DETAILS) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 2d352e99d4..7b32b5b5f3 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -187,13 +187,14 @@ class MetadataFlagsTest(unittest.TestCase): def test_call_wait_for_ready_default(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast(perform_call, - create_dummy_channel()) + with create_dummy_channel() as channel: + self.check_connection_does_failfast(perform_call, channel) def test_call_wait_for_ready_disabled(self): for perform_call in _ALL_CALL_CASES: - self.check_connection_does_failfast( - perform_call, create_dummy_channel(), wait_for_ready=False) + with create_dummy_channel() as channel: + self.check_connection_does_failfast( + perform_call, channel, wait_for_ready=False) def test_call_wait_for_ready_enabled(self): # To test the wait mechanism, Python thread is required to make @@ -210,16 +211,16 @@ class MetadataFlagsTest(unittest.TestCase): wg.done() def test_call(perform_call): - try: - channel = grpc.insecure_channel(addr) - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The channel - # object can be collected before calling the callback, which - # will result in a deadlock. - wg.done() - unhandled_exceptions.put(e, True) + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) test_threads = [] for perform_call in _ALL_CALL_CASES: diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 777ab683e3..892df3df08 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -186,6 +186,7 @@ class MetadataTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index f6d4fcbd0a..d4ea126e2b 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -98,6 +98,8 @@ class ReconnectTest(unittest.TestCase): server.add_insecure_port('[::]:{}'.format(port)) server.start() self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) + server.stop(None) + channel.close() if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index 4fead8fcd5..517c2d2f97 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -148,6 +148,7 @@ class ResourceExhaustedTest(unittest.TestCase): def tearDown(self): self._server.stop(0) + self._channel.close() def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py index a768d6c7c1..a99121cee5 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py @@ -193,6 +193,7 @@ class RPCTest(unittest.TestCase): def tearDown(self): self._server.stop(None) + self._channel.close() def testUnrecognizedMethod(self): request = b'abc' diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py new file mode 100644 index 0000000000..1d1fdba11e --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -0,0 +1,97 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines a number of module-scope gRPC scenarios to test server shutdown.""" + +import argparse +import os +import threading +import time +import logging + +import grpc +from tests.unit import test_common + +from concurrent import futures +from six.moves import queue + +WAIT_TIME = 1000 + +REQUEST = b'request' +RESPONSE = b'response' + +SERVER_RAISES_EXCEPTION = 'server_raises_exception' +SERVER_DEALLOCATED = 'server_deallocated' +SERVER_FORK_CAN_EXIT = 'server_fork_can_exit' + +FORK_EXIT = '/test/ForkExit' + + +def fork_and_exit(request, servicer_context): + pid = os.fork() + if pid == 0: + os._exit(0) + return RESPONSE + + +class GenericHandler(grpc.GenericRpcHandler): + + def service(self, handler_call_details): + if handler_call_details.method == FORK_EXIT: + return grpc.unary_unary_rpc_method_handler(fork_and_exit) + else: + return None + + +def run_server(port_queue): + server = test_common.test_server() + port = server.add_insecure_port('[::]:0') + port_queue.put(port) + server.add_generic_rpc_handlers((GenericHandler(),)) + server.start() + # threading.Event.wait() does not exhibit the bug identified in + # https://github.com/grpc/grpc/issues/17093, sleep instead + time.sleep(WAIT_TIME) + + +def run_test(args): + if args.scenario == SERVER_RAISES_EXCEPTION: + server = test_common.test_server() + server.start() + raise Exception() + elif args.scenario == SERVER_DEALLOCATED: + server = test_common.test_server() + server.start() + server.__del__() + while server._state.stage != grpc._server._ServerStage.STOPPED: + pass + elif args.scenario == SERVER_FORK_CAN_EXIT: + port_queue = queue.Queue() + thread = threading.Thread(target=run_server, args=(port_queue,)) + thread.daemon = True + thread.start() + port = port_queue.get() + channel = grpc.insecure_channel('localhost:%d' % port) + multi_callable = channel.unary_unary(FORK_EXIT) + result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) + os.wait() + else: + raise ValueError('unknown test scenario') + + +if __name__ == '__main__': + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument('scenario', type=str) + args = parser.parse_args() + run_test(args) diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py new file mode 100644 index 0000000000..47446d65a5 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py @@ -0,0 +1,90 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests clean shutdown of server on various interpreter exit conditions. + +The tests in this module spawn a subprocess for each test case, the +test is considered successful if it doesn't hang/timeout. +""" + +import atexit +import os +import subprocess +import sys +import threading +import unittest +import logging + +from tests.unit import _server_shutdown_scenarios + +SCENARIO_FILE = os.path.abspath( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + '_server_shutdown_scenarios.py')) +INTERPRETER = sys.executable +BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] + +processes = [] +process_lock = threading.Lock() + + +# Make sure we attempt to clean up any +# processes we may have left running +def cleanup_processes(): + with process_lock: + for process in processes: + try: + process.kill() + except Exception: # pylint: disable=broad-except + pass + + +atexit.register(cleanup_processes) + + +def wait(process): + with process_lock: + processes.append(process) + process.wait() + + +class ServerShutdown(unittest.TestCase): + + # Currently we shut down a server (if possible) after the Python server + # instance is garbage collected. This behavior may change in the future. + def test_deallocated_server_stops(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + def test_server_exception_exits(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + @unittest.skipIf(os.name == 'nt', 'fork not supported on windows') + def test_server_fork_can_exit(self): + process = subprocess.Popen( + BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT], + stdout=sys.stdout, + stderr=sys.stderr) + wait(process) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py b/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py index e733a59a5b..9e4bd61816 100644 --- a/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py +++ b/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py @@ -70,18 +70,11 @@ SERVER_CERT_CHAIN_2_PEM = (resources.cert_hier_2_server_1_cert() + Call = collections.namedtuple('Call', ['did_raise', 'returned_cert_config']) -def _create_client_stub( - port, - expect_success, - root_certificates=None, - private_key=None, - certificate_chain=None, -): - channel = grpc.secure_channel('localhost:{}'.format(port), - grpc.ssl_channel_credentials( - root_certificates=root_certificates, - private_key=private_key, - certificate_chain=certificate_chain)) +def _create_channel(port, credentials): + return grpc.secure_channel('localhost:{}'.format(port), credentials) + + +def _create_client_stub(channel, expect_success): if expect_success: # per Nathaniel: there's some robustness issue if we start # using a channel without waiting for it to be actually ready @@ -176,14 +169,13 @@ class _ServerSSLCertReloadTest( root_certificates=None, private_key=None, certificate_chain=None): - client_stub = _create_client_stub( - self.port, - expect_success, + credentials = grpc.ssl_channel_credentials( root_certificates=root_certificates, private_key=private_key, certificate_chain=certificate_chain) - self._perform_rpc(client_stub, expect_success) - del client_stub + with _create_channel(self.port, credentials) as client_channel: + client_stub = _create_client_stub(client_channel, expect_success) + self._perform_rpc(client_stub, expect_success) def _test(self): # things should work... @@ -259,12 +251,13 @@ class _ServerSSLCertReloadTest( # now create the "persistent" clients self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - persistent_client_stub_A = _create_client_stub( + channel_A = _create_channel( self.port, - True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + grpc.ssl_channel_credentials( + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + persistent_client_stub_A = _create_client_stub(channel_A, True) self._perform_rpc(persistent_client_stub_A, True) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) @@ -273,12 +266,13 @@ class _ServerSSLCertReloadTest( self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - persistent_client_stub_B = _create_client_stub( + channel_B = _create_channel( self.port, - True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + grpc.ssl_channel_credentials( + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + persistent_client_stub_B = _create_client_stub(channel_B, True) self._perform_rpc(persistent_client_stub_B, True) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) @@ -359,6 +353,9 @@ class _ServerSSLCertReloadTest( actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 0) + channel_A.close() + channel_B.close() + class ServerSSLCertConfigFetcherParamsChecks(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/beta/BUILD.bazel b/src/python/grpcio_tests/tests/unit/beta/BUILD.bazel deleted file mode 100644 index d3e0fe20eb..0000000000 --- a/src/python/grpcio_tests/tests/unit/beta/BUILD.bazel +++ /dev/null @@ -1,75 +0,0 @@ -load("@grpc_python_dependencies//:requirements.bzl", "requirement") - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "test_utilities", - srcs = ["test_utilities.py"], - deps = [ - "//src/python/grpcio/grpc:grpcio", - ], -) - -py_test( - name = "_beta_features_test", - srcs = ["_beta_features_test.py"], - main = "_beta_features_test.py", - size = "small", - deps = [ - "//src/python/grpcio/grpc:grpcio", - "//src/python/grpcio_tests/tests/unit:resources", - "//src/python/grpcio_tests/tests/unit/framework/common", - ":test_utilities", - ], - imports=["../../../",], -) - -py_test( - name = "_connectivity_channel_test", - srcs = ["_connectivity_channel_test.py"], - main = "_connectivity_channel_test.py", - size = "small", - deps = [ - "//src/python/grpcio/grpc:grpcio", - ], -) - -# TODO(ghostwriternr): To be added later. -#py_test( -# name = "_implementations_test", -# srcs = ["_implementations_test.py"], -# main = "_implementations_test.py", -# size = "small", -# deps = [ -# "//src/python/grpcio/grpc:grpcio", -# "//src/python/grpcio_tests/tests/unit:resources", -# requirement('oauth2client'), -# ], -# imports=["../../../",], -#) - -py_test( - name = "_not_found_test", - srcs = ["_not_found_test.py"], - main = "_not_found_test.py", - size = "small", - deps = [ - "//src/python/grpcio/grpc:grpcio", - "//src/python/grpcio_tests/tests/unit/framework/common", - ], - imports=["../../../",], -) - -py_test( - name = "_utilities_test", - srcs = ["_utilities_test.py"], - main = "_utilities_test.py", - size = "small", - deps = [ - "//src/python/grpcio/grpc:grpcio", - "//src/python/grpcio_tests/tests/unit/framework/common", - ], - imports=["../../../",], -) - - diff --git a/src/ruby/end2end/graceful_sig_handling_client.rb b/src/ruby/end2end/graceful_sig_handling_client.rb new file mode 100755 index 0000000000..14a67a62cc --- /dev/null +++ b/src/ruby/end2end/graceful_sig_handling_client.rb @@ -0,0 +1,61 @@ +#!/usr/bin/env ruby + +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +require_relative './end2end_common' + +# Test client. Sends RPC's as normal but process also has signal handlers +class SigHandlingClientController < ClientControl::ClientController::Service + def initialize(stub) + @stub = stub + end + + def do_echo_rpc(req, _) + response = @stub.echo(Echo::EchoRequest.new(request: req.request)) + fail 'bad response' unless response.response == req.request + ClientControl::Void.new + end +end + +def main + client_control_port = '' + server_port = '' + OptionParser.new do |opts| + opts.on('--client_control_port=P', String) do |p| + client_control_port = p + end + opts.on('--server_port=P', String) do |p| + server_port = p + end + end.parse! + + # Allow a few seconds to be safe. + srv = new_rpc_server_for_testing + srv.add_http2_port("0.0.0.0:#{client_control_port}", + :this_port_is_insecure) + stub = Echo::EchoServer::Stub.new("localhost:#{server_port}", + :this_channel_is_insecure) + control_service = SigHandlingClientController.new(stub) + srv.handle(control_service) + server_thread = Thread.new do + srv.run_till_terminated_or_interrupted(['int']) + end + srv.wait_till_running + # send a first RPC to notify the parent process that we've started + stub.echo(Echo::EchoRequest.new(request: 'client/child started')) + server_thread.join +end + +main diff --git a/src/ruby/end2end/graceful_sig_handling_driver.rb b/src/ruby/end2end/graceful_sig_handling_driver.rb new file mode 100755 index 0000000000..e12ae28485 --- /dev/null +++ b/src/ruby/end2end/graceful_sig_handling_driver.rb @@ -0,0 +1,83 @@ +#!/usr/bin/env ruby + +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# smoke test for a grpc-using app that receives and +# handles process-ending signals + +require_relative './end2end_common' + +# A service that calls back it's received_rpc_callback +# upon receiving an RPC. Used for synchronization/waiting +# for child process to start. +class ClientStartedService < Echo::EchoServer::Service + def initialize(received_rpc_callback) + @received_rpc_callback = received_rpc_callback + end + + def echo(echo_req, _) + @received_rpc_callback.call unless @received_rpc_callback.nil? + @received_rpc_callback = nil + Echo::EchoReply.new(response: echo_req.request) + end +end + +def main + STDERR.puts 'start server' + client_started = false + client_started_mu = Mutex.new + client_started_cv = ConditionVariable.new + received_rpc_callback = proc do + client_started_mu.synchronize do + client_started = true + client_started_cv.signal + end + end + + client_started_service = ClientStartedService.new(received_rpc_callback) + server_runner = ServerRunner.new(client_started_service) + server_port = server_runner.run + STDERR.puts 'start client' + control_stub, client_pid = start_client('graceful_sig_handling_client.rb', server_port) + + client_started_mu.synchronize do + client_started_cv.wait(client_started_mu) until client_started + end + + control_stub.do_echo_rpc( + ClientControl::DoEchoRpcRequest.new(request: 'hello')) + + STDERR.puts 'killing client' + Process.kill('SIGINT', client_pid) + Process.wait(client_pid) + client_exit_status = $CHILD_STATUS + + if client_exit_status.exited? + if client_exit_status.exitstatus != 0 + STDERR.puts 'Client did not close gracefully' + exit(1) + end + else + STDERR.puts 'Client did not close gracefully' + exit(1) + end + + STDERR.puts 'Client ended gracefully' + + # no need to call cleanup, client should already be dead + server_runner.stop +end + +main diff --git a/src/ruby/end2end/graceful_sig_stop_client.rb b/src/ruby/end2end/graceful_sig_stop_client.rb new file mode 100755 index 0000000000..b672dc3f2a --- /dev/null +++ b/src/ruby/end2end/graceful_sig_stop_client.rb @@ -0,0 +1,78 @@ +#!/usr/bin/env ruby + +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +require_relative './end2end_common' + +# Test client. Sends RPC's as normal but process also has signal handlers +class SigHandlingClientController < ClientControl::ClientController::Service + def initialize(srv, stub) + @srv = srv + @stub = stub + end + + def do_echo_rpc(req, _) + response = @stub.echo(Echo::EchoRequest.new(request: req.request)) + fail 'bad response' unless response.response == req.request + ClientControl::Void.new + end + + def shutdown(_, _) + # Spawn a new thread because RpcServer#stop is + # synchronous and blocks until either this RPC has finished, + # or the server's "poll_period" seconds have passed. + @shutdown_thread = Thread.new do + @srv.stop + end + ClientControl::Void.new + end + + def join_shutdown_thread + @shutdown_thread.join + end +end + +def main + client_control_port = '' + server_port = '' + OptionParser.new do |opts| + opts.on('--client_control_port=P', String) do |p| + client_control_port = p + end + opts.on('--server_port=P', String) do |p| + server_port = p + end + end.parse! + + # The "shutdown" RPC should end very quickly. + # Allow a few seconds to be safe. + srv = new_rpc_server_for_testing(poll_period: 3) + srv.add_http2_port("0.0.0.0:#{client_control_port}", + :this_port_is_insecure) + stub = Echo::EchoServer::Stub.new("localhost:#{server_port}", + :this_channel_is_insecure) + control_service = SigHandlingClientController.new(srv, stub) + srv.handle(control_service) + server_thread = Thread.new do + srv.run_till_terminated_or_interrupted(['int']) + end + srv.wait_till_running + # send a first RPC to notify the parent process that we've started + stub.echo(Echo::EchoRequest.new(request: 'client/child started')) + server_thread.join + control_service.join_shutdown_thread +end + +main diff --git a/src/ruby/end2end/graceful_sig_stop_driver.rb b/src/ruby/end2end/graceful_sig_stop_driver.rb new file mode 100755 index 0000000000..7a132403eb --- /dev/null +++ b/src/ruby/end2end/graceful_sig_stop_driver.rb @@ -0,0 +1,62 @@ +#!/usr/bin/env ruby + +# Copyright 2016 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# smoke test for a grpc-using app that receives and +# handles process-ending signals + +require_relative './end2end_common' + +# A service that calls back it's received_rpc_callback +# upon receiving an RPC. Used for synchronization/waiting +# for child process to start. +class ClientStartedService < Echo::EchoServer::Service + def initialize(received_rpc_callback) + @received_rpc_callback = received_rpc_callback + end + + def echo(echo_req, _) + @received_rpc_callback.call unless @received_rpc_callback.nil? + @received_rpc_callback = nil + Echo::EchoReply.new(response: echo_req.request) + end +end + +def main + STDERR.puts 'start server' + client_started = false + client_started_mu = Mutex.new + client_started_cv = ConditionVariable.new + received_rpc_callback = proc do + client_started_mu.synchronize do + client_started = true + client_started_cv.signal + end + end + + client_started_service = ClientStartedService.new(received_rpc_callback) + server_runner = ServerRunner.new(client_started_service) + server_port = server_runner.run + STDERR.puts 'start client' + control_stub, client_pid = start_client('./graceful_sig_stop_client.rb', server_port) + + client_started_mu.synchronize do + client_started_cv.wait(client_started_mu) until client_started + end + + cleanup(control_stub, client_pid, server_runner) +end + +main diff --git a/src/ruby/ext/grpc/rb_grpc_imports.generated.h b/src/ruby/ext/grpc/rb_grpc_imports.generated.h index 56d222c7ec..e61a35d09f 100644 --- a/src/ruby/ext/grpc/rb_grpc_imports.generated.h +++ b/src/ruby/ext/grpc/rb_grpc_imports.generated.h @@ -272,7 +272,7 @@ extern grpc_channelz_get_servers_type grpc_channelz_get_servers_import; typedef char*(*grpc_channelz_get_server_type)(intptr_t server_id); extern grpc_channelz_get_server_type grpc_channelz_get_server_import; #define grpc_channelz_get_server grpc_channelz_get_server_import -typedef char*(*grpc_channelz_get_server_sockets_type)(intptr_t server_id, intptr_t start_socket_id); +typedef char*(*grpc_channelz_get_server_sockets_type)(intptr_t server_id, intptr_t start_socket_id, intptr_t max_results); extern grpc_channelz_get_server_sockets_type grpc_channelz_get_server_sockets_import; #define grpc_channelz_get_server_sockets grpc_channelz_get_server_sockets_import typedef char*(*grpc_channelz_get_channel_type)(intptr_t channel_id); diff --git a/src/ruby/lib/grpc/generic/rpc_server.rb b/src/ruby/lib/grpc/generic/rpc_server.rb index 3b5a0ce27f..f0f73dc56e 100644 --- a/src/ruby/lib/grpc/generic/rpc_server.rb +++ b/src/ruby/lib/grpc/generic/rpc_server.rb @@ -240,6 +240,13 @@ module GRPC # the call has no impact if the server is already stopped, otherwise # server's current call loop is it's last. def stop + # if called via run_till_terminated_or_interrupted, + # signal stop_server_thread and dont do anything + if @stop_server.nil? == false && @stop_server == false + @stop_server = true + @stop_server_cv.broadcast + return + end @run_mutex.synchronize do fail 'Cannot stop before starting' if @running_state == :not_started return if @running_state != :running @@ -354,6 +361,60 @@ module GRPC alias_method :run_till_terminated, :run + # runs the server with signal handlers + # @param signals + # List of String, Integer or both representing signals that the user + # would like to send to the server for graceful shutdown + # @param wait_interval (optional) + # Integer seconds that user would like stop_server_thread to poll + # stop_server + def run_till_terminated_or_interrupted(signals, wait_interval = 60) + @stop_server = false + @stop_server_mu = Mutex.new + @stop_server_cv = ConditionVariable.new + + @stop_server_thread = Thread.new do + loop do + break if @stop_server + @stop_server_mu.synchronize do + @stop_server_cv.wait(@stop_server_mu, wait_interval) + end + end + + # stop is surrounded by mutex, should handle multiple calls to stop + # correctly + stop + end + + valid_signals = Signal.list + + # register signal handlers + signals.each do |sig| + # input validation + if sig.class == String + sig.upcase! + if sig.start_with?('SIG') + # cut out the SIG prefix to see if valid signal + sig = sig[3..-1] + end + end + + # register signal traps for all valid signals + if valid_signals.value?(sig) || valid_signals.key?(sig) + Signal.trap(sig) do + @stop_server = true + @stop_server_cv.broadcast + end + else + fail "#{sig} not a valid signal" + end + end + + run + + @stop_server_thread.join + end + # Sends RESOURCE_EXHAUSTED if there are too many unprocessed jobs def available?(an_rpc) return an_rpc if @pool.ready_for_work? diff --git a/src/ruby/lib/grpc/generic/service.rb b/src/ruby/lib/grpc/generic/service.rb index 4764217406..169a62f11d 100644 --- a/src/ruby/lib/grpc/generic/service.rb +++ b/src/ruby/lib/grpc/generic/service.rb @@ -95,7 +95,7 @@ module GRPC rpc_descs[name] = RpcDesc.new(name, input, output, marshal_class_method, unmarshal_class_method) - define_method(GenericService.underscore(name.to_s).to_sym) do |_, _| + define_method(GenericService.underscore(name.to_s).to_sym) do |*| fail GRPC::BadStatus.new_status_exception( GRPC::Core::StatusCodes::UNIMPLEMENTED) end diff --git a/src/ruby/lib/grpc/version.rb b/src/ruby/lib/grpc/version.rb index 243d566645..a4ed052d85 100644 --- a/src/ruby/lib/grpc/version.rb +++ b/src/ruby/lib/grpc/version.rb @@ -14,5 +14,5 @@ # GRPC contains the General RPC module. module GRPC - VERSION = '1.17.0.dev' + VERSION = '1.18.0.dev' end diff --git a/src/ruby/pb/grpc/health/checker.rb b/src/ruby/pb/grpc/health/checker.rb index c492455d8f..7ad68409dd 100644 --- a/src/ruby/pb/grpc/health/checker.rb +++ b/src/ruby/pb/grpc/health/checker.rb @@ -14,7 +14,6 @@ require 'grpc' require 'grpc/health/v1/health_services_pb' -require 'thread' module Grpc # Health contains classes and modules that support providing a health check @@ -37,9 +36,9 @@ module Grpc @status_mutex.synchronize do status = @statuses["#{req.service}"] end - if status.nil? + if status.nil? fail GRPC::BadStatus.new_status_exception(StatusCodes::NOT_FOUND) - end + end HealthCheckResponse.new(status: status) end diff --git a/src/ruby/pb/test/client.rb b/src/ruby/pb/test/client.rb index b2303c6e14..03f3d9001a 100755 --- a/src/ruby/pb/test/client.rb +++ b/src/ruby/pb/test/client.rb @@ -111,10 +111,13 @@ def create_stub(opts) if opts.secure creds = ssl_creds(opts.use_test_ca) stub_opts = { - channel_args: { - GRPC::Core::Channel::SSL_TARGET => opts.server_host_override - } + channel_args: {} } + unless opts.server_host_override.empty? + stub_opts[:channel_args].merge!({ + GRPC::Core::Channel::SSL_TARGET => opts.server_host_override + }) + end # Add service account creds if specified wants_creds = %w(all compute_engine_creds service_account_creds) @@ -603,7 +606,7 @@ class NamedTests if not op.metadata.has_key?(initial_metadata_key) fail AssertionError, "Expected initial metadata. None received" elsif op.metadata[initial_metadata_key] != metadata[initial_metadata_key] - fail AssertionError, + fail AssertionError, "Expected initial metadata: #{metadata[initial_metadata_key]}. "\ "Received: #{op.metadata[initial_metadata_key]}" end @@ -611,7 +614,7 @@ class NamedTests fail AssertionError, "Expected trailing metadata. None received" elsif op.trailing_metadata[trailing_metadata_key] != metadata[trailing_metadata_key] - fail AssertionError, + fail AssertionError, "Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\ "Received: #{op.trailing_metadata[trailing_metadata_key]}" end @@ -639,7 +642,7 @@ class NamedTests fail AssertionError, "Expected trailing metadata. None received" elsif duplex_op.trailing_metadata[trailing_metadata_key] != metadata[trailing_metadata_key] - fail AssertionError, + fail AssertionError, "Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\ "Received: #{duplex_op.trailing_metadata[trailing_metadata_key]}" end @@ -710,7 +713,7 @@ Args = Struct.new(:default_service_account, :server_host, :server_host_override, # validates the command line options, returning them as a Hash. def parse_args args = Args.new - args.server_host_override = 'foo.test.google.fr' + args.server_host_override = '' OptionParser.new do |opts| opts.on('--oauth_scope scope', 'Scope for OAuth tokens') { |v| args['oauth_scope'] = v } diff --git a/src/ruby/spec/generic/rpc_server_spec.rb b/src/ruby/spec/generic/rpc_server_spec.rb index 44a6134086..924d747a79 100644 --- a/src/ruby/spec/generic/rpc_server_spec.rb +++ b/src/ruby/spec/generic/rpc_server_spec.rb @@ -342,6 +342,28 @@ describe GRPC::RpcServer do t.join end + it 'should return UNIMPLEMENTED on unimplemented ' \ + 'methods for client_streamer', server: true do + @srv.handle(EchoService) + t = Thread.new { @srv.run } + @srv.wait_till_running + blk = proc do + stub = EchoStub.new(@host, :this_channel_is_insecure, **client_opts) + requests = [EchoMsg.new, EchoMsg.new] + stub.a_client_streaming_rpc_unimplemented(requests) + end + + begin + expect(&blk).to raise_error do |error| + expect(error).to be_a(GRPC::BadStatus) + expect(error.code).to eq(GRPC::Core::StatusCodes::UNIMPLEMENTED) + end + ensure + @srv.stop # should be call not to crash + t.join + end + end + it 'should handle multiple sequential requests', server: true do @srv.handle(EchoService) t = Thread.new { @srv.run } diff --git a/src/ruby/spec/support/services.rb b/src/ruby/spec/support/services.rb index 6e693f1cde..438459dfd7 100644 --- a/src/ruby/spec/support/services.rb +++ b/src/ruby/spec/support/services.rb @@ -33,6 +33,7 @@ class EchoService rpc :a_client_streaming_rpc, stream(EchoMsg), EchoMsg rpc :a_server_streaming_rpc, EchoMsg, stream(EchoMsg) rpc :a_bidi_rpc, stream(EchoMsg), stream(EchoMsg) + rpc :a_client_streaming_rpc_unimplemented, stream(EchoMsg), EchoMsg attr_reader :received_md def initialize(**kw) diff --git a/src/ruby/tools/version.rb b/src/ruby/tools/version.rb index 92e85eb882..389fb70684 100644 --- a/src/ruby/tools/version.rb +++ b/src/ruby/tools/version.rb @@ -14,6 +14,6 @@ module GRPC module Tools - VERSION = '1.17.0.dev' + VERSION = '1.18.0.dev' end end |