aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/core/ext/filters/client_channel/client_channel.cc1035
-rw-r--r--src/core/ext/filters/client_channel/lb_policy.h13
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc83
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc10
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc8
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc8
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/xds/xds.cc8
-rw-r--r--src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc10
-rw-r--r--src/core/ext/filters/client_channel/request_routing.cc936
-rw-r--r--src/core/ext/filters/client_channel/request_routing.h177
-rw-r--r--src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc51
-rw-r--r--src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc47
-rw-r--r--src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc70
-rw-r--r--src/core/ext/filters/client_channel/resolver_result_parsing.cc14
-rw-r--r--src/core/ext/filters/client_channel/resolver_result_parsing.h30
-rw-r--r--src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc17
-rw-r--r--src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc19
-rw-r--r--src/core/ext/transport/chttp2/transport/chttp2_transport.cc7
-rw-r--r--src/core/ext/transport/chttp2/transport/context_list.cc38
-rw-r--r--src/core/ext/transport/chttp2/transport/context_list.h35
-rw-r--r--src/core/lib/gprpp/ref_counted_ptr.h8
-rw-r--r--src/core/lib/http/httpcli_security_connector.cc195
-rw-r--r--src/core/lib/http/parser.h10
-rw-r--r--src/core/lib/iomgr/resource_quota.cc1
-rw-r--r--src/core/lib/iomgr/tcp_posix.cc29
-rw-r--r--src/core/lib/security/context/security_context.cc183
-rw-r--r--src/core/lib/security/context/security_context.h94
-rw-r--r--src/core/lib/security/credentials/alts/alts_credentials.cc84
-rw-r--r--src/core/lib/security/credentials/alts/alts_credentials.h47
-rw-r--r--src/core/lib/security/credentials/composite/composite_credentials.cc254
-rw-r--r--src/core/lib/security/credentials/composite/composite_credentials.h84
-rw-r--r--src/core/lib/security/credentials/credentials.cc160
-rw-r--r--src/core/lib/security/credentials/credentials.h214
-rw-r--r--src/core/lib/security/credentials/fake/fake_credentials.cc117
-rw-r--r--src/core/lib/security/credentials/fake/fake_credentials.h28
-rw-r--r--src/core/lib/security/credentials/google_default/google_default_credentials.cc83
-rw-r--r--src/core/lib/security/credentials/google_default/google_default_credentials.h33
-rw-r--r--src/core/lib/security/credentials/iam/iam_credentials.cc62
-rw-r--r--src/core/lib/security/credentials/iam/iam_credentials.h22
-rw-r--r--src/core/lib/security/credentials/jwt/jwt_credentials.cc129
-rw-r--r--src/core/lib/security/credentials/jwt/jwt_credentials.h39
-rw-r--r--src/core/lib/security/credentials/jwt/jwt_verifier.cc2
-rw-r--r--src/core/lib/security/credentials/local/local_credentials.cc51
-rw-r--r--src/core/lib/security/credentials/local/local_credentials.h43
-rw-r--r--src/core/lib/security/credentials/oauth2/oauth2_credentials.cc279
-rw-r--r--src/core/lib/security/credentials/oauth2/oauth2_credentials.h103
-rw-r--r--src/core/lib/security/credentials/plugin/plugin_credentials.cc136
-rw-r--r--src/core/lib/security/credentials/plugin/plugin_credentials.h57
-rw-r--r--src/core/lib/security/credentials/ssl/ssl_credentials.cc149
-rw-r--r--src/core/lib/security/credentials/ssl/ssl_credentials.h73
-rw-r--r--src/core/lib/security/security_connector/alts/alts_security_connector.cc329
-rw-r--r--src/core/lib/security/security_connector/alts/alts_security_connector.h22
-rw-r--r--src/core/lib/security/security_connector/fake/fake_security_connector.cc425
-rw-r--r--src/core/lib/security/security_connector/fake/fake_security_connector.h15
-rw-r--r--src/core/lib/security/security_connector/local/local_security_connector.cc345
-rw-r--r--src/core/lib/security/security_connector/local/local_security_connector.h19
-rw-r--r--src/core/lib/security/security_connector/security_connector.cc165
-rw-r--r--src/core/lib/security/security_connector/security_connector.h207
-rw-r--r--src/core/lib/security/security_connector/ssl/ssl_security_connector.cc718
-rw-r--r--src/core/lib/security/security_connector/ssl/ssl_security_connector.h26
-rw-r--r--src/core/lib/security/security_connector/ssl_utils.cc22
-rw-r--r--src/core/lib/security/security_connector/ssl_utils.h4
-rw-r--r--src/core/lib/security/transport/client_auth_filter.cc100
-rw-r--r--src/core/lib/security/transport/security_handshaker.cc148
-rw-r--r--src/core/lib/security/transport/server_auth_filter.cc28
-rw-r--r--src/core/lib/surface/server.cc104
-rw-r--r--src/core/lib/surface/version.cc2
-rw-r--r--src/core/lib/transport/metadata.cc1
-rw-r--r--src/core/tsi/ssl_transport_security.cc59
-rw-r--r--src/cpp/client/secure_credentials.cc6
-rw-r--r--src/cpp/client/secure_credentials.h9
-rw-r--r--src/cpp/common/alarm.cc13
-rw-r--r--src/cpp/common/channel_arguments.cc2
-rw-r--r--src/cpp/common/secure_auth_context.cc38
-rw-r--r--src/cpp/common/secure_auth_context.h11
-rw-r--r--src/cpp/common/secure_create_auth_context.cc5
-rw-r--r--src/cpp/common/version_cc.cc2
-rw-r--r--src/cpp/ext/filters/census/context.cc6
-rw-r--r--src/cpp/server/secure_server_credentials.cc2
-rw-r--r--src/cpp/server/server_cc.cc4
-rwxr-xr-xsrc/csharp/Grpc.Core/Version.csproj.include2
-rw-r--r--src/csharp/Grpc.Core/VersionInfo.cs4
-rw-r--r--src/csharp/Grpc.IntegrationTesting/InteropClient.cs2
-rw-r--r--src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets3
-rw-r--r--src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets3
-rwxr-xr-xsrc/csharp/build_packages_dotnetcli.bat2
-rw-r--r--src/csharp/build_unitypackage.bat2
-rw-r--r--src/objective-c/!ProtoCompiler-gRPCPlugin.podspec2
-rw-r--r--src/objective-c/GRPCClient/private/version.h2
-rw-r--r--src/objective-c/README.md9
-rw-r--r--src/objective-c/tests/version.h2
-rw-r--r--src/php/composer.json2
-rw-r--r--src/php/ext/grpc/version.h2
-rwxr-xr-xsrc/php/tests/interop/interop_client.php8
-rw-r--r--src/python/grpcio/grpc/__init__.py41
-rw-r--r--src/python/grpcio/grpc/_auth.py2
-rw-r--r--src/python/grpcio/grpc/_channel.py54
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi13
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi37
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi9
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi2
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi1
-rw-r--r--src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi19
-rw-r--r--src/python/grpcio/grpc/_grpcio_metadata.py2
-rw-r--r--src/python/grpcio/grpc/_server.py110
-rw-r--r--src/python/grpcio/grpc/_utilities.py3
-rw-r--r--src/python/grpcio/grpc_core_dependencies.py1
-rw-r--r--src/python/grpcio/grpc_version.py2
-rw-r--r--src/python/grpcio_channelz/grpc_version.py2
-rw-r--r--src/python/grpcio_health_checking/grpc_health/v1/health.py80
-rw-r--r--src/python/grpcio_health_checking/grpc_version.py2
-rw-r--r--src/python/grpcio_reflection/grpc_version.py2
-rw-r--r--src/python/grpcio_status/.gitignore3
-rw-r--r--src/python/grpcio_status/MANIFEST.in4
-rw-r--r--src/python/grpcio_status/README.rst9
-rw-r--r--src/python/grpcio_status/grpc_status/BUILD.bazel14
-rw-r--r--src/python/grpcio_status/grpc_status/__init__.py13
-rw-r--r--src/python/grpcio_status/grpc_status/rpc_status.py92
-rw-r--r--src/python/grpcio_status/grpc_version.py17
-rw-r--r--src/python/grpcio_status/setup.py93
-rw-r--r--src/python/grpcio_status/status_commands.py39
-rw-r--r--src/python/grpcio_testing/grpc_testing/_server/_handler.py2
-rw-r--r--src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py3
-rw-r--r--src/python/grpcio_testing/grpc_version.py2
-rw-r--r--src/python/grpcio_tests/commands.py11
-rw-r--r--src/python/grpcio_tests/grpc_version.py2
-rw-r--r--src/python/grpcio_tests/setup.py13
-rw-r--r--src/python/grpcio_tests/tests/_runner.py2
-rw-r--r--src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py3
-rw-r--r--src/python/grpcio_tests/tests/health_check/BUILD.bazel1
-rw-r--r--src/python/grpcio_tests/tests/health_check/_health_servicer_test.py187
-rw-r--r--src/python/grpcio_tests/tests/interop/client.py12
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py8
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py2
-rw-r--r--src/python/grpcio_tests/tests/qps/benchmark_client.py2
-rw-r--r--src/python/grpcio_tests/tests/qps/client_runner.py2
-rw-r--r--src/python/grpcio_tests/tests/qps/worker_server.py2
-rw-r--r--src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py8
-rw-r--r--src/python/grpcio_tests/tests/status/BUILD.bazel19
-rw-r--r--src/python/grpcio_tests/tests/status/__init__.py13
-rw-r--r--src/python/grpcio_tests/tests/status/_grpc_status_test.py173
-rw-r--r--src/python/grpcio_tests/tests/stress/client.py5
-rw-r--r--src/python/grpcio_tests/tests/testing/_client_application.py4
-rw-r--r--src/python/grpcio_tests/tests/tests.json4
-rw-r--r--src/python/grpcio_tests/tests/unit/BUILD.bazel9
-rw-r--r--src/python/grpcio_tests/tests/unit/_abort_test.py124
-rw-r--r--src/python/grpcio_tests/tests/unit/_api_test.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_auth_context_test.py6
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py6
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py5
-rw-r--r--src/python/grpcio_tests/tests/unit/_compression_test.py5
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_fork_test.py4
-rw-r--r--src/python/grpcio_tests/tests/unit/_empty_message_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_exit_test.py15
-rw-r--r--src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_interceptor_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py3
-rw-r--r--src/python/grpcio_tests/tests/unit/_invocation_defects_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_logging_test.py104
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py14
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_flags_test.py29
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_reconnect_test.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_rpc_test.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py97
-rw-r--r--src/python/grpcio_tests/tests/unit/_server_shutdown_test.py90
-rw-r--r--src/python/grpcio_tests/tests/unit/_version_test.py30
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_handling_client.rb61
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_handling_driver.rb83
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_stop_client.rb78
-rwxr-xr-xsrc/ruby/end2end/graceful_sig_stop_driver.rb62
-rw-r--r--src/ruby/lib/grpc/generic/rpc_server.rb61
-rw-r--r--src/ruby/lib/grpc/version.rb2
-rwxr-xr-xsrc/ruby/pb/test/client.rb17
-rw-r--r--src/ruby/tools/version.rb2
177 files changed, 6133 insertions, 4200 deletions
diff --git a/src/core/ext/filters/client_channel/client_channel.cc b/src/core/ext/filters/client_channel/client_channel.cc
index 70aac47231..dd741f1e2d 100644
--- a/src/core/ext/filters/client_channel/client_channel.cc
+++ b/src/core/ext/filters/client_channel/client_channel.cc
@@ -35,10 +35,10 @@
#include "src/core/ext/filters/client_channel/http_connect_handshaker.h"
#include "src/core/ext/filters/client_channel/lb_policy_registry.h"
#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h"
+#include "src/core/ext/filters/client_channel/request_routing.h"
#include "src/core/ext/filters/client_channel/resolver_registry.h"
#include "src/core/ext/filters/client_channel/resolver_result_parsing.h"
#include "src/core/ext/filters/client_channel/retry_throttle.h"
-#include "src/core/ext/filters/client_channel/server_address.h"
#include "src/core/ext/filters/client_channel/subchannel.h"
#include "src/core/ext/filters/deadline/deadline_filter.h"
#include "src/core/lib/backoff/backoff.h"
@@ -63,7 +63,6 @@
#include "src/core/lib/transport/static_metadata.h"
#include "src/core/lib/transport/status_metadata.h"
-using grpc_core::ServerAddressList;
using grpc_core::internal::ClientChannelMethodParams;
using grpc_core::internal::ClientChannelMethodParamsTable;
using grpc_core::internal::ProcessedResolverResult;
@@ -88,31 +87,18 @@ grpc_core::TraceFlag grpc_client_channel_trace(false, "client_channel");
struct external_connectivity_watcher;
typedef struct client_channel_channel_data {
- grpc_core::OrphanablePtr<grpc_core::Resolver> resolver;
- bool started_resolving;
+ grpc_core::ManualConstructor<grpc_core::RequestRouter> request_router;
+
bool deadline_checking_enabled;
- grpc_client_channel_factory* client_channel_factory;
bool enable_retries;
size_t per_rpc_retry_buffer_size;
/** combiner protecting all variables below in this data structure */
grpc_combiner* combiner;
- /** currently active load balancer */
- grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> lb_policy;
/** retry throttle data */
grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data;
/** maps method names to method_parameters structs */
grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> method_params_table;
- /** incoming resolver result - set by resolver.next() */
- grpc_channel_args* resolver_result;
- /** a list of closures that are all waiting for resolver result to come in */
- grpc_closure_list waiting_for_resolver_result_closures;
- /** resolver callback */
- grpc_closure on_resolver_result_changed;
- /** connectivity state being tracked */
- grpc_connectivity_state_tracker state_tracker;
- /** when an lb_policy arrives, should we try to exit idle */
- bool exit_idle_when_lb_policy_arrives;
/** owning stack */
grpc_channel_stack* owning_stack;
/** interested parties (owned) */
@@ -129,418 +115,40 @@ typedef struct client_channel_channel_data {
grpc_core::UniquePtr<char> info_lb_policy_name;
/** service config in JSON form */
grpc_core::UniquePtr<char> info_service_config_json;
- /* backpointer to grpc_channel's channelz node */
- grpc_core::channelz::ClientChannelNode* channelz_channel;
- /* caches if the last resolution event contained addresses */
- bool previous_resolution_contained_addresses;
} channel_data;
-typedef struct {
- channel_data* chand;
- /** used as an identifier, don't dereference it because the LB policy may be
- * non-existing when the callback is run */
- grpc_core::LoadBalancingPolicy* lb_policy;
- grpc_closure closure;
-} reresolution_request_args;
-
-/** We create one watcher for each new lb_policy that is returned from a
- resolver, to watch for state changes from the lb_policy. When a state
- change is seen, we update the channel, and create a new watcher. */
-typedef struct {
- channel_data* chand;
- grpc_closure on_changed;
- grpc_connectivity_state state;
- grpc_core::LoadBalancingPolicy* lb_policy;
-} lb_policy_connectivity_watcher;
-
-static void watch_lb_policy_locked(channel_data* chand,
- grpc_core::LoadBalancingPolicy* lb_policy,
- grpc_connectivity_state current_state);
-
-static const char* channel_connectivity_state_change_string(
- grpc_connectivity_state state) {
- switch (state) {
- case GRPC_CHANNEL_IDLE:
- return "Channel state change to IDLE";
- case GRPC_CHANNEL_CONNECTING:
- return "Channel state change to CONNECTING";
- case GRPC_CHANNEL_READY:
- return "Channel state change to READY";
- case GRPC_CHANNEL_TRANSIENT_FAILURE:
- return "Channel state change to TRANSIENT_FAILURE";
- case GRPC_CHANNEL_SHUTDOWN:
- return "Channel state change to SHUTDOWN";
- }
- GPR_UNREACHABLE_CODE(return "UNKNOWN");
-}
-
-static void set_channel_connectivity_state_locked(channel_data* chand,
- grpc_connectivity_state state,
- grpc_error* error,
- const char* reason) {
- /* TODO: Improve failure handling:
- * - Make it possible for policies to return GRPC_CHANNEL_TRANSIENT_FAILURE.
- * - Hand over pending picks from old policies during the switch that happens
- * when resolver provides an update. */
- if (chand->lb_policy != nullptr) {
- if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) {
- /* cancel picks with wait_for_ready=false */
- chand->lb_policy->CancelMatchingPicksLocked(
- /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY,
- /* check= */ 0, GRPC_ERROR_REF(error));
- } else if (state == GRPC_CHANNEL_SHUTDOWN) {
- /* cancel all picks */
- chand->lb_policy->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0,
- GRPC_ERROR_REF(error));
- }
- }
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: setting connectivity state to %s", chand,
- grpc_connectivity_state_name(state));
- }
- if (chand->channelz_channel != nullptr) {
- chand->channelz_channel->AddTraceEvent(
- grpc_core::channelz::ChannelTrace::Severity::Info,
- grpc_slice_from_static_string(
- channel_connectivity_state_change_string(state)));
- }
- grpc_connectivity_state_set(&chand->state_tracker, state, error, reason);
-}
-
-static void on_lb_policy_state_changed_locked(void* arg, grpc_error* error) {
- lb_policy_connectivity_watcher* w =
- static_cast<lb_policy_connectivity_watcher*>(arg);
- /* check if the notification is for the latest policy */
- if (w->lb_policy == w->chand->lb_policy.get()) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: lb_policy=%p state changed to %s", w->chand,
- w->lb_policy, grpc_connectivity_state_name(w->state));
- }
- set_channel_connectivity_state_locked(w->chand, w->state,
- GRPC_ERROR_REF(error), "lb_changed");
- if (w->state != GRPC_CHANNEL_SHUTDOWN) {
- watch_lb_policy_locked(w->chand, w->lb_policy, w->state);
- }
- }
- GRPC_CHANNEL_STACK_UNREF(w->chand->owning_stack, "watch_lb_policy");
- gpr_free(w);
-}
-
-static void watch_lb_policy_locked(channel_data* chand,
- grpc_core::LoadBalancingPolicy* lb_policy,
- grpc_connectivity_state current_state) {
- lb_policy_connectivity_watcher* w =
- static_cast<lb_policy_connectivity_watcher*>(gpr_malloc(sizeof(*w)));
- GRPC_CHANNEL_STACK_REF(chand->owning_stack, "watch_lb_policy");
- w->chand = chand;
- GRPC_CLOSURE_INIT(&w->on_changed, on_lb_policy_state_changed_locked, w,
- grpc_combiner_scheduler(chand->combiner));
- w->state = current_state;
- w->lb_policy = lb_policy;
- lb_policy->NotifyOnStateChangeLocked(&w->state, &w->on_changed);
-}
-
-static void start_resolving_locked(channel_data* chand) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: starting name resolution", chand);
- }
- GPR_ASSERT(!chand->started_resolving);
- chand->started_resolving = true;
- GRPC_CHANNEL_STACK_REF(chand->owning_stack, "resolver");
- chand->resolver->NextLocked(&chand->resolver_result,
- &chand->on_resolver_result_changed);
-}
-
-// Invoked from the resolver NextLocked() callback when the resolver
-// is shutting down.
-static void on_resolver_shutdown_locked(channel_data* chand,
- grpc_error* error) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: shutting down", chand);
- }
- if (chand->lb_policy != nullptr) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand,
- chand->lb_policy.get());
- }
- grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(),
- chand->interested_parties);
- chand->lb_policy.reset();
- }
- if (chand->resolver != nullptr) {
- // This should never happen; it can only be triggered by a resolver
- // implementation spotaneously deciding to report shutdown without
- // being orphaned. This code is included just to be defensive.
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: spontaneous shutdown from resolver %p",
- chand, chand->resolver.get());
- }
- chand->resolver.reset();
- set_channel_connectivity_state_locked(
- chand, GRPC_CHANNEL_SHUTDOWN,
- GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
- "Resolver spontaneous shutdown", &error, 1),
- "resolver_spontaneous_shutdown");
- }
- grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures,
- GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
- "Channel disconnected", &error, 1));
- GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures);
- GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "resolver");
- grpc_channel_args_destroy(chand->resolver_result);
- chand->resolver_result = nullptr;
- GRPC_ERROR_UNREF(error);
-}
-
-static void request_reresolution_locked(void* arg, grpc_error* error) {
- reresolution_request_args* args =
- static_cast<reresolution_request_args*>(arg);
- channel_data* chand = args->chand;
- // If this invocation is for a stale LB policy, treat it as an LB shutdown
- // signal.
- if (args->lb_policy != chand->lb_policy.get() || error != GRPC_ERROR_NONE ||
- chand->resolver == nullptr) {
- GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "re-resolution");
- gpr_free(args);
- return;
- }
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: started name re-resolving", chand);
- }
- chand->resolver->RequestReresolutionLocked();
- // Give back the closure to the LB policy.
- chand->lb_policy->SetReresolutionClosureLocked(&args->closure);
-}
-
-using TraceStringVector = grpc_core::InlinedVector<char*, 3>;
-
-// Creates a new LB policy, replacing any previous one.
-// If the new policy is created successfully, sets *connectivity_state and
-// *connectivity_error to its initial connectivity state; otherwise,
-// leaves them unchanged.
-static void create_new_lb_policy_locked(
- channel_data* chand, char* lb_policy_name, grpc_json* lb_config,
- grpc_connectivity_state* connectivity_state,
- grpc_error** connectivity_error, TraceStringVector* trace_strings) {
- grpc_core::LoadBalancingPolicy::Args lb_policy_args;
- lb_policy_args.combiner = chand->combiner;
- lb_policy_args.client_channel_factory = chand->client_channel_factory;
- lb_policy_args.args = chand->resolver_result;
- lb_policy_args.lb_config = lb_config;
- grpc_core::OrphanablePtr<grpc_core::LoadBalancingPolicy> new_lb_policy =
- grpc_core::LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(
- lb_policy_name, lb_policy_args);
- if (GPR_UNLIKELY(new_lb_policy == nullptr)) {
- gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name);
- if (chand->channelz_channel != nullptr) {
- char* str;
- gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name);
- trace_strings->push_back(str);
- }
- } else {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: created new LB policy \"%s\" (%p)", chand,
- lb_policy_name, new_lb_policy.get());
- }
- if (chand->channelz_channel != nullptr) {
- char* str;
- gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name);
- trace_strings->push_back(str);
- }
- // Swap out the LB policy and update the fds in
- // chand->interested_parties.
- if (chand->lb_policy != nullptr) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: shutting down lb_policy=%p", chand,
- chand->lb_policy.get());
- }
- grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(),
- chand->interested_parties);
- chand->lb_policy->HandOffPendingPicksLocked(new_lb_policy.get());
- }
- chand->lb_policy = std::move(new_lb_policy);
- grpc_pollset_set_add_pollset_set(chand->lb_policy->interested_parties(),
- chand->interested_parties);
- // Set up re-resolution callback.
- reresolution_request_args* args =
- static_cast<reresolution_request_args*>(gpr_zalloc(sizeof(*args)));
- args->chand = chand;
- args->lb_policy = chand->lb_policy.get();
- GRPC_CLOSURE_INIT(&args->closure, request_reresolution_locked, args,
- grpc_combiner_scheduler(chand->combiner));
- GRPC_CHANNEL_STACK_REF(chand->owning_stack, "re-resolution");
- chand->lb_policy->SetReresolutionClosureLocked(&args->closure);
- // Get the new LB policy's initial connectivity state and start a
- // connectivity watch.
- GRPC_ERROR_UNREF(*connectivity_error);
- *connectivity_state =
- chand->lb_policy->CheckConnectivityLocked(connectivity_error);
- if (chand->exit_idle_when_lb_policy_arrives) {
- chand->lb_policy->ExitIdleLocked();
- chand->exit_idle_when_lb_policy_arrives = false;
- }
- watch_lb_policy_locked(chand, chand->lb_policy.get(), *connectivity_state);
- }
-}
-
-static void maybe_add_trace_message_for_address_changes_locked(
- channel_data* chand, TraceStringVector* trace_strings) {
- const ServerAddressList* addresses =
- grpc_core::FindServerAddressListChannelArg(chand->resolver_result);
- const bool resolution_contains_addresses =
- addresses != nullptr && addresses->size() > 0;
- if (!resolution_contains_addresses &&
- chand->previous_resolution_contained_addresses) {
- trace_strings->push_back(gpr_strdup("Address list became empty"));
- } else if (resolution_contains_addresses &&
- !chand->previous_resolution_contained_addresses) {
- trace_strings->push_back(gpr_strdup("Address list became non-empty"));
- }
- chand->previous_resolution_contained_addresses =
- resolution_contains_addresses;
-}
-
-static void concatenate_and_add_channel_trace_locked(
- channel_data* chand, TraceStringVector* trace_strings) {
- if (!trace_strings->empty()) {
- gpr_strvec v;
- gpr_strvec_init(&v);
- gpr_strvec_add(&v, gpr_strdup("Resolution event: "));
- bool is_first = 1;
- for (size_t i = 0; i < trace_strings->size(); ++i) {
- if (!is_first) gpr_strvec_add(&v, gpr_strdup(", "));
- is_first = false;
- gpr_strvec_add(&v, (*trace_strings)[i]);
- }
- char* flat;
- size_t flat_len = 0;
- flat = gpr_strvec_flatten(&v, &flat_len);
- chand->channelz_channel->AddTraceEvent(
- grpc_core::channelz::ChannelTrace::Severity::Info,
- grpc_slice_new(flat, flat_len, gpr_free));
- gpr_strvec_destroy(&v);
- }
-}
-
-// Callback invoked when a resolver result is available.
-static void on_resolver_result_changed_locked(void* arg, grpc_error* error) {
+// Synchronous callback from chand->request_router to process a resolver
+// result update.
+static bool process_resolver_result_locked(void* arg,
+ const grpc_channel_args& args,
+ const char** lb_policy_name,
+ grpc_json** lb_policy_config) {
channel_data* chand = static_cast<channel_data*>(arg);
+ ProcessedResolverResult resolver_result(args, chand->enable_retries);
+ grpc_core::UniquePtr<char> service_config_json =
+ resolver_result.service_config_json();
if (grpc_client_channel_trace.enabled()) {
- const char* disposition =
- chand->resolver_result != nullptr
- ? ""
- : (error == GRPC_ERROR_NONE ? " (transient error)"
- : " (resolver shutdown)");
- gpr_log(GPR_INFO,
- "chand=%p: got resolver result: resolver_result=%p error=%s%s",
- chand, chand->resolver_result, grpc_error_string(error),
- disposition);
+ gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"",
+ chand, service_config_json.get());
}
- // Handle shutdown.
- if (error != GRPC_ERROR_NONE || chand->resolver == nullptr) {
- on_resolver_shutdown_locked(chand, GRPC_ERROR_REF(error));
- return;
- }
- // Data used to set the channel's connectivity state.
- bool set_connectivity_state = true;
- // We only want to trace the address resolution in the follow cases:
- // (a) Address resolution resulted in service config change.
- // (b) Address resolution that causes number of backends to go from
- // zero to non-zero.
- // (c) Address resolution that causes number of backends to go from
- // non-zero to zero.
- // (d) Address resolution that causes a new LB policy to be created.
- //
- // we track a list of strings to eventually be concatenated and traced.
- TraceStringVector trace_strings;
- grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE;
- grpc_error* connectivity_error =
- GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy");
- // chand->resolver_result will be null in the case of a transient
- // resolution error. In that case, we don't have any new result to
- // process, which means that we keep using the previous result (if any).
- if (chand->resolver_result == nullptr) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: resolver transient failure", chand);
- }
- // Don't override connectivity state if we already have an LB policy.
- if (chand->lb_policy != nullptr) set_connectivity_state = false;
- } else {
- // Parse the resolver result.
- ProcessedResolverResult resolver_result(chand->resolver_result,
- chand->enable_retries);
- chand->retry_throttle_data = resolver_result.retry_throttle_data();
- chand->method_params_table = resolver_result.method_params_table();
- grpc_core::UniquePtr<char> service_config_json =
- resolver_result.service_config_json();
- if (service_config_json != nullptr && grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: resolver returned service config: \"%s\"",
- chand, service_config_json.get());
- }
- grpc_core::UniquePtr<char> lb_policy_name =
- resolver_result.lb_policy_name();
- grpc_json* lb_policy_config = resolver_result.lb_policy_config();
- // Check to see if we're already using the right LB policy.
- // Note: It's safe to use chand->info_lb_policy_name here without
- // taking a lock on chand->info_mu, because this function is the
- // only thing that modifies its value, and it can only be invoked
- // once at any given time.
- bool lb_policy_name_changed =
- chand->info_lb_policy_name == nullptr ||
- strcmp(chand->info_lb_policy_name.get(), lb_policy_name.get()) != 0;
- if (chand->lb_policy != nullptr && !lb_policy_name_changed) {
- // Continue using the same LB policy. Update with new addresses.
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p: updating existing LB policy \"%s\" (%p)",
- chand, lb_policy_name.get(), chand->lb_policy.get());
- }
- chand->lb_policy->UpdateLocked(*chand->resolver_result, lb_policy_config);
- // No need to set the channel's connectivity state; the existing
- // watch on the LB policy will take care of that.
- set_connectivity_state = false;
- } else {
- // Instantiate new LB policy.
- create_new_lb_policy_locked(chand, lb_policy_name.get(), lb_policy_config,
- &connectivity_state, &connectivity_error,
- &trace_strings);
- }
- // Note: It's safe to use chand->info_service_config_json here without
- // taking a lock on chand->info_mu, because this function is the
- // only thing that modifies its value, and it can only be invoked
- // once at any given time.
- if (chand->channelz_channel != nullptr) {
- if (((service_config_json == nullptr) !=
- (chand->info_service_config_json == nullptr)) ||
- (service_config_json != nullptr &&
- strcmp(service_config_json.get(),
- chand->info_service_config_json.get()) != 0)) {
- // TODO(ncteisen): might be worth somehow including a snippet of the
- // config in the trace, at the risk of bloating the trace logs.
- trace_strings.push_back(gpr_strdup("Service config changed"));
- }
- maybe_add_trace_message_for_address_changes_locked(chand, &trace_strings);
- concatenate_and_add_channel_trace_locked(chand, &trace_strings);
- }
- // Swap out the data used by cc_get_channel_info().
- gpr_mu_lock(&chand->info_mu);
- chand->info_lb_policy_name = std::move(lb_policy_name);
- chand->info_service_config_json = std::move(service_config_json);
- gpr_mu_unlock(&chand->info_mu);
- // Clean up.
- grpc_channel_args_destroy(chand->resolver_result);
- chand->resolver_result = nullptr;
- }
- // Set the channel's connectivity state if needed.
- if (set_connectivity_state) {
- set_channel_connectivity_state_locked(
- chand, connectivity_state, connectivity_error, "resolver_result");
- } else {
- GRPC_ERROR_UNREF(connectivity_error);
- }
- // Invoke closures that were waiting for results and renew the watch.
- GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures);
- chand->resolver->NextLocked(&chand->resolver_result,
- &chand->on_resolver_result_changed);
+ // Update channel state.
+ chand->retry_throttle_data = resolver_result.retry_throttle_data();
+ chand->method_params_table = resolver_result.method_params_table();
+ // Swap out the data used by cc_get_channel_info().
+ gpr_mu_lock(&chand->info_mu);
+ chand->info_lb_policy_name = resolver_result.lb_policy_name();
+ const bool service_config_changed =
+ ((service_config_json == nullptr) !=
+ (chand->info_service_config_json == nullptr)) ||
+ (service_config_json != nullptr &&
+ strcmp(service_config_json.get(),
+ chand->info_service_config_json.get()) != 0);
+ chand->info_service_config_json = std::move(service_config_json);
+ gpr_mu_unlock(&chand->info_mu);
+ // Return results.
+ *lb_policy_name = chand->info_lb_policy_name.get();
+ *lb_policy_config = resolver_result.lb_policy_config();
+ return service_config_changed;
}
static void start_transport_op_locked(void* arg, grpc_error* error_ignored) {
@@ -550,15 +158,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
if (op->on_connectivity_state_change != nullptr) {
- grpc_connectivity_state_notify_on_state_change(
- &chand->state_tracker, op->connectivity_state,
- op->on_connectivity_state_change);
+ chand->request_router->NotifyOnConnectivityStateChange(
+ op->connectivity_state, op->on_connectivity_state_change);
op->on_connectivity_state_change = nullptr;
op->connectivity_state = nullptr;
}
if (op->send_ping.on_initiate != nullptr || op->send_ping.on_ack != nullptr) {
- if (chand->lb_policy == nullptr) {
+ if (chand->request_router->lb_policy() == nullptr) {
grpc_error* error =
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Ping with no load balancing");
GRPC_CLOSURE_SCHED(op->send_ping.on_initiate, GRPC_ERROR_REF(error));
@@ -567,7 +174,8 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) {
grpc_error* error = GRPC_ERROR_NONE;
grpc_core::LoadBalancingPolicy::PickState pick_state;
// Pick must return synchronously, because pick_state.on_complete is null.
- GPR_ASSERT(chand->lb_policy->PickLocked(&pick_state, &error));
+ GPR_ASSERT(
+ chand->request_router->lb_policy()->PickLocked(&pick_state, &error));
if (pick_state.connected_subchannel != nullptr) {
pick_state.connected_subchannel->Ping(op->send_ping.on_initiate,
op->send_ping.on_ack);
@@ -586,37 +194,14 @@ static void start_transport_op_locked(void* arg, grpc_error* error_ignored) {
}
if (op->disconnect_with_error != GRPC_ERROR_NONE) {
- if (chand->resolver != nullptr) {
- set_channel_connectivity_state_locked(
- chand, GRPC_CHANNEL_SHUTDOWN,
- GRPC_ERROR_REF(op->disconnect_with_error), "disconnect");
- chand->resolver.reset();
- if (!chand->started_resolving) {
- grpc_closure_list_fail_all(&chand->waiting_for_resolver_result_closures,
- GRPC_ERROR_REF(op->disconnect_with_error));
- GRPC_CLOSURE_LIST_SCHED(&chand->waiting_for_resolver_result_closures);
- }
- if (chand->lb_policy != nullptr) {
- grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(),
- chand->interested_parties);
- chand->lb_policy.reset();
- }
- }
- GRPC_ERROR_UNREF(op->disconnect_with_error);
+ chand->request_router->ShutdownLocked(op->disconnect_with_error);
}
if (op->reset_connect_backoff) {
- if (chand->resolver != nullptr) {
- chand->resolver->ResetBackoffLocked();
- chand->resolver->RequestReresolutionLocked();
- }
- if (chand->lb_policy != nullptr) {
- chand->lb_policy->ResetBackoffLocked();
- }
+ chand->request_router->ResetConnectionBackoffLocked();
}
GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "start_transport_op");
-
GRPC_CLOSURE_SCHED(op->on_consumed, GRPC_ERROR_NONE);
}
@@ -667,12 +252,9 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem,
gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu);
chand->owning_stack = args->channel_stack;
- GRPC_CLOSURE_INIT(&chand->on_resolver_result_changed,
- on_resolver_result_changed_locked, chand,
- grpc_combiner_scheduler(chand->combiner));
+ chand->deadline_checking_enabled =
+ grpc_deadline_checking_enabled(args->channel_args);
chand->interested_parties = grpc_pollset_set_create();
- grpc_connectivity_state_init(&chand->state_tracker, GRPC_CHANNEL_IDLE,
- "client_channel");
grpc_client_channel_start_backup_polling(chand->interested_parties);
// Record max per-RPC retry buffer size.
const grpc_arg* arg = grpc_channel_args_find(
@@ -682,8 +264,6 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem,
// Record enable_retries.
arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_ENABLE_RETRIES);
chand->enable_retries = grpc_channel_arg_get_bool(arg, true);
- chand->channelz_channel = nullptr;
- chand->previous_resolution_contained_addresses = false;
// Record client channel factory.
arg = grpc_channel_args_find(args->channel_args,
GRPC_ARG_CLIENT_CHANNEL_FACTORY);
@@ -695,9 +275,7 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem,
return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"client channel factory arg must be a pointer");
}
- grpc_client_channel_factory_ref(
- static_cast<grpc_client_channel_factory*>(arg->value.pointer.p));
- chand->client_channel_factory =
+ grpc_client_channel_factory* client_channel_factory =
static_cast<grpc_client_channel_factory*>(arg->value.pointer.p);
// Get server name to resolve, using proxy mapper if needed.
arg = grpc_channel_args_find(args->channel_args, GRPC_ARG_SERVER_URI);
@@ -713,39 +291,24 @@ static grpc_error* cc_init_channel_elem(grpc_channel_element* elem,
grpc_channel_args* new_args = nullptr;
grpc_proxy_mappers_map_name(arg->value.string, args->channel_args,
&proxy_name, &new_args);
- // Instantiate resolver.
- chand->resolver = grpc_core::ResolverRegistry::CreateResolver(
- proxy_name != nullptr ? proxy_name : arg->value.string,
- new_args != nullptr ? new_args : args->channel_args,
- chand->interested_parties, chand->combiner);
- if (proxy_name != nullptr) gpr_free(proxy_name);
- if (new_args != nullptr) grpc_channel_args_destroy(new_args);
- if (chand->resolver == nullptr) {
- return GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed");
- }
- chand->deadline_checking_enabled =
- grpc_deadline_checking_enabled(args->channel_args);
- return GRPC_ERROR_NONE;
+ // Instantiate request router.
+ grpc_client_channel_factory_ref(client_channel_factory);
+ grpc_error* error = GRPC_ERROR_NONE;
+ chand->request_router.Init(
+ chand->owning_stack, chand->combiner, client_channel_factory,
+ chand->interested_parties, &grpc_client_channel_trace,
+ process_resolver_result_locked, chand,
+ proxy_name != nullptr ? proxy_name : arg->value.string /* target_uri */,
+ new_args != nullptr ? new_args : args->channel_args, &error);
+ gpr_free(proxy_name);
+ grpc_channel_args_destroy(new_args);
+ return error;
}
/* Destructor for channel_data */
static void cc_destroy_channel_elem(grpc_channel_element* elem) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- if (chand->resolver != nullptr) {
- // The only way we can get here is if we never started resolving,
- // because we take a ref to the channel stack when we start
- // resolving and do not release it until the resolver callback is
- // invoked after the resolver shuts down.
- chand->resolver.reset();
- }
- if (chand->client_channel_factory != nullptr) {
- grpc_client_channel_factory_unref(chand->client_channel_factory);
- }
- if (chand->lb_policy != nullptr) {
- grpc_pollset_set_del_pollset_set(chand->lb_policy->interested_parties(),
- chand->interested_parties);
- chand->lb_policy.reset();
- }
+ chand->request_router.Destroy();
// TODO(roth): Once we convert the filter API to C++, there will no
// longer be any need to explicitly reset these smart pointer data members.
chand->info_lb_policy_name.reset();
@@ -753,7 +316,6 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) {
chand->retry_throttle_data.reset();
chand->method_params_table.reset();
grpc_client_channel_stop_backup_polling(chand->interested_parties);
- grpc_connectivity_state_destroy(&chand->state_tracker);
grpc_pollset_set_destroy(chand->interested_parties);
GRPC_COMBINER_UNREF(chand->combiner, "client_channel");
gpr_mu_destroy(&chand->info_mu);
@@ -810,6 +372,7 @@ static void cc_destroy_channel_elem(grpc_channel_element* elem) {
// - add census stats for retries
namespace {
+
struct call_data;
// State used for starting a retryable batch on a subchannel call.
@@ -894,12 +457,12 @@ struct subchannel_call_retry_state {
bool completed_recv_initial_metadata : 1;
bool started_recv_trailing_metadata : 1;
bool completed_recv_trailing_metadata : 1;
+ // State for callback processing.
subchannel_batch_data* recv_initial_metadata_ready_deferred_batch = nullptr;
grpc_error* recv_initial_metadata_error = GRPC_ERROR_NONE;
subchannel_batch_data* recv_message_ready_deferred_batch = nullptr;
grpc_error* recv_message_error = GRPC_ERROR_NONE;
subchannel_batch_data* recv_trailing_metadata_internal_batch = nullptr;
- // State for callback processing.
// NOTE: Do not move this next to the metadata bitfields above. That would
// save space but will also result in a data race because compiler will
// generate a 2 byte store which overwrites the meta-data fields upon
@@ -908,12 +471,12 @@ struct subchannel_call_retry_state {
};
// Pending batches stored in call data.
-typedef struct {
+struct pending_batch {
// The pending batch. If nullptr, this slot is empty.
grpc_transport_stream_op_batch* batch;
// Indicates whether payload for send ops has been cached in call data.
bool send_ops_cached;
-} pending_batch;
+};
/** Call data. Holds a pointer to grpc_subchannel_call and the
associated machinery to create such a pointer.
@@ -950,11 +513,8 @@ struct call_data {
for (size_t i = 0; i < GPR_ARRAY_SIZE(pending_batches); ++i) {
GPR_ASSERT(pending_batches[i].batch == nullptr);
}
- for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
- if (pick.subchannel_call_context[i].value != nullptr) {
- pick.subchannel_call_context[i].destroy(
- pick.subchannel_call_context[i].value);
- }
+ if (have_request) {
+ request.Destroy();
}
}
@@ -981,12 +541,11 @@ struct call_data {
// Set when we get a cancel_stream op.
grpc_error* cancel_error = GRPC_ERROR_NONE;
- grpc_core::LoadBalancingPolicy::PickState pick;
+ grpc_core::ManualConstructor<grpc_core::RequestRouter::Request> request;
+ bool have_request = false;
grpc_closure pick_closure;
- grpc_closure pick_cancel_closure;
grpc_polling_entity* pollent = nullptr;
- bool pollent_added_to_interested_parties = false;
// Batches are added to this list when received from above.
// They are removed when we are done handling the batch (i.e., when
@@ -1036,6 +595,7 @@ struct call_data {
grpc_linked_mdelem* send_trailing_metadata_storage = nullptr;
grpc_metadata_batch send_trailing_metadata;
};
+
} // namespace
// Forward declarations.
@@ -1438,8 +998,9 @@ static void do_retry(grpc_call_element* elem,
"client_channel_call_retry");
calld->subchannel_call = nullptr;
}
- if (calld->pick.connected_subchannel != nullptr) {
- calld->pick.connected_subchannel.reset();
+ if (calld->have_request) {
+ calld->have_request = false;
+ calld->request.Destroy();
}
// Compute backoff delay.
grpc_millis next_attempt_time;
@@ -1588,6 +1149,7 @@ static bool maybe_retry(grpc_call_element* elem,
//
namespace {
+
subchannel_batch_data::subchannel_batch_data(grpc_call_element* elem,
call_data* calld, int refcount,
bool set_on_complete)
@@ -1628,6 +1190,7 @@ void subchannel_batch_data::destroy() {
call_data* calld = static_cast<call_data*>(elem->call_data);
GRPC_CALL_STACK_UNREF(calld->owning_call, "batch_data");
}
+
} // namespace
// Creates a subchannel_batch_data object on the call's arena with the
@@ -2644,17 +2207,18 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) {
const size_t parent_data_size =
calld->enable_retries ? sizeof(subchannel_call_retry_state) : 0;
const grpc_core::ConnectedSubchannel::CallArgs call_args = {
- calld->pollent, // pollent
- calld->path, // path
- calld->call_start_time, // start_time
- calld->deadline, // deadline
- calld->arena, // arena
- calld->pick.subchannel_call_context, // context
- calld->call_combiner, // call_combiner
- parent_data_size // parent_data_size
+ calld->pollent, // pollent
+ calld->path, // path
+ calld->call_start_time, // start_time
+ calld->deadline, // deadline
+ calld->arena, // arena
+ calld->request->pick()->subchannel_call_context, // context
+ calld->call_combiner, // call_combiner
+ parent_data_size // parent_data_size
};
- grpc_error* new_error = calld->pick.connected_subchannel->CreateCall(
- call_args, &calld->subchannel_call);
+ grpc_error* new_error =
+ calld->request->pick()->connected_subchannel->CreateCall(
+ call_args, &calld->subchannel_call);
if (grpc_client_channel_trace.enabled()) {
gpr_log(GPR_INFO, "chand=%p calld=%p: create subchannel_call=%p: error=%s",
chand, calld, calld->subchannel_call, grpc_error_string(new_error));
@@ -2666,7 +2230,8 @@ static void create_subchannel_call(grpc_call_element* elem, grpc_error* error) {
if (parent_data_size > 0) {
new (grpc_connected_subchannel_call_get_parent_data(
calld->subchannel_call))
- subchannel_call_retry_state(calld->pick.subchannel_call_context);
+ subchannel_call_retry_state(
+ calld->request->pick()->subchannel_call_context);
}
pending_batches_resume(elem);
}
@@ -2678,7 +2243,7 @@ static void pick_done(void* arg, grpc_error* error) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
- if (GPR_UNLIKELY(calld->pick.connected_subchannel == nullptr)) {
+ if (GPR_UNLIKELY(calld->request->pick()->connected_subchannel == nullptr)) {
// Failed to create subchannel.
// If there was no error, this is an LB policy drop, in which case
// we return an error; otherwise, we may retry.
@@ -2707,135 +2272,27 @@ static void pick_done(void* arg, grpc_error* error) {
}
}
-static void maybe_add_call_to_channel_interested_parties_locked(
- grpc_call_element* elem) {
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (!calld->pollent_added_to_interested_parties) {
- calld->pollent_added_to_interested_parties = true;
- grpc_polling_entity_add_to_pollset_set(calld->pollent,
- chand->interested_parties);
- }
-}
-
-static void maybe_del_call_from_channel_interested_parties_locked(
- grpc_call_element* elem) {
+// If the channel is in TRANSIENT_FAILURE and the call is not
+// wait_for_ready=true, fails the call and returns true.
+static bool fail_call_if_in_transient_failure(grpc_call_element* elem) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
call_data* calld = static_cast<call_data*>(elem->call_data);
- if (calld->pollent_added_to_interested_parties) {
- calld->pollent_added_to_interested_parties = false;
- grpc_polling_entity_del_from_pollset_set(calld->pollent,
- chand->interested_parties);
+ grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch;
+ if (chand->request_router->GetConnectivityState() ==
+ GRPC_CHANNEL_TRANSIENT_FAILURE &&
+ (batch->payload->send_initial_metadata.send_initial_metadata_flags &
+ GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) {
+ pending_batches_fail(
+ elem,
+ grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "channel is in state TRANSIENT_FAILURE"),
+ GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE),
+ true /* yield_call_combiner */);
+ return true;
}
+ return false;
}
-// Invoked when a pick is completed to leave the client_channel combiner
-// and continue processing in the call combiner.
-// If needed, removes the call's polling entity from chand->interested_parties.
-static void pick_done_locked(grpc_call_element* elem, grpc_error* error) {
- call_data* calld = static_cast<call_data*>(elem->call_data);
- maybe_del_call_from_channel_interested_parties_locked(elem);
- GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem,
- grpc_schedule_on_exec_ctx);
- GRPC_CLOSURE_SCHED(&calld->pick_closure, error);
-}
-
-namespace grpc_core {
-
-// Performs subchannel pick via LB policy.
-class LbPicker {
- public:
- // Starts a pick on chand->lb_policy.
- static void StartLocked(grpc_call_element* elem) {
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: starting pick on lb_policy=%p",
- chand, calld, chand->lb_policy.get());
- }
- // If this is a retry, use the send_initial_metadata payload that
- // we've cached; otherwise, use the pending batch. The
- // send_initial_metadata batch will be the first pending batch in the
- // list, as set by get_batch_index() above.
- calld->pick.initial_metadata =
- calld->seen_send_initial_metadata
- ? &calld->send_initial_metadata
- : calld->pending_batches[0]
- .batch->payload->send_initial_metadata.send_initial_metadata;
- calld->pick.initial_metadata_flags =
- calld->seen_send_initial_metadata
- ? calld->send_initial_metadata_flags
- : calld->pending_batches[0]
- .batch->payload->send_initial_metadata
- .send_initial_metadata_flags;
- GRPC_CLOSURE_INIT(&calld->pick_closure, &LbPicker::DoneLocked, elem,
- grpc_combiner_scheduler(chand->combiner));
- calld->pick.on_complete = &calld->pick_closure;
- GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback");
- grpc_error* error = GRPC_ERROR_NONE;
- const bool pick_done = chand->lb_policy->PickLocked(&calld->pick, &error);
- if (GPR_LIKELY(pick_done)) {
- // Pick completed synchronously.
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed synchronously",
- chand, calld);
- }
- pick_done_locked(elem, error);
- GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback");
- } else {
- // Pick will be returned asynchronously.
- // Add the polling entity from call_data to the channel_data's
- // interested_parties, so that the I/O of the LB policy can be done
- // under it. It will be removed in pick_done_locked().
- maybe_add_call_to_channel_interested_parties_locked(elem);
- // Request notification on call cancellation.
- GRPC_CALL_STACK_REF(calld->owning_call, "pick_callback_cancel");
- grpc_call_combiner_set_notify_on_cancel(
- calld->call_combiner,
- GRPC_CLOSURE_INIT(&calld->pick_cancel_closure,
- &LbPicker::CancelLocked, elem,
- grpc_combiner_scheduler(chand->combiner)));
- }
- }
-
- private:
- // Callback invoked by LoadBalancingPolicy::PickLocked() for async picks.
- // Unrefs the LB policy and invokes pick_done_locked().
- static void DoneLocked(void* arg, grpc_error* error) {
- grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: pick completed asynchronously",
- chand, calld);
- }
- pick_done_locked(elem, GRPC_ERROR_REF(error));
- GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback");
- }
-
- // Note: This runs under the client_channel combiner, but will NOT be
- // holding the call combiner.
- static void CancelLocked(void* arg, grpc_error* error) {
- grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- // Note: chand->lb_policy may have changed since we started our pick,
- // in which case we will be cancelling the pick on a policy other than
- // the one we started it on. However, this will just be a no-op.
- if (GPR_UNLIKELY(error != GRPC_ERROR_NONE && chand->lb_policy != nullptr)) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO,
- "chand=%p calld=%p: cancelling pick from LB policy %p", chand,
- calld, chand->lb_policy.get());
- }
- chand->lb_policy->CancelPickLocked(&calld->pick, GRPC_ERROR_REF(error));
- }
- GRPC_CALL_STACK_UNREF(calld->owning_call, "pick_callback_cancel");
- }
-};
-
-} // namespace grpc_core
-
// Applies service config to the call. Must be invoked once we know
// that the resolver has returned results to the channel.
static void apply_service_config_to_call_locked(grpc_call_element* elem) {
@@ -2892,224 +2349,66 @@ static void apply_service_config_to_call_locked(grpc_call_element* elem) {
}
}
-// If the channel is in TRANSIENT_FAILURE and the call is not
-// wait_for_ready=true, fails the call and returns true.
-static bool fail_call_if_in_transient_failure(grpc_call_element* elem) {
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- grpc_transport_stream_op_batch* batch = calld->pending_batches[0].batch;
- if (grpc_connectivity_state_check(&chand->state_tracker) ==
- GRPC_CHANNEL_TRANSIENT_FAILURE &&
- (batch->payload->send_initial_metadata.send_initial_metadata_flags &
- GRPC_INITIAL_METADATA_WAIT_FOR_READY) == 0) {
- pending_batches_fail(
- elem,
- grpc_error_set_int(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "channel is in state TRANSIENT_FAILURE"),
- GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE),
- true /* yield_call_combiner */);
- return true;
- }
- return false;
-}
-
// Invoked once resolver results are available.
-static void process_service_config_and_start_lb_pick_locked(
- grpc_call_element* elem) {
+static bool maybe_apply_service_config_to_call_locked(void* arg) {
+ grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
call_data* calld = static_cast<call_data*>(elem->call_data);
// Only get service config data on the first attempt.
if (GPR_LIKELY(calld->num_attempts_completed == 0)) {
apply_service_config_to_call_locked(elem);
// Check this after applying service config, since it may have
// affected the call's wait_for_ready value.
- if (fail_call_if_in_transient_failure(elem)) return;
+ if (fail_call_if_in_transient_failure(elem)) return false;
}
- // Start LB pick.
- grpc_core::LbPicker::StartLocked(elem);
+ return true;
}
-namespace grpc_core {
-
-// Handles waiting for a resolver result.
-// Used only for the first call on an idle channel.
-class ResolverResultWaiter {
- public:
- explicit ResolverResultWaiter(grpc_call_element* elem) : elem_(elem) {
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO,
- "chand=%p calld=%p: deferring pick pending resolver result",
- chand, calld);
- }
- // Add closure to be run when a resolver result is available.
- GRPC_CLOSURE_INIT(&done_closure_, &ResolverResultWaiter::DoneLocked, this,
- grpc_combiner_scheduler(chand->combiner));
- AddToWaitingList();
- // Set cancellation closure, so that we abort if the call is cancelled.
- GRPC_CLOSURE_INIT(&cancel_closure_, &ResolverResultWaiter::CancelLocked,
- this, grpc_combiner_scheduler(chand->combiner));
- grpc_call_combiner_set_notify_on_cancel(calld->call_combiner,
- &cancel_closure_);
- }
-
- private:
- // Adds closure_ to chand->waiting_for_resolver_result_closures.
- void AddToWaitingList() {
- channel_data* chand = static_cast<channel_data*>(elem_->channel_data);
- grpc_closure_list_append(&chand->waiting_for_resolver_result_closures,
- &done_closure_, GRPC_ERROR_NONE);
- }
-
- // Invoked when a resolver result is available.
- static void DoneLocked(void* arg, grpc_error* error) {
- ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg);
- // If CancelLocked() has already run, delete ourselves without doing
- // anything. Note that the call stack may have already been destroyed,
- // so it's not safe to access anything in elem_.
- if (GPR_UNLIKELY(self->finished_)) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "call cancelled before resolver result");
- }
- Delete(self);
- return;
- }
- // Otherwise, process the resolver result.
- grpc_call_element* elem = self->elem_;
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: resolver failed to return data",
- chand, calld);
- }
- pick_done_locked(elem, GRPC_ERROR_REF(error));
- } else if (GPR_UNLIKELY(chand->resolver == nullptr)) {
- // Shutting down.
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: resolver disconnected", chand,
- calld);
- }
- pick_done_locked(elem,
- GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected"));
- } else if (GPR_UNLIKELY(chand->lb_policy == nullptr)) {
- // Transient resolver failure.
- // If call has wait_for_ready=true, try again; otherwise, fail.
- uint32_t send_initial_metadata_flags =
- calld->seen_send_initial_metadata
- ? calld->send_initial_metadata_flags
- : calld->pending_batches[0]
- .batch->payload->send_initial_metadata
- .send_initial_metadata_flags;
- if (send_initial_metadata_flags & GRPC_INITIAL_METADATA_WAIT_FOR_READY) {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO,
- "chand=%p calld=%p: resolver returned but no LB policy; "
- "wait_for_ready=true; trying again",
- chand, calld);
- }
- // Re-add ourselves to the waiting list.
- self->AddToWaitingList();
- // Return early so that we don't set finished_ to true below.
- return;
- } else {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO,
- "chand=%p calld=%p: resolver returned but no LB policy; "
- "wait_for_ready=false; failing",
- chand, calld);
- }
- pick_done_locked(
- elem,
- grpc_error_set_int(
- GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"),
- GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE));
- }
- } else {
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO, "chand=%p calld=%p: resolver returned, doing LB pick",
- chand, calld);
- }
- process_service_config_and_start_lb_pick_locked(elem);
- }
- self->finished_ = true;
- }
-
- // Invoked when the call is cancelled.
- // Note: This runs under the client_channel combiner, but will NOT be
- // holding the call combiner.
- static void CancelLocked(void* arg, grpc_error* error) {
- ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg);
- // If DoneLocked() has already run, delete ourselves without doing anything.
- if (GPR_LIKELY(self->finished_)) {
- Delete(self);
- return;
- }
- // If we are being cancelled, immediately invoke pick_done_locked()
- // to propagate the error back to the caller.
- if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) {
- grpc_call_element* elem = self->elem_;
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- call_data* calld = static_cast<call_data*>(elem->call_data);
- if (grpc_client_channel_trace.enabled()) {
- gpr_log(GPR_INFO,
- "chand=%p calld=%p: cancelling call waiting for name "
- "resolution",
- chand, calld);
- }
- // Note: Although we are not in the call combiner here, we are
- // basically stealing the call combiner from the pending pick, so
- // it's safe to call pick_done_locked() here -- we are essentially
- // calling it here instead of calling it in DoneLocked().
- pick_done_locked(elem, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
- "Pick cancelled", &error, 1));
- }
- self->finished_ = true;
- }
-
- grpc_call_element* elem_;
- grpc_closure done_closure_;
- grpc_closure cancel_closure_;
- bool finished_ = false;
-};
-
-} // namespace grpc_core
-
static void start_pick_locked(void* arg, grpc_error* ignored) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- GPR_ASSERT(calld->pick.connected_subchannel == nullptr);
+ GPR_ASSERT(!calld->have_request);
GPR_ASSERT(calld->subchannel_call == nullptr);
- if (GPR_LIKELY(chand->lb_policy != nullptr)) {
- // We already have resolver results, so process the service config
- // and start an LB pick.
- process_service_config_and_start_lb_pick_locked(elem);
- } else if (GPR_UNLIKELY(chand->resolver == nullptr)) {
- pick_done_locked(elem,
- GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected"));
- } else {
- // We do not yet have an LB policy, so wait for a resolver result.
- if (GPR_UNLIKELY(!chand->started_resolving)) {
- start_resolving_locked(chand);
- } else {
- // Normally, we want to do this check in
- // process_service_config_and_start_lb_pick_locked(), so that we
- // can honor the wait_for_ready setting in the service config.
- // However, if the channel is in TRANSIENT_FAILURE at this point, that
- // means that the resolver has returned a failure, so we're not going
- // to get a service config right away. In that case, we fail the
- // call now based on the wait_for_ready value passed in from the
- // application.
- if (fail_call_if_in_transient_failure(elem)) return;
- }
- // Create a new waiter, which will delete itself when done.
- grpc_core::New<grpc_core::ResolverResultWaiter>(elem);
- // Add the polling entity from call_data to the channel_data's
- // interested_parties, so that the I/O of the resolver can be done
- // under it. It will be removed in pick_done_locked().
- maybe_add_call_to_channel_interested_parties_locked(elem);
+ // Normally, we want to do this check until after we've processed the
+ // service config, so that we can honor the wait_for_ready setting in
+ // the service config. However, if the channel is in TRANSIENT_FAILURE
+ // and we don't have an LB policy at this point, that means that the
+ // resolver has returned a failure, so we're not going to get a service
+ // config right away. In that case, we fail the call now based on the
+ // wait_for_ready value passed in from the application.
+ if (chand->request_router->lb_policy() == nullptr &&
+ fail_call_if_in_transient_failure(elem)) {
+ return;
}
+ // If this is a retry, use the send_initial_metadata payload that
+ // we've cached; otherwise, use the pending batch. The
+ // send_initial_metadata batch will be the first pending batch in the
+ // list, as set by get_batch_index() above.
+ // TODO(roth): What if the LB policy needs to add something to the
+ // call's initial metadata, and then there's a retry? We don't want
+ // the new metadata to be added twice. We might need to somehow
+ // allocate the subchannel batch earlier so that we can give the
+ // subchannel's copy of the metadata batch (which is copied for each
+ // attempt) to the LB policy instead the one from the parent channel.
+ grpc_metadata_batch* initial_metadata =
+ calld->seen_send_initial_metadata
+ ? &calld->send_initial_metadata
+ : calld->pending_batches[0]
+ .batch->payload->send_initial_metadata.send_initial_metadata;
+ uint32_t* initial_metadata_flags =
+ calld->seen_send_initial_metadata
+ ? &calld->send_initial_metadata_flags
+ : &calld->pending_batches[0]
+ .batch->payload->send_initial_metadata
+ .send_initial_metadata_flags;
+ GRPC_CLOSURE_INIT(&calld->pick_closure, pick_done, elem,
+ grpc_schedule_on_exec_ctx);
+ calld->request.Init(calld->owning_call, calld->call_combiner, calld->pollent,
+ initial_metadata, initial_metadata_flags,
+ maybe_apply_service_config_to_call_locked, elem,
+ &calld->pick_closure);
+ calld->have_request = true;
+ chand->request_router->RouteCallLocked(calld->request.get());
}
//
@@ -3249,23 +2548,10 @@ const grpc_channel_filter grpc_client_channel_filter = {
"client-channel",
};
-static void try_to_connect_locked(void* arg, grpc_error* error_ignored) {
- channel_data* chand = static_cast<channel_data*>(arg);
- if (chand->lb_policy != nullptr) {
- chand->lb_policy->ExitIdleLocked();
- } else {
- chand->exit_idle_when_lb_policy_arrives = true;
- if (!chand->started_resolving && chand->resolver != nullptr) {
- start_resolving_locked(chand);
- }
- }
- GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect");
-}
-
void grpc_client_channel_set_channelz_node(
grpc_channel_element* elem, grpc_core::channelz::ClientChannelNode* node) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- chand->channelz_channel = node;
+ chand->request_router->set_channelz_node(node);
}
void grpc_client_channel_populate_child_refs(
@@ -3273,17 +2559,22 @@ void grpc_client_channel_populate_child_refs(
grpc_core::channelz::ChildRefsList* child_subchannels,
grpc_core::channelz::ChildRefsList* child_channels) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- if (chand->lb_policy != nullptr) {
- chand->lb_policy->FillChildRefsForChannelz(child_subchannels,
- child_channels);
+ if (chand->request_router->lb_policy() != nullptr) {
+ chand->request_router->lb_policy()->FillChildRefsForChannelz(
+ child_subchannels, child_channels);
}
}
+static void try_to_connect_locked(void* arg, grpc_error* error_ignored) {
+ channel_data* chand = static_cast<channel_data*>(arg);
+ chand->request_router->ExitIdleLocked();
+ GRPC_CHANNEL_STACK_UNREF(chand->owning_stack, "try_to_connect");
+}
+
grpc_connectivity_state grpc_client_channel_check_connectivity_state(
grpc_channel_element* elem, int try_to_connect) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- grpc_connectivity_state out =
- grpc_connectivity_state_check(&chand->state_tracker);
+ grpc_connectivity_state out = chand->request_router->GetConnectivityState();
if (out == GRPC_CHANNEL_IDLE && try_to_connect) {
GRPC_CHANNEL_STACK_REF(chand->owning_stack, "try_to_connect");
GRPC_CLOSURE_SCHED(
@@ -3328,19 +2619,19 @@ static void external_connectivity_watcher_list_append(
}
static void external_connectivity_watcher_list_remove(
- channel_data* chand, external_connectivity_watcher* too_remove) {
+ channel_data* chand, external_connectivity_watcher* to_remove) {
GPR_ASSERT(
- lookup_external_connectivity_watcher(chand, too_remove->on_complete));
+ lookup_external_connectivity_watcher(chand, to_remove->on_complete));
gpr_mu_lock(&chand->external_connectivity_watcher_list_mu);
- if (too_remove == chand->external_connectivity_watcher_list_head) {
- chand->external_connectivity_watcher_list_head = too_remove->next;
+ if (to_remove == chand->external_connectivity_watcher_list_head) {
+ chand->external_connectivity_watcher_list_head = to_remove->next;
gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu);
return;
}
external_connectivity_watcher* w =
chand->external_connectivity_watcher_list_head;
while (w != nullptr) {
- if (w->next == too_remove) {
+ if (w->next == to_remove) {
w->next = w->next->next;
gpr_mu_unlock(&chand->external_connectivity_watcher_list_mu);
return;
@@ -3392,15 +2683,15 @@ static void watch_connectivity_state_locked(void* arg,
GRPC_CLOSURE_RUN(w->watcher_timer_init, GRPC_ERROR_NONE);
GRPC_CLOSURE_INIT(&w->my_closure, on_external_watch_complete_locked, w,
grpc_combiner_scheduler(w->chand->combiner));
- grpc_connectivity_state_notify_on_state_change(&w->chand->state_tracker,
- w->state, &w->my_closure);
+ w->chand->request_router->NotifyOnConnectivityStateChange(w->state,
+ &w->my_closure);
} else {
GPR_ASSERT(w->watcher_timer_init == nullptr);
found = lookup_external_connectivity_watcher(w->chand, w->on_complete);
if (found) {
GPR_ASSERT(found->on_complete == w->on_complete);
- grpc_connectivity_state_notify_on_state_change(
- &found->chand->state_tracker, nullptr, &found->my_closure);
+ found->chand->request_router->NotifyOnConnectivityStateChange(
+ nullptr, &found->my_closure);
}
grpc_polling_entity_del_from_pollset_set(&w->pollent,
w->chand->interested_parties);
diff --git a/src/core/ext/filters/client_channel/lb_policy.h b/src/core/ext/filters/client_channel/lb_policy.h
index 6b76fe5d5d..293d8e960c 100644
--- a/src/core/ext/filters/client_channel/lb_policy.h
+++ b/src/core/ext/filters/client_channel/lb_policy.h
@@ -65,10 +65,10 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> {
struct PickState {
/// Initial metadata associated with the picking call.
grpc_metadata_batch* initial_metadata = nullptr;
- /// Bitmask used for selective cancelling. See
+ /// Pointer to bitmask used for selective cancelling. See
/// \a CancelMatchingPicksLocked() and \a GRPC_INITIAL_METADATA_* in
/// grpc_types.h.
- uint32_t initial_metadata_flags = 0;
+ uint32_t* initial_metadata_flags = nullptr;
/// Storage for LB token in \a initial_metadata, or nullptr if not used.
grpc_linked_mdelem lb_token_mdelem_storage;
/// Closure to run when pick is complete, if not completed synchronously.
@@ -88,6 +88,9 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> {
LoadBalancingPolicy(const LoadBalancingPolicy&) = delete;
LoadBalancingPolicy& operator=(const LoadBalancingPolicy&) = delete;
+ /// Returns the name of the LB policy.
+ virtual const char* name() const GRPC_ABSTRACT;
+
/// Updates the policy with a new set of \a args and a new \a lb_config from
/// the resolver. Note that the LB policy gets the set of addresses from the
/// GRPC_ARG_SERVER_ADDRESS_LIST channel arg.
@@ -205,12 +208,6 @@ class LoadBalancingPolicy : public InternallyRefCounted<LoadBalancingPolicy> {
grpc_pollset_set* interested_parties_;
/// Callback to force a re-resolution.
grpc_closure* request_reresolution_;
-
- // Dummy classes needed for alignment issues.
- // See https://github.com/grpc/grpc/issues/16032 for context.
- // TODO(ncteisen): remove this as soon as the issue is resolved.
- channelz::ChildRefsList dummy_list_foo;
- channelz::ChildRefsList dummy_list_bar;
};
} // namespace grpc_core
diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc
index a9a5965ed1..ba40febd53 100644
--- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.cc
@@ -122,10 +122,14 @@ TraceFlag grpc_lb_glb_trace(false, "glb");
namespace {
+constexpr char kGrpclb[] = "grpclb";
+
class GrpcLb : public LoadBalancingPolicy {
public:
explicit GrpcLb(const Args& args);
+ const char* name() const override { return kGrpclb; }
+
void UpdateLocked(const grpc_channel_args& args,
grpc_json* lb_config) override;
bool PickLocked(PickState* pick, grpc_error** error) override;
@@ -361,7 +365,9 @@ void lb_token_destroy(void* token) {
}
}
int lb_token_cmp(void* token1, void* token2) {
- return GPR_ICMP(token1, token2);
+ // Always indicate a match, since we don't want this channel arg to
+ // affect the subchannel's key in the index.
+ return 0;
}
const grpc_arg_pointer_vtable lb_token_arg_vtable = {
lb_token_copy, lb_token_destroy, lb_token_cmp};
@@ -422,7 +428,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) {
grpc_resolved_address addr;
ParseServer(server, &addr);
// LB token processing.
- void* lb_token;
+ grpc_mdelem lb_token;
if (server->has_load_balance_token) {
const size_t lb_token_max_length =
GPR_ARRAY_SIZE(server->load_balance_token);
@@ -430,9 +436,7 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) {
strnlen(server->load_balance_token, lb_token_max_length);
grpc_slice lb_token_mdstr = grpc_slice_from_copied_buffer(
server->load_balance_token, lb_token_length);
- lb_token =
- (void*)grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr)
- .payload;
+ lb_token = grpc_mdelem_from_slices(GRPC_MDSTR_LB_TOKEN, lb_token_mdstr);
} else {
char* uri = grpc_sockaddr_to_uri(&addr);
gpr_log(GPR_INFO,
@@ -440,14 +444,16 @@ ServerAddressList ProcessServerlist(const grpc_grpclb_serverlist* serverlist) {
"be used instead",
uri);
gpr_free(uri);
- lb_token = (void*)GRPC_MDELEM_LB_TOKEN_EMPTY.payload;
+ lb_token = GRPC_MDELEM_LB_TOKEN_EMPTY;
}
// Add address.
grpc_arg arg = grpc_channel_arg_pointer_create(
- const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN), lb_token,
- &lb_token_arg_vtable);
+ const_cast<char*>(GRPC_ARG_GRPCLB_ADDRESS_LB_TOKEN),
+ (void*)lb_token.payload, &lb_token_arg_vtable);
grpc_channel_args* args = grpc_channel_args_copy_and_add(nullptr, &arg, 1);
addresses.emplace_back(addr, args);
+ // Clean up.
+ GRPC_MDELEM_UNREF(lb_token);
}
return addresses;
}
@@ -525,8 +531,7 @@ void GrpcLb::BalancerCallState::Orphan() {
void GrpcLb::BalancerCallState::StartQuery() {
GPR_ASSERT(lb_call_ != nullptr);
if (grpc_lb_glb_trace.enabled()) {
- gpr_log(GPR_INFO,
- "[grpclb %p] Starting LB call (lb_calld: %p, lb_call: %p)",
+ gpr_log(GPR_INFO, "[grpclb %p] lb_calld=%p: Starting LB call %p",
grpclb_policy_.get(), this, lb_call_);
}
// Create the ops.
@@ -670,8 +675,9 @@ void GrpcLb::BalancerCallState::SendClientLoadReportLocked() {
grpc_call_error call_error = grpc_call_start_batch_and_execute(
lb_call_, &op, 1, &client_load_report_closure_);
if (GPR_UNLIKELY(call_error != GRPC_CALL_OK)) {
- gpr_log(GPR_ERROR, "[grpclb %p] call_error=%d", grpclb_policy_.get(),
- call_error);
+ gpr_log(GPR_ERROR,
+ "[grpclb %p] lb_calld=%p call_error=%d sending client load report",
+ grpclb_policy_.get(), this, call_error);
GPR_ASSERT(GRPC_CALL_OK == call_error);
}
}
@@ -732,15 +738,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked(
&initial_response->client_stats_report_interval));
if (grpc_lb_glb_trace.enabled()) {
gpr_log(GPR_INFO,
- "[grpclb %p] Received initial LB response message; "
- "client load reporting interval = %" PRId64 " milliseconds",
- grpclb_policy, lb_calld->client_stats_report_interval_);
+ "[grpclb %p] lb_calld=%p: Received initial LB response "
+ "message; client load reporting interval = %" PRId64
+ " milliseconds",
+ grpclb_policy, lb_calld,
+ lb_calld->client_stats_report_interval_);
}
} else if (grpc_lb_glb_trace.enabled()) {
gpr_log(GPR_INFO,
- "[grpclb %p] Received initial LB response message; client load "
- "reporting NOT enabled",
- grpclb_policy);
+ "[grpclb %p] lb_calld=%p: Received initial LB response message; "
+ "client load reporting NOT enabled",
+ grpclb_policy, lb_calld);
}
grpc_grpclb_initial_response_destroy(initial_response);
lb_calld->seen_initial_response_ = true;
@@ -750,15 +758,17 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked(
GPR_ASSERT(lb_calld->lb_call_ != nullptr);
if (grpc_lb_glb_trace.enabled()) {
gpr_log(GPR_INFO,
- "[grpclb %p] Serverlist with %" PRIuPTR " servers received",
- grpclb_policy, serverlist->num_servers);
+ "[grpclb %p] lb_calld=%p: Serverlist with %" PRIuPTR
+ " servers received",
+ grpclb_policy, lb_calld, serverlist->num_servers);
for (size_t i = 0; i < serverlist->num_servers; ++i) {
grpc_resolved_address addr;
ParseServer(serverlist->servers[i], &addr);
char* ipport;
grpc_sockaddr_to_string(&ipport, &addr, false);
- gpr_log(GPR_INFO, "[grpclb %p] Serverlist[%" PRIuPTR "]: %s",
- grpclb_policy, i, ipport);
+ gpr_log(GPR_INFO,
+ "[grpclb %p] lb_calld=%p: Serverlist[%" PRIuPTR "]: %s",
+ grpclb_policy, lb_calld, i, ipport);
gpr_free(ipport);
}
}
@@ -778,9 +788,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked(
if (grpc_grpclb_serverlist_equals(grpclb_policy->serverlist_, serverlist)) {
if (grpc_lb_glb_trace.enabled()) {
gpr_log(GPR_INFO,
- "[grpclb %p] Incoming server list identical to current, "
- "ignoring.",
- grpclb_policy);
+ "[grpclb %p] lb_calld=%p: Incoming server list identical to "
+ "current, ignoring.",
+ grpclb_policy, lb_calld);
}
grpc_grpclb_destroy_serverlist(serverlist);
} else { // New serverlist.
@@ -806,8 +816,9 @@ void GrpcLb::BalancerCallState::OnBalancerMessageReceivedLocked(
char* response_slice_str =
grpc_dump_slice(response_slice, GPR_DUMP_ASCII | GPR_DUMP_HEX);
gpr_log(GPR_ERROR,
- "[grpclb %p] Invalid LB response received: '%s'. Ignoring.",
- grpclb_policy, response_slice_str);
+ "[grpclb %p] lb_calld=%p: Invalid LB response received: '%s'. "
+ "Ignoring.",
+ grpclb_policy, lb_calld, response_slice_str);
gpr_free(response_slice_str);
}
grpc_slice_unref_internal(response_slice);
@@ -838,9 +849,9 @@ void GrpcLb::BalancerCallState::OnBalancerStatusReceivedLocked(
char* status_details =
grpc_slice_to_c_string(lb_calld->lb_call_status_details_);
gpr_log(GPR_INFO,
- "[grpclb %p] Status from LB server received. Status = %d, details "
- "= '%s', (lb_calld: %p, lb_call: %p), error '%s'",
- grpclb_policy, lb_calld->lb_call_status_, status_details, lb_calld,
+ "[grpclb %p] lb_calld=%p: Status from LB server received. "
+ "Status = %d, details = '%s', (lb_call: %p), error '%s'",
+ grpclb_policy, lb_calld, lb_calld->lb_call_status_, status_details,
lb_calld->lb_call_, grpc_error_string(error));
gpr_free(status_details);
}
@@ -1129,7 +1140,7 @@ void GrpcLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask,
pending_picks_ = nullptr;
while (pp != nullptr) {
PendingPick* next = pp->next;
- if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) ==
+ if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) ==
initial_metadata_flags_eq) {
// Note: pp is deleted in this callback.
GRPC_CLOSURE_SCHED(&pp->on_complete,
@@ -1592,6 +1603,10 @@ void GrpcLb::CreateRoundRobinPolicyLocked(const Args& args) {
this);
return;
}
+ if (grpc_lb_glb_trace.enabled()) {
+ gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this,
+ rr_policy_.get());
+ }
// TODO(roth): We currently track this ref manually. Once the new
// ClosureRef API is done, pass the RefCountedPtr<> along with the closure.
auto self = Ref(DEBUG_LOCATION, "on_rr_reresolution_requested");
@@ -1685,10 +1700,6 @@ void GrpcLb::CreateOrUpdateRoundRobinPolicyLocked() {
lb_policy_args.client_channel_factory = client_channel_factory();
lb_policy_args.args = args;
CreateRoundRobinPolicyLocked(lb_policy_args);
- if (grpc_lb_glb_trace.enabled()) {
- gpr_log(GPR_INFO, "[grpclb %p] Created new RR policy %p", this,
- rr_policy_.get());
- }
}
grpc_channel_args_destroy(args);
}
@@ -1812,7 +1823,7 @@ class GrpcLbFactory : public LoadBalancingPolicyFactory {
return OrphanablePtr<LoadBalancingPolicy>(New<GrpcLb>(args));
}
- const char* name() const override { return "grpclb"; }
+ const char* name() const override { return kGrpclb; }
};
} // namespace
diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc
index 6e8fbdcab7..657ff69312 100644
--- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc
@@ -88,22 +88,18 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args(
// bearer token credentials.
grpc_channel_credentials* channel_credentials =
grpc_channel_credentials_find_in_args(args);
- grpc_channel_credentials* creds_sans_call_creds = nullptr;
+ grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds;
if (channel_credentials != nullptr) {
creds_sans_call_creds =
- grpc_channel_credentials_duplicate_without_call_credentials(
- channel_credentials);
+ channel_credentials->duplicate_without_call_credentials();
GPR_ASSERT(creds_sans_call_creds != nullptr);
args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS;
args_to_add[num_args_to_add++] =
- grpc_channel_credentials_to_arg(creds_sans_call_creds);
+ grpc_channel_credentials_to_arg(creds_sans_call_creds.get());
}
grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add);
// Clean up.
grpc_channel_args_destroy(args);
- if (creds_sans_call_creds != nullptr) {
- grpc_channel_credentials_unref(creds_sans_call_creds);
- }
return result;
}
diff --git a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc
index 74c17612a2..d6ff74ec7f 100644
--- a/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc
@@ -43,10 +43,14 @@ namespace {
// pick_first LB policy
//
+constexpr char kPickFirst[] = "pick_first";
+
class PickFirst : public LoadBalancingPolicy {
public:
explicit PickFirst(const Args& args);
+ const char* name() const override { return kPickFirst; }
+
void UpdateLocked(const grpc_channel_args& args,
grpc_json* lb_config) override;
bool PickLocked(PickState* pick, grpc_error** error) override;
@@ -234,7 +238,7 @@ void PickFirst::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask,
pending_picks_ = nullptr;
while (pick != nullptr) {
PickState* next = pick->next;
- if ((pick->initial_metadata_flags & initial_metadata_flags_mask) ==
+ if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) ==
initial_metadata_flags_eq) {
GRPC_CLOSURE_SCHED(pick->on_complete,
GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
@@ -622,7 +626,7 @@ class PickFirstFactory : public LoadBalancingPolicyFactory {
return OrphanablePtr<LoadBalancingPolicy>(New<PickFirst>(args));
}
- const char* name() const override { return "pick_first"; }
+ const char* name() const override { return kPickFirst; }
};
} // namespace
diff --git a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc
index 63089afbd7..3bcb33ef11 100644
--- a/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.cc
@@ -53,10 +53,14 @@ namespace {
// round_robin LB policy
//
+constexpr char kRoundRobin[] = "round_robin";
+
class RoundRobin : public LoadBalancingPolicy {
public:
explicit RoundRobin(const Args& args);
+ const char* name() const override { return kRoundRobin; }
+
void UpdateLocked(const grpc_channel_args& args,
grpc_json* lb_config) override;
bool PickLocked(PickState* pick, grpc_error** error) override;
@@ -291,7 +295,7 @@ void RoundRobin::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask,
pending_picks_ = nullptr;
while (pick != nullptr) {
PickState* next = pick->next;
- if ((pick->initial_metadata_flags & initial_metadata_flags_mask) ==
+ if ((*pick->initial_metadata_flags & initial_metadata_flags_mask) ==
initial_metadata_flags_eq) {
pick->connected_subchannel.reset();
GRPC_CLOSURE_SCHED(pick->on_complete,
@@ -700,7 +704,7 @@ class RoundRobinFactory : public LoadBalancingPolicyFactory {
return OrphanablePtr<LoadBalancingPolicy>(New<RoundRobin>(args));
}
- const char* name() const override { return "round_robin"; }
+ const char* name() const override { return kRoundRobin; }
};
} // namespace
diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc
index 3c25de2386..8787f5bcc2 100644
--- a/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds.cc
@@ -115,10 +115,14 @@ TraceFlag grpc_lb_xds_trace(false, "xds");
namespace {
+constexpr char kXds[] = "xds_experimental";
+
class XdsLb : public LoadBalancingPolicy {
public:
explicit XdsLb(const Args& args);
+ const char* name() const override { return kXds; }
+
void UpdateLocked(const grpc_channel_args& args,
grpc_json* lb_config) override;
bool PickLocked(PickState* pick, grpc_error** error) override;
@@ -1053,7 +1057,7 @@ void XdsLb::CancelMatchingPicksLocked(uint32_t initial_metadata_flags_mask,
pending_picks_ = nullptr;
while (pp != nullptr) {
PendingPick* next = pp->next;
- if ((pp->pick->initial_metadata_flags & initial_metadata_flags_mask) ==
+ if ((*pp->pick->initial_metadata_flags & initial_metadata_flags_mask) ==
initial_metadata_flags_eq) {
// Note: pp is deleted in this callback.
GRPC_CLOSURE_SCHED(&pp->on_complete,
@@ -1651,7 +1655,7 @@ class XdsFactory : public LoadBalancingPolicyFactory {
return OrphanablePtr<LoadBalancingPolicy>(New<XdsLb>(args));
}
- const char* name() const override { return "xds_experimental"; }
+ const char* name() const override { return kXds; }
};
} // namespace
diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc
index 9a11f8e39f..55c646e6ee 100644
--- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc
@@ -87,22 +87,18 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args(
// bearer token credentials.
grpc_channel_credentials* channel_credentials =
grpc_channel_credentials_find_in_args(args);
- grpc_channel_credentials* creds_sans_call_creds = nullptr;
+ grpc_core::RefCountedPtr<grpc_channel_credentials> creds_sans_call_creds;
if (channel_credentials != nullptr) {
creds_sans_call_creds =
- grpc_channel_credentials_duplicate_without_call_credentials(
- channel_credentials);
+ channel_credentials->duplicate_without_call_credentials();
GPR_ASSERT(creds_sans_call_creds != nullptr);
args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS;
args_to_add[num_args_to_add++] =
- grpc_channel_credentials_to_arg(creds_sans_call_creds);
+ grpc_channel_credentials_to_arg(creds_sans_call_creds.get());
}
grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add);
// Clean up.
grpc_channel_args_destroy(args);
- if (creds_sans_call_creds != nullptr) {
- grpc_channel_credentials_unref(creds_sans_call_creds);
- }
return result;
}
diff --git a/src/core/ext/filters/client_channel/request_routing.cc b/src/core/ext/filters/client_channel/request_routing.cc
new file mode 100644
index 0000000000..f9a7e164e7
--- /dev/null
+++ b/src/core/ext/filters/client_channel/request_routing.cc
@@ -0,0 +1,936 @@
+/*
+ *
+ * Copyright 2015 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include <grpc/support/port_platform.h>
+
+#include "src/core/ext/filters/client_channel/request_routing.h"
+
+#include <inttypes.h>
+#include <limits.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <string.h>
+
+#include <grpc/support/alloc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/string_util.h>
+#include <grpc/support/sync.h>
+
+#include "src/core/ext/filters/client_channel/backup_poller.h"
+#include "src/core/ext/filters/client_channel/http_connect_handshaker.h"
+#include "src/core/ext/filters/client_channel/lb_policy_registry.h"
+#include "src/core/ext/filters/client_channel/proxy_mapper_registry.h"
+#include "src/core/ext/filters/client_channel/resolver_registry.h"
+#include "src/core/ext/filters/client_channel/retry_throttle.h"
+#include "src/core/ext/filters/client_channel/server_address.h"
+#include "src/core/ext/filters/client_channel/subchannel.h"
+#include "src/core/ext/filters/deadline/deadline_filter.h"
+#include "src/core/lib/backoff/backoff.h"
+#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/channel/connected_channel.h"
+#include "src/core/lib/channel/status_util.h"
+#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/inlined_vector.h"
+#include "src/core/lib/gprpp/manual_constructor.h"
+#include "src/core/lib/iomgr/combiner.h"
+#include "src/core/lib/iomgr/iomgr.h"
+#include "src/core/lib/iomgr/polling_entity.h"
+#include "src/core/lib/profiling/timers.h"
+#include "src/core/lib/slice/slice_internal.h"
+#include "src/core/lib/slice/slice_string_helpers.h"
+#include "src/core/lib/surface/channel.h"
+#include "src/core/lib/transport/connectivity_state.h"
+#include "src/core/lib/transport/error_utils.h"
+#include "src/core/lib/transport/metadata.h"
+#include "src/core/lib/transport/metadata_batch.h"
+#include "src/core/lib/transport/service_config.h"
+#include "src/core/lib/transport/static_metadata.h"
+#include "src/core/lib/transport/status_metadata.h"
+
+namespace grpc_core {
+
+//
+// RequestRouter::Request::ResolverResultWaiter
+//
+
+// Handles waiting for a resolver result.
+// Used only for the first call on an idle channel.
+class RequestRouter::Request::ResolverResultWaiter {
+ public:
+ explicit ResolverResultWaiter(Request* request)
+ : request_router_(request->request_router_),
+ request_(request),
+ tracer_enabled_(request_router_->tracer_->enabled()) {
+ if (tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: deferring pick pending resolver "
+ "result",
+ request_router_, request);
+ }
+ // Add closure to be run when a resolver result is available.
+ GRPC_CLOSURE_INIT(&done_closure_, &DoneLocked, this,
+ grpc_combiner_scheduler(request_router_->combiner_));
+ AddToWaitingList();
+ // Set cancellation closure, so that we abort if the call is cancelled.
+ GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this,
+ grpc_combiner_scheduler(request_router_->combiner_));
+ grpc_call_combiner_set_notify_on_cancel(request->call_combiner_,
+ &cancel_closure_);
+ }
+
+ private:
+ // Adds done_closure_ to
+ // request_router_->waiting_for_resolver_result_closures_.
+ void AddToWaitingList() {
+ grpc_closure_list_append(
+ &request_router_->waiting_for_resolver_result_closures_, &done_closure_,
+ GRPC_ERROR_NONE);
+ }
+
+ // Invoked when a resolver result is available.
+ static void DoneLocked(void* arg, grpc_error* error) {
+ ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg);
+ RequestRouter* request_router = self->request_router_;
+ // If CancelLocked() has already run, delete ourselves without doing
+ // anything. Note that the call stack may have already been destroyed,
+ // so it's not safe to access anything in state_.
+ if (GPR_UNLIKELY(self->finished_)) {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p: call cancelled before resolver result",
+ request_router);
+ }
+ Delete(self);
+ return;
+ }
+ // Otherwise, process the resolver result.
+ Request* request = self->request_;
+ if (GPR_UNLIKELY(error != GRPC_ERROR_NONE)) {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: resolver failed to return data",
+ request_router, request);
+ }
+ GRPC_CLOSURE_RUN(request->on_route_done_, GRPC_ERROR_REF(error));
+ } else if (GPR_UNLIKELY(request_router->resolver_ == nullptr)) {
+ // Shutting down.
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO, "request_router=%p request=%p: resolver disconnected",
+ request_router, request);
+ }
+ GRPC_CLOSURE_RUN(request->on_route_done_,
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected"));
+ } else if (GPR_UNLIKELY(request_router->lb_policy_ == nullptr)) {
+ // Transient resolver failure.
+ // If call has wait_for_ready=true, try again; otherwise, fail.
+ if (*request->pick_.initial_metadata_flags &
+ GRPC_INITIAL_METADATA_WAIT_FOR_READY) {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: resolver returned but no LB "
+ "policy; wait_for_ready=true; trying again",
+ request_router, request);
+ }
+ // Re-add ourselves to the waiting list.
+ self->AddToWaitingList();
+ // Return early so that we don't set finished_ to true below.
+ return;
+ } else {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: resolver returned but no LB "
+ "policy; wait_for_ready=false; failing",
+ request_router, request);
+ }
+ GRPC_CLOSURE_RUN(
+ request->on_route_done_,
+ grpc_error_set_int(
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("Name resolution failure"),
+ GRPC_ERROR_INT_GRPC_STATUS, GRPC_STATUS_UNAVAILABLE));
+ }
+ } else {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: resolver returned, doing LB "
+ "pick",
+ request_router, request);
+ }
+ request->ProcessServiceConfigAndStartLbPickLocked();
+ }
+ self->finished_ = true;
+ }
+
+ // Invoked when the call is cancelled.
+ // Note: This runs under the client_channel combiner, but will NOT be
+ // holding the call combiner.
+ static void CancelLocked(void* arg, grpc_error* error) {
+ ResolverResultWaiter* self = static_cast<ResolverResultWaiter*>(arg);
+ RequestRouter* request_router = self->request_router_;
+ // If DoneLocked() has already run, delete ourselves without doing anything.
+ if (self->finished_) {
+ Delete(self);
+ return;
+ }
+ Request* request = self->request_;
+ // If we are being cancelled, immediately invoke on_route_done_
+ // to propagate the error back to the caller.
+ if (error != GRPC_ERROR_NONE) {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: cancelling call waiting for "
+ "name resolution",
+ request_router, request);
+ }
+ // Note: Although we are not in the call combiner here, we are
+ // basically stealing the call combiner from the pending pick, so
+ // it's safe to run on_route_done_ here -- we are essentially
+ // calling it here instead of calling it in DoneLocked().
+ GRPC_CLOSURE_RUN(request->on_route_done_,
+ GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
+ "Pick cancelled", &error, 1));
+ }
+ self->finished_ = true;
+ }
+
+ RequestRouter* request_router_;
+ Request* request_;
+ const bool tracer_enabled_;
+ grpc_closure done_closure_;
+ grpc_closure cancel_closure_;
+ bool finished_ = false;
+};
+
+//
+// RequestRouter::Request::AsyncPickCanceller
+//
+
+// Handles the call combiner cancellation callback for an async LB pick.
+class RequestRouter::Request::AsyncPickCanceller {
+ public:
+ explicit AsyncPickCanceller(Request* request)
+ : request_router_(request->request_router_),
+ request_(request),
+ tracer_enabled_(request_router_->tracer_->enabled()) {
+ GRPC_CALL_STACK_REF(request->owning_call_, "pick_callback_cancel");
+ // Set cancellation closure, so that we abort if the call is cancelled.
+ GRPC_CLOSURE_INIT(&cancel_closure_, &CancelLocked, this,
+ grpc_combiner_scheduler(request_router_->combiner_));
+ grpc_call_combiner_set_notify_on_cancel(request->call_combiner_,
+ &cancel_closure_);
+ }
+
+ void MarkFinishedLocked() {
+ finished_ = true;
+ GRPC_CALL_STACK_UNREF(request_->owning_call_, "pick_callback_cancel");
+ }
+
+ private:
+ // Invoked when the call is cancelled.
+ // Note: This runs under the client_channel combiner, but will NOT be
+ // holding the call combiner.
+ static void CancelLocked(void* arg, grpc_error* error) {
+ AsyncPickCanceller* self = static_cast<AsyncPickCanceller*>(arg);
+ Request* request = self->request_;
+ RequestRouter* request_router = self->request_router_;
+ if (!self->finished_) {
+ // Note: request_router->lb_policy_ may have changed since we started our
+ // pick, in which case we will be cancelling the pick on a policy other
+ // than the one we started it on. However, this will just be a no-op.
+ if (error != GRPC_ERROR_NONE && request_router->lb_policy_ != nullptr) {
+ if (self->tracer_enabled_) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: cancelling pick from LB "
+ "policy %p",
+ request_router, request, request_router->lb_policy_.get());
+ }
+ request_router->lb_policy_->CancelPickLocked(&request->pick_,
+ GRPC_ERROR_REF(error));
+ }
+ request->pick_canceller_ = nullptr;
+ GRPC_CALL_STACK_UNREF(request->owning_call_, "pick_callback_cancel");
+ }
+ Delete(self);
+ }
+
+ RequestRouter* request_router_;
+ Request* request_;
+ const bool tracer_enabled_;
+ grpc_closure cancel_closure_;
+ bool finished_ = false;
+};
+
+//
+// RequestRouter::Request
+//
+
+RequestRouter::Request::Request(grpc_call_stack* owning_call,
+ grpc_call_combiner* call_combiner,
+ grpc_polling_entity* pollent,
+ grpc_metadata_batch* send_initial_metadata,
+ uint32_t* send_initial_metadata_flags,
+ ApplyServiceConfigCallback apply_service_config,
+ void* apply_service_config_user_data,
+ grpc_closure* on_route_done)
+ : owning_call_(owning_call),
+ call_combiner_(call_combiner),
+ pollent_(pollent),
+ apply_service_config_(apply_service_config),
+ apply_service_config_user_data_(apply_service_config_user_data),
+ on_route_done_(on_route_done) {
+ pick_.initial_metadata = send_initial_metadata;
+ pick_.initial_metadata_flags = send_initial_metadata_flags;
+}
+
+RequestRouter::Request::~Request() {
+ if (pick_.connected_subchannel != nullptr) {
+ pick_.connected_subchannel.reset();
+ }
+ for (size_t i = 0; i < GRPC_CONTEXT_COUNT; ++i) {
+ if (pick_.subchannel_call_context[i].destroy != nullptr) {
+ pick_.subchannel_call_context[i].destroy(
+ pick_.subchannel_call_context[i].value);
+ }
+ }
+}
+
+// Invoked once resolver results are available.
+void RequestRouter::Request::ProcessServiceConfigAndStartLbPickLocked() {
+ // Get service config data if needed.
+ if (!apply_service_config_(apply_service_config_user_data_)) return;
+ // Start LB pick.
+ StartLbPickLocked();
+}
+
+void RequestRouter::Request::MaybeAddCallToInterestedPartiesLocked() {
+ if (!pollent_added_to_interested_parties_) {
+ pollent_added_to_interested_parties_ = true;
+ grpc_polling_entity_add_to_pollset_set(
+ pollent_, request_router_->interested_parties_);
+ }
+}
+
+void RequestRouter::Request::MaybeRemoveCallFromInterestedPartiesLocked() {
+ if (pollent_added_to_interested_parties_) {
+ pollent_added_to_interested_parties_ = false;
+ grpc_polling_entity_del_from_pollset_set(
+ pollent_, request_router_->interested_parties_);
+ }
+}
+
+// Starts a pick on the LB policy.
+void RequestRouter::Request::StartLbPickLocked() {
+ if (request_router_->tracer_->enabled()) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: starting pick on lb_policy=%p",
+ request_router_, this, request_router_->lb_policy_.get());
+ }
+ GRPC_CLOSURE_INIT(&on_pick_done_, &LbPickDoneLocked, this,
+ grpc_combiner_scheduler(request_router_->combiner_));
+ pick_.on_complete = &on_pick_done_;
+ GRPC_CALL_STACK_REF(owning_call_, "pick_callback");
+ grpc_error* error = GRPC_ERROR_NONE;
+ const bool pick_done =
+ request_router_->lb_policy_->PickLocked(&pick_, &error);
+ if (pick_done) {
+ // Pick completed synchronously.
+ if (request_router_->tracer_->enabled()) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: pick completed synchronously",
+ request_router_, this);
+ }
+ GRPC_CLOSURE_RUN(on_route_done_, error);
+ GRPC_CALL_STACK_UNREF(owning_call_, "pick_callback");
+ } else {
+ // Pick will be returned asynchronously.
+ // Add the request's polling entity to the request_router's
+ // interested_parties, so that the I/O of the LB policy can be done
+ // under it. It will be removed in LbPickDoneLocked().
+ MaybeAddCallToInterestedPartiesLocked();
+ // Request notification on call cancellation.
+ // We allocate a separate object to track cancellation, since the
+ // cancellation closure might still be pending when we need to reuse
+ // the memory in which this Request object is stored for a subsequent
+ // retry attempt.
+ pick_canceller_ = New<AsyncPickCanceller>(this);
+ }
+}
+
+// Callback invoked by LoadBalancingPolicy::PickLocked() for async picks.
+// Unrefs the LB policy and invokes on_route_done_.
+void RequestRouter::Request::LbPickDoneLocked(void* arg, grpc_error* error) {
+ Request* self = static_cast<Request*>(arg);
+ RequestRouter* request_router = self->request_router_;
+ if (request_router->tracer_->enabled()) {
+ gpr_log(GPR_INFO,
+ "request_router=%p request=%p: pick completed asynchronously",
+ request_router, self);
+ }
+ self->MaybeRemoveCallFromInterestedPartiesLocked();
+ if (self->pick_canceller_ != nullptr) {
+ self->pick_canceller_->MarkFinishedLocked();
+ }
+ GRPC_CLOSURE_RUN(self->on_route_done_, GRPC_ERROR_REF(error));
+ GRPC_CALL_STACK_UNREF(self->owning_call_, "pick_callback");
+}
+
+//
+// RequestRouter::LbConnectivityWatcher
+//
+
+class RequestRouter::LbConnectivityWatcher {
+ public:
+ LbConnectivityWatcher(RequestRouter* request_router,
+ grpc_connectivity_state state,
+ LoadBalancingPolicy* lb_policy,
+ grpc_channel_stack* owning_stack,
+ grpc_combiner* combiner)
+ : request_router_(request_router),
+ state_(state),
+ lb_policy_(lb_policy),
+ owning_stack_(owning_stack) {
+ GRPC_CHANNEL_STACK_REF(owning_stack_, "LbConnectivityWatcher");
+ GRPC_CLOSURE_INIT(&on_changed_, &OnLbPolicyStateChangedLocked, this,
+ grpc_combiner_scheduler(combiner));
+ lb_policy_->NotifyOnStateChangeLocked(&state_, &on_changed_);
+ }
+
+ ~LbConnectivityWatcher() {
+ GRPC_CHANNEL_STACK_UNREF(owning_stack_, "LbConnectivityWatcher");
+ }
+
+ private:
+ static void OnLbPolicyStateChangedLocked(void* arg, grpc_error* error) {
+ LbConnectivityWatcher* self = static_cast<LbConnectivityWatcher*>(arg);
+ // If the notification is not for the current policy, we're stale,
+ // so delete ourselves.
+ if (self->lb_policy_ != self->request_router_->lb_policy_.get()) {
+ Delete(self);
+ return;
+ }
+ // Otherwise, process notification.
+ if (self->request_router_->tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: lb_policy=%p state changed to %s",
+ self->request_router_, self->lb_policy_,
+ grpc_connectivity_state_name(self->state_));
+ }
+ self->request_router_->SetConnectivityStateLocked(
+ self->state_, GRPC_ERROR_REF(error), "lb_changed");
+ // If shutting down, terminate watch.
+ if (self->state_ == GRPC_CHANNEL_SHUTDOWN) {
+ Delete(self);
+ return;
+ }
+ // Renew watch.
+ self->lb_policy_->NotifyOnStateChangeLocked(&self->state_,
+ &self->on_changed_);
+ }
+
+ RequestRouter* request_router_;
+ grpc_connectivity_state state_;
+ // LB policy address. No ref held, so not safe to dereference unless
+ // it happens to match request_router->lb_policy_.
+ LoadBalancingPolicy* lb_policy_;
+ grpc_channel_stack* owning_stack_;
+ grpc_closure on_changed_;
+};
+
+//
+// RequestRounter::ReresolutionRequestHandler
+//
+
+class RequestRouter::ReresolutionRequestHandler {
+ public:
+ ReresolutionRequestHandler(RequestRouter* request_router,
+ LoadBalancingPolicy* lb_policy,
+ grpc_channel_stack* owning_stack,
+ grpc_combiner* combiner)
+ : request_router_(request_router),
+ lb_policy_(lb_policy),
+ owning_stack_(owning_stack) {
+ GRPC_CHANNEL_STACK_REF(owning_stack_, "ReresolutionRequestHandler");
+ GRPC_CLOSURE_INIT(&closure_, &OnRequestReresolutionLocked, this,
+ grpc_combiner_scheduler(combiner));
+ lb_policy_->SetReresolutionClosureLocked(&closure_);
+ }
+
+ private:
+ static void OnRequestReresolutionLocked(void* arg, grpc_error* error) {
+ ReresolutionRequestHandler* self =
+ static_cast<ReresolutionRequestHandler*>(arg);
+ RequestRouter* request_router = self->request_router_;
+ // If this invocation is for a stale LB policy, treat it as an LB shutdown
+ // signal.
+ if (self->lb_policy_ != request_router->lb_policy_.get() ||
+ error != GRPC_ERROR_NONE || request_router->resolver_ == nullptr) {
+ GRPC_CHANNEL_STACK_UNREF(request_router->owning_stack_,
+ "ReresolutionRequestHandler");
+ Delete(self);
+ return;
+ }
+ if (request_router->tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: started name re-resolving",
+ request_router);
+ }
+ request_router->resolver_->RequestReresolutionLocked();
+ // Give back the closure to the LB policy.
+ self->lb_policy_->SetReresolutionClosureLocked(&self->closure_);
+ }
+
+ RequestRouter* request_router_;
+ // LB policy address. No ref held, so not safe to dereference unless
+ // it happens to match request_router->lb_policy_.
+ LoadBalancingPolicy* lb_policy_;
+ grpc_channel_stack* owning_stack_;
+ grpc_closure closure_;
+};
+
+//
+// RequestRouter
+//
+
+RequestRouter::RequestRouter(
+ grpc_channel_stack* owning_stack, grpc_combiner* combiner,
+ grpc_client_channel_factory* client_channel_factory,
+ grpc_pollset_set* interested_parties, TraceFlag* tracer,
+ ProcessResolverResultCallback process_resolver_result,
+ void* process_resolver_result_user_data, const char* target_uri,
+ const grpc_channel_args* args, grpc_error** error)
+ : owning_stack_(owning_stack),
+ combiner_(combiner),
+ client_channel_factory_(client_channel_factory),
+ interested_parties_(interested_parties),
+ tracer_(tracer),
+ process_resolver_result_(process_resolver_result),
+ process_resolver_result_user_data_(process_resolver_result_user_data) {
+ GRPC_CLOSURE_INIT(&on_resolver_result_changed_,
+ &RequestRouter::OnResolverResultChangedLocked, this,
+ grpc_combiner_scheduler(combiner));
+ grpc_connectivity_state_init(&state_tracker_, GRPC_CHANNEL_IDLE,
+ "request_router");
+ grpc_channel_args* new_args = nullptr;
+ if (process_resolver_result == nullptr) {
+ grpc_arg arg = grpc_channel_arg_integer_create(
+ const_cast<char*>(GRPC_ARG_SERVICE_CONFIG_DISABLE_RESOLUTION), 0);
+ new_args = grpc_channel_args_copy_and_add(args, &arg, 1);
+ }
+ resolver_ = ResolverRegistry::CreateResolver(
+ target_uri, (new_args == nullptr ? args : new_args), interested_parties_,
+ combiner_);
+ grpc_channel_args_destroy(new_args);
+ if (resolver_ == nullptr) {
+ *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING("resolver creation failed");
+ }
+}
+
+RequestRouter::~RequestRouter() {
+ if (resolver_ != nullptr) {
+ // The only way we can get here is if we never started resolving,
+ // because we take a ref to the channel stack when we start
+ // resolving and do not release it until the resolver callback is
+ // invoked after the resolver shuts down.
+ resolver_.reset();
+ }
+ if (lb_policy_ != nullptr) {
+ grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(),
+ interested_parties_);
+ lb_policy_.reset();
+ }
+ if (client_channel_factory_ != nullptr) {
+ grpc_client_channel_factory_unref(client_channel_factory_);
+ }
+ grpc_connectivity_state_destroy(&state_tracker_);
+}
+
+namespace {
+
+const char* GetChannelConnectivityStateChangeString(
+ grpc_connectivity_state state) {
+ switch (state) {
+ case GRPC_CHANNEL_IDLE:
+ return "Channel state change to IDLE";
+ case GRPC_CHANNEL_CONNECTING:
+ return "Channel state change to CONNECTING";
+ case GRPC_CHANNEL_READY:
+ return "Channel state change to READY";
+ case GRPC_CHANNEL_TRANSIENT_FAILURE:
+ return "Channel state change to TRANSIENT_FAILURE";
+ case GRPC_CHANNEL_SHUTDOWN:
+ return "Channel state change to SHUTDOWN";
+ }
+ GPR_UNREACHABLE_CODE(return "UNKNOWN");
+}
+
+} // namespace
+
+void RequestRouter::SetConnectivityStateLocked(grpc_connectivity_state state,
+ grpc_error* error,
+ const char* reason) {
+ if (lb_policy_ != nullptr) {
+ if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) {
+ // Cancel picks with wait_for_ready=false.
+ lb_policy_->CancelMatchingPicksLocked(
+ /* mask= */ GRPC_INITIAL_METADATA_WAIT_FOR_READY,
+ /* check= */ 0, GRPC_ERROR_REF(error));
+ } else if (state == GRPC_CHANNEL_SHUTDOWN) {
+ // Cancel all picks.
+ lb_policy_->CancelMatchingPicksLocked(/* mask= */ 0, /* check= */ 0,
+ GRPC_ERROR_REF(error));
+ }
+ }
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: setting connectivity state to %s",
+ this, grpc_connectivity_state_name(state));
+ }
+ if (channelz_node_ != nullptr) {
+ channelz_node_->AddTraceEvent(
+ channelz::ChannelTrace::Severity::Info,
+ grpc_slice_from_static_string(
+ GetChannelConnectivityStateChangeString(state)));
+ }
+ grpc_connectivity_state_set(&state_tracker_, state, error, reason);
+}
+
+void RequestRouter::StartResolvingLocked() {
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: starting name resolution", this);
+ }
+ GPR_ASSERT(!started_resolving_);
+ started_resolving_ = true;
+ GRPC_CHANNEL_STACK_REF(owning_stack_, "resolver");
+ resolver_->NextLocked(&resolver_result_, &on_resolver_result_changed_);
+}
+
+// Invoked from the resolver NextLocked() callback when the resolver
+// is shutting down.
+void RequestRouter::OnResolverShutdownLocked(grpc_error* error) {
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: shutting down", this);
+ }
+ if (lb_policy_ != nullptr) {
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this,
+ lb_policy_.get());
+ }
+ grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(),
+ interested_parties_);
+ lb_policy_.reset();
+ }
+ if (resolver_ != nullptr) {
+ // This should never happen; it can only be triggered by a resolver
+ // implementation spotaneously deciding to report shutdown without
+ // being orphaned. This code is included just to be defensive.
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO,
+ "request_router=%p: spontaneous shutdown from resolver %p", this,
+ resolver_.get());
+ }
+ resolver_.reset();
+ SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN,
+ GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
+ "Resolver spontaneous shutdown", &error, 1),
+ "resolver_spontaneous_shutdown");
+ }
+ grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_,
+ GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
+ "Channel disconnected", &error, 1));
+ GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_);
+ GRPC_CHANNEL_STACK_UNREF(owning_stack_, "resolver");
+ grpc_channel_args_destroy(resolver_result_);
+ resolver_result_ = nullptr;
+ GRPC_ERROR_UNREF(error);
+}
+
+// Creates a new LB policy, replacing any previous one.
+// If the new policy is created successfully, sets *connectivity_state and
+// *connectivity_error to its initial connectivity state; otherwise,
+// leaves them unchanged.
+void RequestRouter::CreateNewLbPolicyLocked(
+ const char* lb_policy_name, grpc_json* lb_config,
+ grpc_connectivity_state* connectivity_state,
+ grpc_error** connectivity_error, TraceStringVector* trace_strings) {
+ LoadBalancingPolicy::Args lb_policy_args;
+ lb_policy_args.combiner = combiner_;
+ lb_policy_args.client_channel_factory = client_channel_factory_;
+ lb_policy_args.args = resolver_result_;
+ lb_policy_args.lb_config = lb_config;
+ OrphanablePtr<LoadBalancingPolicy> new_lb_policy =
+ LoadBalancingPolicyRegistry::CreateLoadBalancingPolicy(lb_policy_name,
+ lb_policy_args);
+ if (GPR_UNLIKELY(new_lb_policy == nullptr)) {
+ gpr_log(GPR_ERROR, "could not create LB policy \"%s\"", lb_policy_name);
+ if (channelz_node_ != nullptr) {
+ char* str;
+ gpr_asprintf(&str, "Could not create LB policy \'%s\'", lb_policy_name);
+ trace_strings->push_back(str);
+ }
+ } else {
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: created new LB policy \"%s\" (%p)",
+ this, lb_policy_name, new_lb_policy.get());
+ }
+ if (channelz_node_ != nullptr) {
+ char* str;
+ gpr_asprintf(&str, "Created new LB policy \'%s\'", lb_policy_name);
+ trace_strings->push_back(str);
+ }
+ // Swap out the LB policy and update the fds in interested_parties_.
+ if (lb_policy_ != nullptr) {
+ if (tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: shutting down lb_policy=%p", this,
+ lb_policy_.get());
+ }
+ grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(),
+ interested_parties_);
+ lb_policy_->HandOffPendingPicksLocked(new_lb_policy.get());
+ }
+ lb_policy_ = std::move(new_lb_policy);
+ grpc_pollset_set_add_pollset_set(lb_policy_->interested_parties(),
+ interested_parties_);
+ // Create re-resolution request handler for the new LB policy. It
+ // will delete itself when no longer needed.
+ New<ReresolutionRequestHandler>(this, lb_policy_.get(), owning_stack_,
+ combiner_);
+ // Get the new LB policy's initial connectivity state and start a
+ // connectivity watch.
+ GRPC_ERROR_UNREF(*connectivity_error);
+ *connectivity_state =
+ lb_policy_->CheckConnectivityLocked(connectivity_error);
+ if (exit_idle_when_lb_policy_arrives_) {
+ lb_policy_->ExitIdleLocked();
+ exit_idle_when_lb_policy_arrives_ = false;
+ }
+ // Create new watcher. It will delete itself when done.
+ New<LbConnectivityWatcher>(this, *connectivity_state, lb_policy_.get(),
+ owning_stack_, combiner_);
+ }
+}
+
+void RequestRouter::MaybeAddTraceMessagesForAddressChangesLocked(
+ TraceStringVector* trace_strings) {
+ const ServerAddressList* addresses =
+ FindServerAddressListChannelArg(resolver_result_);
+ const bool resolution_contains_addresses =
+ addresses != nullptr && addresses->size() > 0;
+ if (!resolution_contains_addresses &&
+ previous_resolution_contained_addresses_) {
+ trace_strings->push_back(gpr_strdup("Address list became empty"));
+ } else if (resolution_contains_addresses &&
+ !previous_resolution_contained_addresses_) {
+ trace_strings->push_back(gpr_strdup("Address list became non-empty"));
+ }
+ previous_resolution_contained_addresses_ = resolution_contains_addresses;
+}
+
+void RequestRouter::ConcatenateAndAddChannelTraceLocked(
+ TraceStringVector* trace_strings) const {
+ if (!trace_strings->empty()) {
+ gpr_strvec v;
+ gpr_strvec_init(&v);
+ gpr_strvec_add(&v, gpr_strdup("Resolution event: "));
+ bool is_first = 1;
+ for (size_t i = 0; i < trace_strings->size(); ++i) {
+ if (!is_first) gpr_strvec_add(&v, gpr_strdup(", "));
+ is_first = false;
+ gpr_strvec_add(&v, (*trace_strings)[i]);
+ }
+ char* flat;
+ size_t flat_len = 0;
+ flat = gpr_strvec_flatten(&v, &flat_len);
+ channelz_node_->AddTraceEvent(
+ grpc_core::channelz::ChannelTrace::Severity::Info,
+ grpc_slice_new(flat, flat_len, gpr_free));
+ gpr_strvec_destroy(&v);
+ }
+}
+
+// Callback invoked when a resolver result is available.
+void RequestRouter::OnResolverResultChangedLocked(void* arg,
+ grpc_error* error) {
+ RequestRouter* self = static_cast<RequestRouter*>(arg);
+ if (self->tracer_->enabled()) {
+ const char* disposition =
+ self->resolver_result_ != nullptr
+ ? ""
+ : (error == GRPC_ERROR_NONE ? " (transient error)"
+ : " (resolver shutdown)");
+ gpr_log(GPR_INFO,
+ "request_router=%p: got resolver result: resolver_result=%p "
+ "error=%s%s",
+ self, self->resolver_result_, grpc_error_string(error),
+ disposition);
+ }
+ // Handle shutdown.
+ if (error != GRPC_ERROR_NONE || self->resolver_ == nullptr) {
+ self->OnResolverShutdownLocked(GRPC_ERROR_REF(error));
+ return;
+ }
+ // Data used to set the channel's connectivity state.
+ bool set_connectivity_state = true;
+ // We only want to trace the address resolution in the follow cases:
+ // (a) Address resolution resulted in service config change.
+ // (b) Address resolution that causes number of backends to go from
+ // zero to non-zero.
+ // (c) Address resolution that causes number of backends to go from
+ // non-zero to zero.
+ // (d) Address resolution that causes a new LB policy to be created.
+ //
+ // we track a list of strings to eventually be concatenated and traced.
+ TraceStringVector trace_strings;
+ grpc_connectivity_state connectivity_state = GRPC_CHANNEL_TRANSIENT_FAILURE;
+ grpc_error* connectivity_error =
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("No load balancing policy");
+ // resolver_result_ will be null in the case of a transient
+ // resolution error. In that case, we don't have any new result to
+ // process, which means that we keep using the previous result (if any).
+ if (self->resolver_result_ == nullptr) {
+ if (self->tracer_->enabled()) {
+ gpr_log(GPR_INFO, "request_router=%p: resolver transient failure", self);
+ }
+ // Don't override connectivity state if we already have an LB policy.
+ if (self->lb_policy_ != nullptr) set_connectivity_state = false;
+ } else {
+ // Parse the resolver result.
+ const char* lb_policy_name = nullptr;
+ grpc_json* lb_policy_config = nullptr;
+ const bool service_config_changed = self->process_resolver_result_(
+ self->process_resolver_result_user_data_, *self->resolver_result_,
+ &lb_policy_name, &lb_policy_config);
+ GPR_ASSERT(lb_policy_name != nullptr);
+ // Check to see if we're already using the right LB policy.
+ const bool lb_policy_name_changed =
+ self->lb_policy_ == nullptr ||
+ strcmp(self->lb_policy_->name(), lb_policy_name) != 0;
+ if (self->lb_policy_ != nullptr && !lb_policy_name_changed) {
+ // Continue using the same LB policy. Update with new addresses.
+ if (self->tracer_->enabled()) {
+ gpr_log(GPR_INFO,
+ "request_router=%p: updating existing LB policy \"%s\" (%p)",
+ self, lb_policy_name, self->lb_policy_.get());
+ }
+ self->lb_policy_->UpdateLocked(*self->resolver_result_, lb_policy_config);
+ // No need to set the channel's connectivity state; the existing
+ // watch on the LB policy will take care of that.
+ set_connectivity_state = false;
+ } else {
+ // Instantiate new LB policy.
+ self->CreateNewLbPolicyLocked(lb_policy_name, lb_policy_config,
+ &connectivity_state, &connectivity_error,
+ &trace_strings);
+ }
+ // Add channel trace event.
+ if (self->channelz_node_ != nullptr) {
+ if (service_config_changed) {
+ // TODO(ncteisen): might be worth somehow including a snippet of the
+ // config in the trace, at the risk of bloating the trace logs.
+ trace_strings.push_back(gpr_strdup("Service config changed"));
+ }
+ self->MaybeAddTraceMessagesForAddressChangesLocked(&trace_strings);
+ self->ConcatenateAndAddChannelTraceLocked(&trace_strings);
+ }
+ // Clean up.
+ grpc_channel_args_destroy(self->resolver_result_);
+ self->resolver_result_ = nullptr;
+ }
+ // Set the channel's connectivity state if needed.
+ if (set_connectivity_state) {
+ self->SetConnectivityStateLocked(connectivity_state, connectivity_error,
+ "resolver_result");
+ } else {
+ GRPC_ERROR_UNREF(connectivity_error);
+ }
+ // Invoke closures that were waiting for results and renew the watch.
+ GRPC_CLOSURE_LIST_SCHED(&self->waiting_for_resolver_result_closures_);
+ self->resolver_->NextLocked(&self->resolver_result_,
+ &self->on_resolver_result_changed_);
+}
+
+void RequestRouter::RouteCallLocked(Request* request) {
+ GPR_ASSERT(request->pick_.connected_subchannel == nullptr);
+ request->request_router_ = this;
+ if (lb_policy_ != nullptr) {
+ // We already have resolver results, so process the service config
+ // and start an LB pick.
+ request->ProcessServiceConfigAndStartLbPickLocked();
+ } else if (resolver_ == nullptr) {
+ GRPC_CLOSURE_RUN(request->on_route_done_,
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("Disconnected"));
+ } else {
+ // We do not yet have an LB policy, so wait for a resolver result.
+ if (!started_resolving_) {
+ StartResolvingLocked();
+ }
+ // Create a new waiter, which will delete itself when done.
+ New<Request::ResolverResultWaiter>(request);
+ // Add the request's polling entity to the request_router's
+ // interested_parties, so that the I/O of the resolver can be done
+ // under it. It will be removed in LbPickDoneLocked().
+ request->MaybeAddCallToInterestedPartiesLocked();
+ }
+}
+
+void RequestRouter::ShutdownLocked(grpc_error* error) {
+ if (resolver_ != nullptr) {
+ SetConnectivityStateLocked(GRPC_CHANNEL_SHUTDOWN, GRPC_ERROR_REF(error),
+ "disconnect");
+ resolver_.reset();
+ if (!started_resolving_) {
+ grpc_closure_list_fail_all(&waiting_for_resolver_result_closures_,
+ GRPC_ERROR_REF(error));
+ GRPC_CLOSURE_LIST_SCHED(&waiting_for_resolver_result_closures_);
+ }
+ if (lb_policy_ != nullptr) {
+ grpc_pollset_set_del_pollset_set(lb_policy_->interested_parties(),
+ interested_parties_);
+ lb_policy_.reset();
+ }
+ }
+ GRPC_ERROR_UNREF(error);
+}
+
+grpc_connectivity_state RequestRouter::GetConnectivityState() {
+ return grpc_connectivity_state_check(&state_tracker_);
+}
+
+void RequestRouter::NotifyOnConnectivityStateChange(
+ grpc_connectivity_state* state, grpc_closure* closure) {
+ grpc_connectivity_state_notify_on_state_change(&state_tracker_, state,
+ closure);
+}
+
+void RequestRouter::ExitIdleLocked() {
+ if (lb_policy_ != nullptr) {
+ lb_policy_->ExitIdleLocked();
+ } else {
+ exit_idle_when_lb_policy_arrives_ = true;
+ if (!started_resolving_ && resolver_ != nullptr) {
+ StartResolvingLocked();
+ }
+ }
+}
+
+void RequestRouter::ResetConnectionBackoffLocked() {
+ if (resolver_ != nullptr) {
+ resolver_->ResetBackoffLocked();
+ resolver_->RequestReresolutionLocked();
+ }
+ if (lb_policy_ != nullptr) {
+ lb_policy_->ResetBackoffLocked();
+ }
+}
+
+} // namespace grpc_core
diff --git a/src/core/ext/filters/client_channel/request_routing.h b/src/core/ext/filters/client_channel/request_routing.h
new file mode 100644
index 0000000000..0c671229c8
--- /dev/null
+++ b/src/core/ext/filters/client_channel/request_routing.h
@@ -0,0 +1,177 @@
+/*
+ *
+ * Copyright 2018 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#ifndef GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H
+#define GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H
+
+#include <grpc/support/port_platform.h>
+
+#include "src/core/ext/filters/client_channel/client_channel_channelz.h"
+#include "src/core/ext/filters/client_channel/client_channel_factory.h"
+#include "src/core/ext/filters/client_channel/lb_policy.h"
+#include "src/core/ext/filters/client_channel/resolver.h"
+#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/channel/channel_stack.h"
+#include "src/core/lib/debug/trace.h"
+#include "src/core/lib/gprpp/inlined_vector.h"
+#include "src/core/lib/gprpp/orphanable.h"
+#include "src/core/lib/iomgr/call_combiner.h"
+#include "src/core/lib/iomgr/closure.h"
+#include "src/core/lib/iomgr/polling_entity.h"
+#include "src/core/lib/iomgr/pollset_set.h"
+#include "src/core/lib/transport/connectivity_state.h"
+#include "src/core/lib/transport/metadata_batch.h"
+
+namespace grpc_core {
+
+class RequestRouter {
+ public:
+ class Request {
+ public:
+ // Synchronous callback that applies the service config to a call.
+ // Returns false if the call should be failed.
+ typedef bool (*ApplyServiceConfigCallback)(void* user_data);
+
+ Request(grpc_call_stack* owning_call, grpc_call_combiner* call_combiner,
+ grpc_polling_entity* pollent,
+ grpc_metadata_batch* send_initial_metadata,
+ uint32_t* send_initial_metadata_flags,
+ ApplyServiceConfigCallback apply_service_config,
+ void* apply_service_config_user_data, grpc_closure* on_route_done);
+
+ ~Request();
+
+ // TODO(roth): It seems a bit ugly to expose this member in a
+ // non-const way. Find a better API to avoid this.
+ LoadBalancingPolicy::PickState* pick() { return &pick_; }
+
+ private:
+ friend class RequestRouter;
+
+ class ResolverResultWaiter;
+ class AsyncPickCanceller;
+
+ void ProcessServiceConfigAndStartLbPickLocked();
+ void StartLbPickLocked();
+ static void LbPickDoneLocked(void* arg, grpc_error* error);
+
+ void MaybeAddCallToInterestedPartiesLocked();
+ void MaybeRemoveCallFromInterestedPartiesLocked();
+
+ // Populated by caller.
+ grpc_call_stack* owning_call_;
+ grpc_call_combiner* call_combiner_;
+ grpc_polling_entity* pollent_;
+ ApplyServiceConfigCallback apply_service_config_;
+ void* apply_service_config_user_data_;
+ grpc_closure* on_route_done_;
+ LoadBalancingPolicy::PickState pick_;
+
+ // Internal state.
+ RequestRouter* request_router_ = nullptr;
+ bool pollent_added_to_interested_parties_ = false;
+ grpc_closure on_pick_done_;
+ AsyncPickCanceller* pick_canceller_ = nullptr;
+ };
+
+ // Synchronous callback that takes the service config JSON string and
+ // LB policy name.
+ // Returns true if the service config has changed since the last result.
+ typedef bool (*ProcessResolverResultCallback)(void* user_data,
+ const grpc_channel_args& args,
+ const char** lb_policy_name,
+ grpc_json** lb_policy_config);
+
+ RequestRouter(grpc_channel_stack* owning_stack, grpc_combiner* combiner,
+ grpc_client_channel_factory* client_channel_factory,
+ grpc_pollset_set* interested_parties, TraceFlag* tracer,
+ ProcessResolverResultCallback process_resolver_result,
+ void* process_resolver_result_user_data, const char* target_uri,
+ const grpc_channel_args* args, grpc_error** error);
+
+ ~RequestRouter();
+
+ void set_channelz_node(channelz::ClientChannelNode* channelz_node) {
+ channelz_node_ = channelz_node;
+ }
+
+ void RouteCallLocked(Request* request);
+
+ // TODO(roth): Add methods to cancel picks.
+
+ void ShutdownLocked(grpc_error* error);
+
+ void ExitIdleLocked();
+ void ResetConnectionBackoffLocked();
+
+ grpc_connectivity_state GetConnectivityState();
+ void NotifyOnConnectivityStateChange(grpc_connectivity_state* state,
+ grpc_closure* closure);
+
+ LoadBalancingPolicy* lb_policy() const { return lb_policy_.get(); }
+
+ private:
+ using TraceStringVector = grpc_core::InlinedVector<char*, 3>;
+
+ class ReresolutionRequestHandler;
+ class LbConnectivityWatcher;
+
+ void StartResolvingLocked();
+ void OnResolverShutdownLocked(grpc_error* error);
+ void CreateNewLbPolicyLocked(const char* lb_policy_name, grpc_json* lb_config,
+ grpc_connectivity_state* connectivity_state,
+ grpc_error** connectivity_error,
+ TraceStringVector* trace_strings);
+ void MaybeAddTraceMessagesForAddressChangesLocked(
+ TraceStringVector* trace_strings);
+ void ConcatenateAndAddChannelTraceLocked(
+ TraceStringVector* trace_strings) const;
+ static void OnResolverResultChangedLocked(void* arg, grpc_error* error);
+
+ void SetConnectivityStateLocked(grpc_connectivity_state state,
+ grpc_error* error, const char* reason);
+
+ // Passed in from caller at construction time.
+ grpc_channel_stack* owning_stack_;
+ grpc_combiner* combiner_;
+ grpc_client_channel_factory* client_channel_factory_;
+ grpc_pollset_set* interested_parties_;
+ TraceFlag* tracer_;
+
+ channelz::ClientChannelNode* channelz_node_ = nullptr;
+
+ // Resolver and associated state.
+ OrphanablePtr<Resolver> resolver_;
+ ProcessResolverResultCallback process_resolver_result_;
+ void* process_resolver_result_user_data_;
+ bool started_resolving_ = false;
+ grpc_channel_args* resolver_result_ = nullptr;
+ bool previous_resolution_contained_addresses_ = false;
+ grpc_closure_list waiting_for_resolver_result_closures_;
+ grpc_closure on_resolver_result_changed_;
+
+ // LB policy and associated state.
+ OrphanablePtr<LoadBalancingPolicy> lb_policy_;
+ bool exit_idle_when_lb_policy_arrives_ = false;
+
+ grpc_connectivity_state_tracker state_tracker_;
+};
+
+} // namespace grpc_core
+
+#endif /* GRPC_CORE_EXT_FILTERS_CLIENT_CHANNEL_REQUEST_ROUTING_H */
diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc
index c8425ae336..abacd0c960 100644
--- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc
+++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc
@@ -170,7 +170,7 @@ AresDnsResolver::AresDnsResolver(const ResolverArgs& args)
}
AresDnsResolver::~AresDnsResolver() {
- gpr_log(GPR_DEBUG, "destroying AresDnsResolver");
+ GRPC_CARES_TRACE_LOG("resolver:%p destroying AresDnsResolver", this);
if (resolved_result_ != nullptr) {
grpc_channel_args_destroy(resolved_result_);
}
@@ -182,7 +182,8 @@ AresDnsResolver::~AresDnsResolver() {
void AresDnsResolver::NextLocked(grpc_channel_args** target_result,
grpc_closure* on_complete) {
- gpr_log(GPR_DEBUG, "AresDnsResolver::NextLocked() is called.");
+ GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::NextLocked() is called.",
+ this);
GPR_ASSERT(next_completion_ == nullptr);
next_completion_ = on_complete;
target_result_ = target_result;
@@ -225,12 +226,14 @@ void AresDnsResolver::ShutdownLocked() {
void AresDnsResolver::OnNextResolutionLocked(void* arg, grpc_error* error) {
AresDnsResolver* r = static_cast<AresDnsResolver*>(arg);
GRPC_CARES_TRACE_LOG(
- "%p re-resolution timer fired. error: %s. shutdown_initiated_: %d", r,
- grpc_error_string(error), r->shutdown_initiated_);
+ "resolver:%p re-resolution timer fired. error: %s. shutdown_initiated_: "
+ "%d",
+ r, grpc_error_string(error), r->shutdown_initiated_);
r->have_next_resolution_timer_ = false;
if (error == GRPC_ERROR_NONE && !r->shutdown_initiated_) {
if (!r->resolving_) {
- GRPC_CARES_TRACE_LOG("%p start resolving due to re-resolution timer", r);
+ GRPC_CARES_TRACE_LOG(
+ "resolver:%p start resolving due to re-resolution timer", r);
r->StartResolvingLocked();
}
}
@@ -327,8 +330,8 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) {
service_config_string = ChooseServiceConfig(r->service_config_json_);
gpr_free(r->service_config_json_);
if (service_config_string != nullptr) {
- gpr_log(GPR_INFO, "selected service config choice: %s",
- service_config_string);
+ GRPC_CARES_TRACE_LOG("resolver:%p selected service config choice: %s",
+ r, service_config_string);
args_to_remove[num_args_to_remove++] = GRPC_ARG_SERVICE_CONFIG;
args_to_add[num_args_to_add++] = grpc_channel_arg_string_create(
(char*)GRPC_ARG_SERVICE_CONFIG, service_config_string);
@@ -344,11 +347,11 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) {
r->backoff_.Reset();
} else if (!r->shutdown_initiated_) {
const char* msg = grpc_error_string(error);
- gpr_log(GPR_DEBUG, "dns resolution failed: %s", msg);
+ GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed: %s", r, msg);
grpc_millis next_try = r->backoff_.NextAttemptTime();
grpc_millis timeout = next_try - ExecCtx::Get()->Now();
- gpr_log(GPR_INFO, "dns resolution failed (will retry): %s",
- grpc_error_string(error));
+ GRPC_CARES_TRACE_LOG("resolver:%p dns resolution failed (will retry): %s",
+ r, grpc_error_string(error));
GPR_ASSERT(!r->have_next_resolution_timer_);
r->have_next_resolution_timer_ = true;
// TODO(roth): We currently deal with this ref manually. Once the
@@ -357,9 +360,10 @@ void AresDnsResolver::OnResolvedLocked(void* arg, grpc_error* error) {
RefCountedPtr<Resolver> self = r->Ref(DEBUG_LOCATION, "retry-timer");
self.release();
if (timeout > 0) {
- gpr_log(GPR_DEBUG, "retrying in %" PRId64 " milliseconds", timeout);
+ GRPC_CARES_TRACE_LOG("resolver:%p retrying in %" PRId64 " milliseconds",
+ r, timeout);
} else {
- gpr_log(GPR_DEBUG, "retrying immediately");
+ GRPC_CARES_TRACE_LOG("resolver:%p retrying immediately", r);
}
grpc_timer_init(&r->next_resolution_timer_, next_try,
&r->on_next_resolution_);
@@ -385,10 +389,10 @@ void AresDnsResolver::MaybeStartResolvingLocked() {
if (ms_until_next_resolution > 0) {
const grpc_millis last_resolution_ago =
grpc_core::ExecCtx::Get()->Now() - last_resolution_timestamp_;
- gpr_log(GPR_DEBUG,
- "In cooldown from last resolution (from %" PRId64
- " ms ago). Will resolve again in %" PRId64 " ms",
- last_resolution_ago, ms_until_next_resolution);
+ GRPC_CARES_TRACE_LOG(
+ "resolver:%p In cooldown from last resolution (from %" PRId64
+ " ms ago). Will resolve again in %" PRId64 " ms",
+ this, last_resolution_ago, ms_until_next_resolution);
have_next_resolution_timer_ = true;
// TODO(roth): We currently deal with this ref manually. Once the
// new closure API is done, find a way to track this ref with the timer
@@ -405,7 +409,6 @@ void AresDnsResolver::MaybeStartResolvingLocked() {
}
void AresDnsResolver::StartResolvingLocked() {
- gpr_log(GPR_DEBUG, "Start resolving.");
// TODO(roth): We currently deal with this ref manually. Once the
// new closure API is done, find a way to track this ref with the timer
// callback as part of the type system.
@@ -420,6 +423,8 @@ void AresDnsResolver::StartResolvingLocked() {
request_service_config_ ? &service_config_json_ : nullptr,
query_timeout_ms_, combiner());
last_resolution_timestamp_ = grpc_core::ExecCtx::Get()->Now();
+ GRPC_CARES_TRACE_LOG("resolver:%p Started resolving. pending_request_:%p",
+ this, pending_request_);
}
void AresDnsResolver::MaybeFinishNextLocked() {
@@ -427,7 +432,8 @@ void AresDnsResolver::MaybeFinishNextLocked() {
*target_result_ = resolved_result_ == nullptr
? nullptr
: grpc_channel_args_copy(resolved_result_);
- gpr_log(GPR_DEBUG, "AresDnsResolver::MaybeFinishNextLocked()");
+ GRPC_CARES_TRACE_LOG("resolver:%p AresDnsResolver::MaybeFinishNextLocked()",
+ this);
GRPC_CLOSURE_SCHED(next_completion_, GRPC_ERROR_NONE);
next_completion_ = nullptr;
published_version_ = resolved_version_;
@@ -465,11 +471,16 @@ static grpc_error* blocking_resolve_address_ares(
static grpc_address_resolver_vtable ares_resolver = {
grpc_resolve_address_ares, blocking_resolve_address_ares};
+static bool should_use_ares(const char* resolver_env) {
+ return resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0;
+}
+
void grpc_resolver_dns_ares_init() {
char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER");
/* TODO(zyc): Turn on c-ares based resolver by default after the address
sorter and the CNAME support are added. */
- if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) {
+ if (should_use_ares(resolver_env)) {
+ gpr_log(GPR_DEBUG, "Using ares dns resolver");
address_sorting_init();
grpc_error* error = grpc_ares_init();
if (error != GRPC_ERROR_NONE) {
@@ -489,7 +500,7 @@ void grpc_resolver_dns_ares_init() {
void grpc_resolver_dns_ares_shutdown() {
char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER");
- if (resolver_env != nullptr && gpr_stricmp(resolver_env, "ares") == 0) {
+ if (should_use_ares(resolver_env)) {
address_sorting_shutdown();
grpc_ares_cleanup();
}
diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc
index 8abc34c6ed..d99c2e3004 100644
--- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc
+++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver.cc
@@ -90,15 +90,18 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver);
static grpc_ares_ev_driver* grpc_ares_ev_driver_ref(
grpc_ares_ev_driver* ev_driver) {
- gpr_log(GPR_DEBUG, "Ref ev_driver %" PRIuPTR, (uintptr_t)ev_driver);
+ GRPC_CARES_TRACE_LOG("request:%p Ref ev_driver %p", ev_driver->request,
+ ev_driver);
gpr_ref(&ev_driver->refs);
return ev_driver;
}
static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) {
- gpr_log(GPR_DEBUG, "Unref ev_driver %" PRIuPTR, (uintptr_t)ev_driver);
+ GRPC_CARES_TRACE_LOG("request:%p Unref ev_driver %p", ev_driver->request,
+ ev_driver);
if (gpr_unref(&ev_driver->refs)) {
- gpr_log(GPR_DEBUG, "destroy ev_driver %" PRIuPTR, (uintptr_t)ev_driver);
+ GRPC_CARES_TRACE_LOG("request:%p destroy ev_driver %p", ev_driver->request,
+ ev_driver);
GPR_ASSERT(ev_driver->fds == nullptr);
GRPC_COMBINER_UNREF(ev_driver->combiner, "free ares event driver");
ares_destroy(ev_driver->channel);
@@ -108,7 +111,8 @@ static void grpc_ares_ev_driver_unref(grpc_ares_ev_driver* ev_driver) {
}
static void fd_node_destroy_locked(fd_node* fdn) {
- gpr_log(GPR_DEBUG, "delete fd: %s", fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p delete fd: %s", fdn->ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
GPR_ASSERT(!fdn->readable_registered);
GPR_ASSERT(!fdn->writable_registered);
GPR_ASSERT(fdn->already_shutdown);
@@ -136,7 +140,7 @@ grpc_error* grpc_ares_ev_driver_create_locked(grpc_ares_ev_driver** ev_driver,
memset(&opts, 0, sizeof(opts));
opts.flags |= ARES_FLAG_STAYOPEN;
int status = ares_init_options(&(*ev_driver)->channel, &opts, ARES_OPT_FLAGS);
- gpr_log(GPR_DEBUG, "grpc_ares_ev_driver_create_locked");
+ GRPC_CARES_TRACE_LOG("request:%p grpc_ares_ev_driver_create_locked", request);
if (status != ARES_SUCCESS) {
char* err_msg;
gpr_asprintf(&err_msg, "Failed to init ares channel. C-ares error: %s",
@@ -203,8 +207,9 @@ static fd_node* pop_fd_node_locked(fd_node** head, ares_socket_t as) {
static void on_timeout_locked(void* arg, grpc_error* error) {
grpc_ares_ev_driver* driver = static_cast<grpc_ares_ev_driver*>(arg);
GRPC_CARES_TRACE_LOG(
- "ev_driver=%p on_timeout_locked. driver->shutting_down=%d. err=%s",
- driver, driver->shutting_down, grpc_error_string(error));
+ "request:%p ev_driver=%p on_timeout_locked. driver->shutting_down=%d. "
+ "err=%s",
+ driver->request, driver, driver->shutting_down, grpc_error_string(error));
if (!driver->shutting_down && error == GRPC_ERROR_NONE) {
grpc_ares_ev_driver_shutdown_locked(driver);
}
@@ -216,7 +221,8 @@ static void on_readable_locked(void* arg, grpc_error* error) {
grpc_ares_ev_driver* ev_driver = fdn->ev_driver;
const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked();
fdn->readable_registered = false;
- gpr_log(GPR_DEBUG, "readable on %s", fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p readable on %s", fdn->ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
if (error == GRPC_ERROR_NONE) {
do {
ares_process_fd(ev_driver->channel, as, ARES_SOCKET_BAD);
@@ -239,7 +245,8 @@ static void on_writable_locked(void* arg, grpc_error* error) {
grpc_ares_ev_driver* ev_driver = fdn->ev_driver;
const ares_socket_t as = fdn->grpc_polled_fd->GetWrappedAresSocketLocked();
fdn->writable_registered = false;
- gpr_log(GPR_DEBUG, "writable on %s", fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p writable on %s", ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
if (error == GRPC_ERROR_NONE) {
ares_process_fd(ev_driver->channel, ARES_SOCKET_BAD, as);
} else {
@@ -278,7 +285,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) {
fdn->grpc_polled_fd =
ev_driver->polled_fd_factory->NewGrpcPolledFdLocked(
socks[i], ev_driver->pollset_set, ev_driver->combiner);
- gpr_log(GPR_DEBUG, "new fd: %s", fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p new fd: %s", ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
fdn->ev_driver = ev_driver;
fdn->readable_registered = false;
fdn->writable_registered = false;
@@ -295,8 +303,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) {
if (ARES_GETSOCK_READABLE(socks_bitmask, i) &&
!fdn->readable_registered) {
grpc_ares_ev_driver_ref(ev_driver);
- gpr_log(GPR_DEBUG, "notify read on: %s",
- fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p notify read on: %s",
+ ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
fdn->grpc_polled_fd->RegisterForOnReadableLocked(&fdn->read_closure);
fdn->readable_registered = true;
}
@@ -304,8 +313,9 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) {
// has not been registered with this socket.
if (ARES_GETSOCK_WRITABLE(socks_bitmask, i) &&
!fdn->writable_registered) {
- gpr_log(GPR_DEBUG, "notify write on: %s",
- fdn->grpc_polled_fd->GetName());
+ GRPC_CARES_TRACE_LOG("request:%p notify write on: %s",
+ ev_driver->request,
+ fdn->grpc_polled_fd->GetName());
grpc_ares_ev_driver_ref(ev_driver);
fdn->grpc_polled_fd->RegisterForOnWriteableLocked(
&fdn->write_closure);
@@ -332,7 +342,8 @@ static void grpc_ares_notify_on_event_locked(grpc_ares_ev_driver* ev_driver) {
// If the ev driver has no working fd, all the tasks are done.
if (new_list == nullptr) {
ev_driver->working = false;
- gpr_log(GPR_DEBUG, "ev driver stop working");
+ GRPC_CARES_TRACE_LOG("request:%p ev driver stop working",
+ ev_driver->request);
}
}
@@ -345,9 +356,9 @@ void grpc_ares_ev_driver_start_locked(grpc_ares_ev_driver* ev_driver) {
? GRPC_MILLIS_INF_FUTURE
: ev_driver->query_timeout_ms + grpc_core::ExecCtx::Get()->Now();
GRPC_CARES_TRACE_LOG(
- "ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in %" PRId64
- " ms",
- ev_driver, timeout);
+ "request:%p ev_driver=%p grpc_ares_ev_driver_start_locked. timeout in "
+ "%" PRId64 " ms",
+ ev_driver->request, ev_driver, timeout);
grpc_ares_ev_driver_ref(ev_driver);
grpc_timer_init(&ev_driver->query_timeout, timeout,
&ev_driver->on_timeout_locked);
diff --git a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc
index 1b1c2303da..1a7e5d0626 100644
--- a/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc
+++ b/src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc
@@ -96,11 +96,11 @@ static void log_address_sorting_list(const ServerAddressList& addresses,
for (size_t i = 0; i < addresses.size(); i++) {
char* addr_str;
if (grpc_sockaddr_to_string(&addr_str, &addresses[i].address(), true)) {
- gpr_log(GPR_DEBUG, "c-ares address sorting: %s[%" PRIuPTR "]=%s",
+ gpr_log(GPR_INFO, "c-ares address sorting: %s[%" PRIuPTR "]=%s",
input_output_str, i, addr_str);
gpr_free(addr_str);
} else {
- gpr_log(GPR_DEBUG,
+ gpr_log(GPR_INFO,
"c-ares address sorting: %s[%" PRIuPTR "]=<unprintable>",
input_output_str, i);
}
@@ -209,10 +209,10 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts,
addresses.emplace_back(&addr, addr_len, args);
char output[INET6_ADDRSTRLEN];
ares_inet_ntop(AF_INET6, &addr.sin6_addr, output, INET6_ADDRSTRLEN);
- gpr_log(GPR_DEBUG,
- "c-ares resolver gets a AF_INET6 result: \n"
- " addr: %s\n port: %d\n sin6_scope_id: %d\n",
- output, ntohs(hr->port), addr.sin6_scope_id);
+ GRPC_CARES_TRACE_LOG(
+ "request:%p c-ares resolver gets a AF_INET6 result: \n"
+ " addr: %s\n port: %d\n sin6_scope_id: %d\n",
+ r, output, ntohs(hr->port), addr.sin6_scope_id);
break;
}
case AF_INET: {
@@ -226,10 +226,10 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts,
addresses.emplace_back(&addr, addr_len, args);
char output[INET_ADDRSTRLEN];
ares_inet_ntop(AF_INET, &addr.sin_addr, output, INET_ADDRSTRLEN);
- gpr_log(GPR_DEBUG,
- "c-ares resolver gets a AF_INET result: \n"
- " addr: %s\n port: %d\n",
- output, ntohs(hr->port));
+ GRPC_CARES_TRACE_LOG(
+ "request:%p c-ares resolver gets a AF_INET result: \n"
+ " addr: %s\n port: %d\n",
+ r, output, ntohs(hr->port));
break;
}
}
@@ -252,9 +252,9 @@ static void on_hostbyname_done_locked(void* arg, int status, int timeouts,
static void on_srv_query_done_locked(void* arg, int status, int timeouts,
unsigned char* abuf, int alen) {
grpc_ares_request* r = static_cast<grpc_ares_request*>(arg);
- gpr_log(GPR_DEBUG, "on_query_srv_done_locked");
+ GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked", r);
if (status == ARES_SUCCESS) {
- gpr_log(GPR_DEBUG, "on_query_srv_done_locked ARES_SUCCESS");
+ GRPC_CARES_TRACE_LOG("request:%p on_query_srv_done_locked ARES_SUCCESS", r);
struct ares_srv_reply* reply;
const int parse_status = ares_parse_srv_reply(abuf, alen, &reply);
if (parse_status == ARES_SUCCESS) {
@@ -297,9 +297,9 @@ static const char g_service_config_attribute_prefix[] = "grpc_config=";
static void on_txt_done_locked(void* arg, int status, int timeouts,
unsigned char* buf, int len) {
- gpr_log(GPR_DEBUG, "on_txt_done_locked");
char* error_msg;
grpc_ares_request* r = static_cast<grpc_ares_request*>(arg);
+ GRPC_CARES_TRACE_LOG("request:%p on_txt_done_locked", r);
const size_t prefix_len = sizeof(g_service_config_attribute_prefix) - 1;
struct ares_txt_ext* result = nullptr;
struct ares_txt_ext* reply = nullptr;
@@ -332,7 +332,8 @@ static void on_txt_done_locked(void* arg, int status, int timeouts,
service_config_len += result->length;
}
(*r->service_config_json_out)[service_config_len] = '\0';
- gpr_log(GPR_INFO, "found service config: %s", *r->service_config_json_out);
+ GRPC_CARES_TRACE_LOG("request:%p found service config: %s", r,
+ *r->service_config_json_out);
}
// Clean up.
ares_free_data(reply);
@@ -358,12 +359,6 @@ void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked(
grpc_error* error = GRPC_ERROR_NONE;
grpc_ares_hostbyname_request* hr = nullptr;
ares_channel* channel = nullptr;
- /* TODO(zyc): Enable tracing after #9603 is checked in */
- /* if (grpc_dns_trace) {
- gpr_log(GPR_DEBUG, "resolve_address (blocking): name=%s, default_port=%s",
- name, default_port);
- } */
-
/* parse name, splitting it into host and port parts */
char* host;
char* port;
@@ -388,7 +383,7 @@ void grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked(
channel = grpc_ares_ev_driver_get_channel_locked(r->ev_driver);
// If dns_server is specified, use it.
if (dns_server != nullptr) {
- gpr_log(GPR_INFO, "Using DNS server %s", dns_server);
+ GRPC_CARES_TRACE_LOG("request:%p Using DNS server %s", r, dns_server);
grpc_resolved_address addr;
if (grpc_parse_ipv4_hostport(dns_server, &addr, false /* log_errors */)) {
r->dns_server_addr.family = AF_INET;
@@ -510,6 +505,28 @@ static bool resolve_as_ip_literal_locked(
return out;
}
+static bool target_matches_localhost_inner(const char* name, char** host,
+ char** port) {
+ if (!gpr_split_host_port(name, host, port)) {
+ gpr_log(GPR_ERROR, "Unable to split host and port for name: %s", name);
+ return false;
+ }
+ if (gpr_stricmp(*host, "localhost") == 0) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+static bool target_matches_localhost(const char* name) {
+ char* host = nullptr;
+ char* port = nullptr;
+ bool out = target_matches_localhost_inner(name, &host, &port);
+ gpr_free(host);
+ gpr_free(port);
+ return out;
+}
+
static grpc_ares_request* grpc_dns_lookup_ares_locked_impl(
const char* dns_server, const char* name, const char* default_port,
grpc_pollset_set* interested_parties, grpc_closure* on_done,
@@ -525,6 +542,10 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl(
r->success = false;
r->error = GRPC_ERROR_NONE;
r->pending_queries = 0;
+ GRPC_CARES_TRACE_LOG(
+ "request:%p c-ares grpc_dns_lookup_ares_locked_impl name=%s, "
+ "default_port=%s",
+ r, name, default_port);
// Early out if the target is an ipv4 or ipv6 literal.
if (resolve_as_ip_literal_locked(name, default_port, addrs)) {
GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE);
@@ -536,6 +557,13 @@ static grpc_ares_request* grpc_dns_lookup_ares_locked_impl(
GRPC_CLOSURE_SCHED(on_done, GRPC_ERROR_NONE);
return r;
}
+ // Don't query for SRV and TXT records if the target is "localhost", so
+ // as to cut down on lookups over the network, especially in tests:
+ // https://github.com/grpc/proposal/pull/79
+ if (target_matches_localhost(name)) {
+ check_grpclb = false;
+ r->service_config_json_out = nullptr;
+ }
// Look up name using c-ares lib.
grpc_dns_lookup_ares_continue_after_check_localhost_and_ip_literals_locked(
r, dns_server, name, default_port, interested_parties, check_grpclb,
diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.cc b/src/core/ext/filters/client_channel/resolver_result_parsing.cc
index 22b06db45c..9a0122e8ec 100644
--- a/src/core/ext/filters/client_channel/resolver_result_parsing.cc
+++ b/src/core/ext/filters/client_channel/resolver_result_parsing.cc
@@ -43,16 +43,16 @@ namespace grpc_core {
namespace internal {
ProcessedResolverResult::ProcessedResolverResult(
- const grpc_channel_args* resolver_result, bool parse_retry) {
+ const grpc_channel_args& resolver_result, bool parse_retry) {
ProcessServiceConfig(resolver_result, parse_retry);
// If no LB config was found above, just find the LB policy name then.
if (lb_policy_name_ == nullptr) ProcessLbPolicyName(resolver_result);
}
void ProcessedResolverResult::ProcessServiceConfig(
- const grpc_channel_args* resolver_result, bool parse_retry) {
+ const grpc_channel_args& resolver_result, bool parse_retry) {
const grpc_arg* channel_arg =
- grpc_channel_args_find(resolver_result, GRPC_ARG_SERVICE_CONFIG);
+ grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVICE_CONFIG);
const char* service_config_json = grpc_channel_arg_get_string(channel_arg);
if (service_config_json != nullptr) {
service_config_json_.reset(gpr_strdup(service_config_json));
@@ -60,7 +60,7 @@ void ProcessedResolverResult::ProcessServiceConfig(
if (service_config_ != nullptr) {
if (parse_retry) {
channel_arg =
- grpc_channel_args_find(resolver_result, GRPC_ARG_SERVER_URI);
+ grpc_channel_args_find(&resolver_result, GRPC_ARG_SERVER_URI);
const char* server_uri = grpc_channel_arg_get_string(channel_arg);
GPR_ASSERT(server_uri != nullptr);
grpc_uri* uri = grpc_uri_parse(server_uri, true);
@@ -78,7 +78,7 @@ void ProcessedResolverResult::ProcessServiceConfig(
}
void ProcessedResolverResult::ProcessLbPolicyName(
- const grpc_channel_args* resolver_result) {
+ const grpc_channel_args& resolver_result) {
// Prefer the LB policy name found in the service config. Note that this is
// checking the deprecated loadBalancingPolicy field, rather than the new
// loadBalancingConfig field.
@@ -96,13 +96,13 @@ void ProcessedResolverResult::ProcessLbPolicyName(
// Otherwise, find the LB policy name set by the client API.
if (lb_policy_name_ == nullptr) {
const grpc_arg* channel_arg =
- grpc_channel_args_find(resolver_result, GRPC_ARG_LB_POLICY_NAME);
+ grpc_channel_args_find(&resolver_result, GRPC_ARG_LB_POLICY_NAME);
lb_policy_name_.reset(gpr_strdup(grpc_channel_arg_get_string(channel_arg)));
}
// Special case: If at least one balancer address is present, we use
// the grpclb policy, regardless of what the resolver has returned.
const ServerAddressList* addresses =
- FindServerAddressListChannelArg(resolver_result);
+ FindServerAddressListChannelArg(&resolver_result);
if (addresses != nullptr) {
bool found_balancer_address = false;
for (size_t i = 0; i < addresses->size(); ++i) {
diff --git a/src/core/ext/filters/client_channel/resolver_result_parsing.h b/src/core/ext/filters/client_channel/resolver_result_parsing.h
index f1fb7406bc..98a9d26c46 100644
--- a/src/core/ext/filters/client_channel/resolver_result_parsing.h
+++ b/src/core/ext/filters/client_channel/resolver_result_parsing.h
@@ -36,8 +36,7 @@ namespace internal {
class ClientChannelMethodParams;
// A table mapping from a method name to its method parameters.
-typedef grpc_core::SliceHashTable<
- grpc_core::RefCountedPtr<ClientChannelMethodParams>>
+typedef SliceHashTable<RefCountedPtr<ClientChannelMethodParams>>
ClientChannelMethodParamsTable;
// A container of processed fields from the resolver result. Simplifies the
@@ -47,33 +46,30 @@ class ProcessedResolverResult {
// Processes the resolver result and populates the relative members
// for later consumption. Tries to parse retry parameters only if parse_retry
// is true.
- ProcessedResolverResult(const grpc_channel_args* resolver_result,
+ ProcessedResolverResult(const grpc_channel_args& resolver_result,
bool parse_retry);
// Getters. Any managed object's ownership is transferred.
- grpc_core::UniquePtr<char> service_config_json() {
+ UniquePtr<char> service_config_json() {
return std::move(service_config_json_);
}
- grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data() {
+ RefCountedPtr<ServerRetryThrottleData> retry_throttle_data() {
return std::move(retry_throttle_data_);
}
- grpc_core::RefCountedPtr<ClientChannelMethodParamsTable>
- method_params_table() {
+ RefCountedPtr<ClientChannelMethodParamsTable> method_params_table() {
return std::move(method_params_table_);
}
- grpc_core::UniquePtr<char> lb_policy_name() {
- return std::move(lb_policy_name_);
- }
+ UniquePtr<char> lb_policy_name() { return std::move(lb_policy_name_); }
grpc_json* lb_policy_config() { return lb_policy_config_; }
private:
// Finds the service config; extracts LB config and (maybe) retry throttle
// params from it.
- void ProcessServiceConfig(const grpc_channel_args* resolver_result,
+ void ProcessServiceConfig(const grpc_channel_args& resolver_result,
bool parse_retry);
// Finds the LB policy name (when no LB config was found).
- void ProcessLbPolicyName(const grpc_channel_args* resolver_result);
+ void ProcessLbPolicyName(const grpc_channel_args& resolver_result);
// Parses the service config. Intended to be used by
// ServiceConfig::ParseGlobalParams.
@@ -85,16 +81,16 @@ class ProcessedResolverResult {
void ParseRetryThrottleParamsFromServiceConfig(const grpc_json* field);
// Service config.
- grpc_core::UniquePtr<char> service_config_json_;
- grpc_core::UniquePtr<grpc_core::ServiceConfig> service_config_;
+ UniquePtr<char> service_config_json_;
+ UniquePtr<grpc_core::ServiceConfig> service_config_;
// LB policy.
grpc_json* lb_policy_config_ = nullptr;
- grpc_core::UniquePtr<char> lb_policy_name_;
+ UniquePtr<char> lb_policy_name_;
// Retry throttle data.
char* server_name_ = nullptr;
- grpc_core::RefCountedPtr<ServerRetryThrottleData> retry_throttle_data_;
+ RefCountedPtr<ServerRetryThrottleData> retry_throttle_data_;
// Method params table.
- grpc_core::RefCountedPtr<ClientChannelMethodParamsTable> method_params_table_;
+ RefCountedPtr<ClientChannelMethodParamsTable> method_params_table_;
};
// The parameters of a method.
diff --git a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc
index e73eee4353..9612698e96 100644
--- a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc
+++ b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc
@@ -110,14 +110,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args(
grpc_channel_args* args_with_authority =
grpc_channel_args_copy_and_add(args->args, args_to_add, num_args_to_add);
grpc_uri_destroy(server_uri);
- grpc_channel_security_connector* subchannel_security_connector = nullptr;
// Create the security connector using the credentials and target name.
grpc_channel_args* new_args_from_connector = nullptr;
- const grpc_security_status security_status =
- grpc_channel_credentials_create_security_connector(
- channel_credentials, authority.get(), args_with_authority,
- &subchannel_security_connector, &new_args_from_connector);
- if (security_status != GRPC_SECURITY_OK) {
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ subchannel_security_connector =
+ channel_credentials->create_security_connector(
+ /*call_creds=*/nullptr, authority.get(), args_with_authority,
+ &new_args_from_connector);
+ if (subchannel_security_connector == nullptr) {
gpr_log(GPR_ERROR,
"Failed to create secure subchannel for secure name '%s'",
authority.get());
@@ -125,15 +125,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args(
return nullptr;
}
grpc_arg new_security_connector_arg =
- grpc_security_connector_to_arg(&subchannel_security_connector->base);
+ grpc_security_connector_to_arg(subchannel_security_connector.get());
grpc_channel_args* new_args = grpc_channel_args_copy_and_add(
new_args_from_connector != nullptr ? new_args_from_connector
: args_with_authority,
&new_security_connector_arg, 1);
- GRPC_SECURITY_CONNECTOR_UNREF(&subchannel_security_connector->base,
- "lb_channel_create");
+ subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create");
if (new_args_from_connector != nullptr) {
grpc_channel_args_destroy(new_args_from_connector);
}
diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc
index 6689a17da6..98fdb62070 100644
--- a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc
+++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc
@@ -31,6 +31,7 @@
#include "src/core/ext/transport/chttp2/transport/chttp2_transport.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/surface/api_trace.h"
@@ -40,9 +41,8 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr,
grpc_server_credentials* creds) {
grpc_core::ExecCtx exec_ctx;
grpc_error* err = GRPC_ERROR_NONE;
- grpc_server_security_connector* sc = nullptr;
+ grpc_core::RefCountedPtr<grpc_server_security_connector> sc;
int port_num = 0;
- grpc_security_status status;
grpc_channel_args* args = nullptr;
GRPC_API_TRACE(
"grpc_server_add_secure_http2_port("
@@ -54,30 +54,27 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr,
"No credentials specified for secure server port (creds==NULL)");
goto done;
}
- status = grpc_server_credentials_create_security_connector(creds, &sc);
- if (status != GRPC_SECURITY_OK) {
+ sc = creds->create_security_connector();
+ if (sc == nullptr) {
char* msg;
gpr_asprintf(&msg,
"Unable to create secure server with credentials of type %s.",
- creds->type);
- err = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg),
- GRPC_ERROR_INT_SECURITY_STATUS, status);
+ creds->type());
+ err = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
gpr_free(msg);
goto done;
}
// Create channel args.
grpc_arg args_to_add[2];
args_to_add[0] = grpc_server_credentials_to_arg(creds);
- args_to_add[1] = grpc_security_connector_to_arg(&sc->base);
+ args_to_add[1] = grpc_security_connector_to_arg(sc.get());
args =
grpc_channel_args_copy_and_add(grpc_server_get_channel_args(server),
args_to_add, GPR_ARRAY_SIZE(args_to_add));
// Add server port.
err = grpc_chttp2_server_add_port(server, addr, args, &port_num);
done:
- if (sc != nullptr) {
- GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server");
- }
+ sc.reset(DEBUG_LOCATION, "server");
if (err != GRPC_ERROR_NONE) {
const char* msg = grpc_error_string(err);
diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
index 9b6574b612..7f4627fa77 100644
--- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
+++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
@@ -170,7 +170,12 @@ grpc_chttp2_transport::~grpc_chttp2_transport() {
grpc_slice_buffer_destroy_internal(&outbuf);
grpc_chttp2_hpack_compressor_destroy(&hpack_compressor);
- grpc_core::ContextList::Execute(cl, nullptr, GRPC_ERROR_NONE);
+ grpc_error* error =
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("Transport destroyed");
+ // ContextList::Execute follows semantics of a callback function and does not
+ // take a ref on error
+ grpc_core::ContextList::Execute(cl, nullptr, error);
+ GRPC_ERROR_UNREF(error);
cl = nullptr;
grpc_slice_buffer_destroy_internal(&read_buffer);
diff --git a/src/core/ext/transport/chttp2/transport/context_list.cc b/src/core/ext/transport/chttp2/transport/context_list.cc
index f30d41c332..df09809067 100644
--- a/src/core/ext/transport/chttp2/transport/context_list.cc
+++ b/src/core/ext/transport/chttp2/transport/context_list.cc
@@ -21,31 +21,47 @@
#include "src/core/ext/transport/chttp2/transport/context_list.h"
namespace {
-void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*) = nullptr;
-}
+void (*write_timestamps_callback_g)(void*, grpc_core::Timestamps*,
+ grpc_error* error) = nullptr;
+void* (*get_copied_context_fn_g)(void*) = nullptr;
+} // namespace
namespace grpc_core {
+void ContextList::Append(ContextList** head, grpc_chttp2_stream* s) {
+ if (get_copied_context_fn_g == nullptr ||
+ write_timestamps_callback_g == nullptr) {
+ return;
+ }
+ /* Create a new element in the list and add it at the front */
+ ContextList* elem = grpc_core::New<ContextList>();
+ elem->trace_context_ = get_copied_context_fn_g(s->context);
+ elem->byte_offset_ = s->byte_counter;
+ elem->next_ = *head;
+ *head = elem;
+}
+
void ContextList::Execute(void* arg, grpc_core::Timestamps* ts,
grpc_error* error) {
ContextList* head = static_cast<ContextList*>(arg);
ContextList* to_be_freed;
while (head != nullptr) {
- if (error == GRPC_ERROR_NONE && ts != nullptr) {
- if (write_timestamps_callback_g) {
- ts->byte_offset = static_cast<uint32_t>(head->byte_offset_);
- write_timestamps_callback_g(head->s_->context, ts);
- }
+ if (write_timestamps_callback_g) {
+ ts->byte_offset = static_cast<uint32_t>(head->byte_offset_);
+ write_timestamps_callback_g(head->trace_context_, ts, error);
}
- GRPC_CHTTP2_STREAM_UNREF(static_cast<grpc_chttp2_stream*>(head->s_),
- "timestamp");
to_be_freed = head;
head = head->next_;
grpc_core::Delete(to_be_freed);
}
}
-void grpc_http2_set_write_timestamps_callback(
- void (*fn)(void*, grpc_core::Timestamps*)) {
+void grpc_http2_set_write_timestamps_callback(void (*fn)(void*,
+ grpc_core::Timestamps*,
+ grpc_error* error)) {
write_timestamps_callback_g = fn;
}
+
+void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*)) {
+ get_copied_context_fn_g = fn;
+}
} /* namespace grpc_core */
diff --git a/src/core/ext/transport/chttp2/transport/context_list.h b/src/core/ext/transport/chttp2/transport/context_list.h
index d870107749..5b9d2ab378 100644
--- a/src/core/ext/transport/chttp2/transport/context_list.h
+++ b/src/core/ext/transport/chttp2/transport/context_list.h
@@ -31,42 +31,23 @@ class ContextList {
public:
/* Creates a new element with \a context as the value and appends it to the
* list. */
- static void Append(ContextList** head, grpc_chttp2_stream* s) {
- /* Make sure context is not already present */
- GRPC_CHTTP2_STREAM_REF(s, "timestamp");
-
-#ifndef NDEBUG
- ContextList* ptr = *head;
- while (ptr != nullptr) {
- if (ptr->s_ == s) {
- GPR_ASSERT(
- false &&
- "Trying to append a stream that is already present in the list");
- }
- ptr = ptr->next_;
- }
-#endif
-
- /* Create a new element in the list and add it at the front */
- ContextList* elem = grpc_core::New<ContextList>();
- elem->s_ = s;
- elem->byte_offset_ = s->byte_counter;
- elem->next_ = *head;
- *head = elem;
- }
+ static void Append(ContextList** head, grpc_chttp2_stream* s);
/* Executes a function \a fn with each context in the list and \a ts. It also
- * frees up the entire list after this operation. */
+ * frees up the entire list after this operation. It is intended as a callback
+ * and hence does not take a ref on \a error */
static void Execute(void* arg, grpc_core::Timestamps* ts, grpc_error* error);
private:
- grpc_chttp2_stream* s_ = nullptr;
+ void* trace_context_ = nullptr;
ContextList* next_ = nullptr;
size_t byte_offset_ = 0;
};
-void grpc_http2_set_write_timestamps_callback(
- void (*fn)(void*, grpc_core::Timestamps*));
+void grpc_http2_set_write_timestamps_callback(void (*fn)(void*,
+ grpc_core::Timestamps*,
+ grpc_error* error));
+void grpc_http2_set_fn_get_copied_context(void* (*fn)(void*));
} /* namespace grpc_core */
#endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_CONTEXT_LIST_H */
diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h
index 1ed5d584c7..19f38d7f01 100644
--- a/src/core/lib/gprpp/ref_counted_ptr.h
+++ b/src/core/lib/gprpp/ref_counted_ptr.h
@@ -50,7 +50,7 @@ class RefCountedPtr {
}
template <typename Y>
RefCountedPtr(RefCountedPtr<Y>&& other) {
- value_ = other.value_;
+ value_ = static_cast<T*>(other.value_);
other.value_ = nullptr;
}
@@ -77,7 +77,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (other.value_ != nullptr) other.value_->IncrementRefCount();
- value_ = other.value_;
+ value_ = static_cast<T*>(other.value_);
}
// Copy assignment.
@@ -118,7 +118,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (value_ != nullptr) value_->Unref();
- value_ = value;
+ value_ = static_cast<T*>(value);
}
template <typename Y>
void reset(const DebugLocation& location, const char* reason,
@@ -126,7 +126,7 @@ class RefCountedPtr {
static_assert(std::has_virtual_destructor<T>::value,
"T does not have a virtual dtor");
if (value_ != nullptr) value_->Unref(location, reason);
- value_ = value;
+ value_ = static_cast<T*>(value);
}
// TODO(roth): This method exists solely as a transition mechanism to allow
diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc
index 1c798d368b..fdea7511cc 100644
--- a/src/core/lib/http/httpcli_security_connector.cc
+++ b/src/core/lib/http/httpcli_security_connector.cc
@@ -29,119 +29,125 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker_registry.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
+#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/ssl_utils.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/lib/slice/slice_internal.h"
#include "src/core/tsi/ssl_transport_security.h"
-typedef struct {
- grpc_channel_security_connector base;
- tsi_ssl_client_handshaker_factory* handshaker_factory;
- char* secure_peer_name;
-} grpc_httpcli_ssl_channel_security_connector;
-
-static void httpcli_ssl_destroy(grpc_security_connector* sc) {
- grpc_httpcli_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
- if (c->handshaker_factory != nullptr) {
- tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory);
- c->handshaker_factory = nullptr;
+class grpc_httpcli_ssl_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name)
+ : grpc_channel_security_connector(
+ /*url_scheme=*/nullptr,
+ /*channel_creds=*/nullptr,
+ /*request_metadata_creds=*/nullptr),
+ secure_peer_name_(secure_peer_name) {}
+
+ ~grpc_httpcli_ssl_channel_security_connector() override {
+ if (handshaker_factory_ != nullptr) {
+ tsi_ssl_client_handshaker_factory_unref(handshaker_factory_);
+ }
+ if (secure_peer_name_ != nullptr) {
+ gpr_free(secure_peer_name_);
+ }
+ }
+
+ tsi_result InitHandshakerFactory(const char* pem_root_certs,
+ const tsi_ssl_root_certs_store* root_store) {
+ tsi_ssl_client_handshaker_options options;
+ memset(&options, 0, sizeof(options));
+ options.pem_root_certs = pem_root_certs;
+ options.root_store = root_store;
+ return tsi_create_ssl_client_handshaker_factory_with_options(
+ &options, &handshaker_factory_);
}
- if (c->secure_peer_name != nullptr) gpr_free(c->secure_peer_name);
- gpr_free(sc);
-}
-static void httpcli_ssl_add_handshakers(grpc_channel_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- grpc_httpcli_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
- tsi_handshaker* handshaker = nullptr;
- if (c->handshaker_factory != nullptr) {
- tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
- c->handshaker_factory, c->secure_peer_name, &handshaker);
- if (result != TSI_OK) {
- gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
- tsi_result_to_string(result));
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr) override {
+ tsi_handshaker* handshaker = nullptr;
+ if (handshaker_factory_ != nullptr) {
+ tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
+ handshaker_factory_, secure_peer_name_, &handshaker);
+ if (result != TSI_OK) {
+ gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
+ tsi_result_to_string(result));
+ }
}
+ grpc_handshake_manager_add(
+ handshake_mgr, grpc_security_handshaker_create(handshaker, this));
}
- grpc_handshake_manager_add(
- handshake_mgr, grpc_security_handshaker_create(handshaker, &sc->base));
-}
-static void httpcli_ssl_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_httpcli_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc);
- grpc_error* error = GRPC_ERROR_NONE;
-
- /* Check the peer name. */
- if (c->secure_peer_name != nullptr &&
- !tsi_ssl_peer_matches_name(&peer, c->secure_peer_name)) {
- char* msg;
- gpr_asprintf(&msg, "Peer name %s is not in peer certificate",
- c->secure_peer_name);
- error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
- gpr_free(msg);
+ tsi_ssl_client_handshaker_factory* handshaker_factory() const {
+ return handshaker_factory_;
}
- GRPC_CLOSURE_SCHED(on_peer_checked, error);
- tsi_peer_destruct(&peer);
-}
-static int httpcli_ssl_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_httpcli_ssl_channel_security_connector* c1 =
- reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc1);
- grpc_httpcli_ssl_channel_security_connector* c2 =
- reinterpret_cast<grpc_httpcli_ssl_channel_security_connector*>(sc2);
- return strcmp(c1->secure_peer_name, c2->secure_peer_name);
-}
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* /*auth_context*/,
+ grpc_closure* on_peer_checked) override {
+ grpc_error* error = GRPC_ERROR_NONE;
+
+ /* Check the peer name. */
+ if (secure_peer_name_ != nullptr &&
+ !tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) {
+ char* msg;
+ gpr_asprintf(&msg, "Peer name %s is not in peer certificate",
+ secure_peer_name_);
+ error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
+ gpr_free(msg);
+ }
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
+ tsi_peer_destruct(&peer);
+ }
-static grpc_security_connector_vtable httpcli_ssl_vtable = {
- httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp};
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_httpcli_ssl_channel_security_connector*>(
+ other_sc);
+ return strcmp(secure_peer_name_, other->secure_peer_name_);
+ }
-static grpc_security_status httpcli_ssl_channel_security_connector_create(
- const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store,
- const char* secure_peer_name, grpc_channel_security_connector** sc) {
- tsi_result result = TSI_OK;
- grpc_httpcli_ssl_channel_security_connector* c;
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ *error = GRPC_ERROR_NONE;
+ return true;
+ }
- if (secure_peer_name != nullptr && pem_root_certs == nullptr) {
- gpr_log(GPR_ERROR,
- "Cannot assert a secure peer name without a trust root.");
- return GRPC_SECURITY_ERROR;
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
}
- c = static_cast<grpc_httpcli_ssl_channel_security_connector*>(
- gpr_zalloc(sizeof(grpc_httpcli_ssl_channel_security_connector)));
+ const char* secure_peer_name() const { return secure_peer_name_; }
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &httpcli_ssl_vtable;
- if (secure_peer_name != nullptr) {
- c->secure_peer_name = gpr_strdup(secure_peer_name);
+ private:
+ tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr;
+ char* secure_peer_name_;
+};
+
+static grpc_core::RefCountedPtr<grpc_channel_security_connector>
+httpcli_ssl_channel_security_connector_create(
+ const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store,
+ const char* secure_peer_name) {
+ if (secure_peer_name != nullptr && pem_root_certs == nullptr) {
+ gpr_log(GPR_ERROR,
+ "Cannot assert a secure peer name without a trust root.");
+ return nullptr;
}
- tsi_ssl_client_handshaker_options options;
- memset(&options, 0, sizeof(options));
- options.pem_root_certs = pem_root_certs;
- options.root_store = root_store;
- result = tsi_create_ssl_client_handshaker_factory_with_options(
- &options, &c->handshaker_factory);
+ grpc_core::RefCountedPtr<grpc_httpcli_ssl_channel_security_connector> c =
+ grpc_core::MakeRefCounted<grpc_httpcli_ssl_channel_security_connector>(
+ secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name));
+ tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store);
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
- httpcli_ssl_destroy(&c->base.base);
- *sc = nullptr;
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- // We don't actually need a channel credentials object in this case,
- // but we set it to a non-nullptr address so that we don't trigger
- // assertions in grpc_channel_security_connector_cmp().
- c->base.channel_creds = (grpc_channel_credentials*)1;
- c->base.add_handshakers = httpcli_ssl_add_handshakers;
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return c;
}
/* handshaker */
@@ -186,10 +192,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
}
c->func = on_done;
c->arg = arg;
- grpc_channel_security_connector* sc = nullptr;
- GPR_ASSERT(httpcli_ssl_channel_security_connector_create(
- pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK);
- grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base);
+ grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
+ httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store,
+ host);
+ GPR_ASSERT(sc != nullptr);
+ grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get());
grpc_channel_args args = {1, &channel_arg};
c->handshake_mgr = grpc_handshake_manager_create();
grpc_handshakers_add(HANDSHAKER_CLIENT, &args,
@@ -197,7 +204,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host,
grpc_handshake_manager_do_handshake(
c->handshake_mgr, tcp, nullptr /* channel_args */, deadline,
nullptr /* acceptor */, on_handshake_done, c /* user_data */);
- GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "httpcli");
+ sc.reset(DEBUG_LOCATION, "httpcli");
}
const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake};
diff --git a/src/core/lib/http/parser.h b/src/core/lib/http/parser.h
index 1d2e13e831..a8f47c96c8 100644
--- a/src/core/lib/http/parser.h
+++ b/src/core/lib/http/parser.h
@@ -70,13 +70,13 @@ typedef struct grpc_http_request {
/* A response */
typedef struct grpc_http_response {
/* HTTP status code */
- int status;
+ int status = 0;
/* Headers: count and key/values */
- size_t hdr_count;
- grpc_http_header* hdrs;
+ size_t hdr_count = 0;
+ grpc_http_header* hdrs = nullptr;
/* Body: length and contents; contents are NOT null-terminated */
- size_t body_length;
- char* body;
+ size_t body_length = 0;
+ char* body = nullptr;
} grpc_http_response;
typedef struct {
diff --git a/src/core/lib/iomgr/resource_quota.cc b/src/core/lib/iomgr/resource_quota.cc
index 7e4b3c9b2f..61c366098e 100644
--- a/src/core/lib/iomgr/resource_quota.cc
+++ b/src/core/lib/iomgr/resource_quota.cc
@@ -665,6 +665,7 @@ void grpc_resource_quota_unref_internal(grpc_resource_quota* resource_quota) {
GPR_ASSERT(resource_quota->num_threads_allocated == 0);
GRPC_COMBINER_UNREF(resource_quota->combiner, "resource_quota");
gpr_free(resource_quota->name);
+ gpr_mu_destroy(&resource_quota->thread_count_mu);
gpr_free(resource_quota);
}
}
diff --git a/src/core/lib/iomgr/tcp_posix.cc b/src/core/lib/iomgr/tcp_posix.cc
index cfcb190d60..c268c18664 100644
--- a/src/core/lib/iomgr/tcp_posix.cc
+++ b/src/core/lib/iomgr/tcp_posix.cc
@@ -126,6 +126,7 @@ struct grpc_tcp {
int bytes_counter;
bool socket_ts_enabled; /* True if timestamping options are set on the socket
*/
+ bool ts_capable; /* Cache whether we can set timestamping options */
gpr_atm
stop_error_notification; /* Set to 1 if we do not want to be notified on
errors anymore */
@@ -589,7 +590,7 @@ ssize_t tcp_send(int fd, const struct msghdr* msg) {
*/
static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
size_t sending_length,
- ssize_t* sent_length, grpc_error** error);
+ ssize_t* sent_length);
/** The callback function to be invoked when we get an error on the socket. */
static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error);
@@ -597,13 +598,11 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error);
#ifdef GRPC_LINUX_ERRQUEUE
static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
size_t sending_length,
- ssize_t* sent_length,
- grpc_error** error) {
+ ssize_t* sent_length) {
if (!tcp->socket_ts_enabled) {
uint32_t opt = grpc_core::kTimestampingSocketOptions;
if (setsockopt(tcp->fd, SOL_SOCKET, SO_TIMESTAMPING,
static_cast<void*>(&opt), sizeof(opt)) != 0) {
- *error = tcp_annotate_error(GRPC_OS_ERROR(errno, "setsockopt"), tcp);
grpc_slice_buffer_reset_and_unref_internal(tcp->outgoing_buffer);
if (grpc_tcp_trace.enabled()) {
gpr_log(GPR_ERROR, "Failed to set timestamping options on the socket.");
@@ -784,8 +783,7 @@ static void tcp_handle_error(void* arg /* grpc_tcp */, grpc_error* error) {
#else /* GRPC_LINUX_ERRQUEUE */
static bool tcp_write_with_timestamps(grpc_tcp* tcp, struct msghdr* msg,
size_t sending_length,
- ssize_t* sent_length,
- grpc_error** error) {
+ ssize_t* sent_length) {
gpr_log(GPR_ERROR, "Write with timestamps not supported for this platform");
GPR_ASSERT(0);
return false;
@@ -804,7 +802,7 @@ void tcp_shutdown_buffer_list(grpc_tcp* tcp) {
gpr_mu_lock(&tcp->tb_mu);
grpc_core::TracedBuffer::Shutdown(
&tcp->tb_head, tcp->outgoing_buffer_arg,
- GRPC_ERROR_CREATE_FROM_STATIC_STRING("endpoint destroyed"));
+ GRPC_ERROR_CREATE_FROM_STATIC_STRING("TracedBuffer list shutdown"));
gpr_mu_unlock(&tcp->tb_mu);
tcp->outgoing_buffer_arg = nullptr;
}
@@ -820,7 +818,7 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) {
struct msghdr msg;
struct iovec iov[MAX_WRITE_IOVEC];
msg_iovlen_type iov_size;
- ssize_t sent_length;
+ ssize_t sent_length = 0;
size_t sending_length;
size_t trailing;
size_t unwind_slice_idx;
@@ -855,13 +853,19 @@ static bool tcp_flush(grpc_tcp* tcp, grpc_error** error) {
msg.msg_iov = iov;
msg.msg_iovlen = iov_size;
msg.msg_flags = 0;
+ bool tried_sending_message = false;
if (tcp->outgoing_buffer_arg != nullptr) {
- if (!tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length,
- error)) {
+ if (!tcp->ts_capable ||
+ !tcp_write_with_timestamps(tcp, &msg, sending_length, &sent_length)) {
+ /* We could not set socket options to collect Fathom timestamps.
+ * Fallback on writing without timestamps. */
+ tcp->ts_capable = false;
tcp_shutdown_buffer_list(tcp);
- return true; /* something went wrong with timestamps */
+ } else {
+ tried_sending_message = true;
}
- } else {
+ }
+ if (!tried_sending_message) {
msg.msg_control = nullptr;
msg.msg_controllen = 0;
@@ -1117,6 +1121,7 @@ grpc_endpoint* grpc_tcp_create(grpc_fd* em_fd,
tcp->is_first_read = true;
tcp->bytes_counter = -1;
tcp->socket_ts_enabled = false;
+ tcp->ts_capable = true;
tcp->outgoing_buffer_arg = nullptr;
/* paired with unref in grpc_tcp_destroy */
gpr_ref_init(&tcp->refcount, 1);
diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc
index 16f40b4f55..8443ee0695 100644
--- a/src/core/lib/security/context/security_context.cc
+++ b/src/core/lib/security/context/security_context.cc
@@ -23,6 +23,8 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/arena.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/surface/api_trace.h"
#include "src/core/lib/surface/call.h"
@@ -50,13 +52,11 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call,
ctx = static_cast<grpc_client_security_context*>(
grpc_call_context_get(call, GRPC_CONTEXT_SECURITY));
if (ctx == nullptr) {
- ctx = grpc_client_security_context_create(grpc_call_get_arena(call));
- ctx->creds = grpc_call_credentials_ref(creds);
+ ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds);
grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx,
grpc_client_security_context_destroy);
} else {
- grpc_call_credentials_unref(ctx->creds);
- ctx->creds = grpc_call_credentials_ref(creds);
+ ctx->creds = creds != nullptr ? creds->Ref() : nullptr;
}
return GRPC_CALL_OK;
@@ -66,33 +66,45 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) {
void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY);
GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call));
if (sec_ctx == nullptr) return nullptr;
- return grpc_call_is_client(call)
- ? GRPC_AUTH_CONTEXT_REF(
- ((grpc_client_security_context*)sec_ctx)->auth_context,
- "grpc_call_auth_context client")
- : GRPC_AUTH_CONTEXT_REF(
- ((grpc_server_security_context*)sec_ctx)->auth_context,
- "grpc_call_auth_context server");
+ if (grpc_call_is_client(call)) {
+ auto* sc = static_cast<grpc_client_security_context*>(sec_ctx);
+ if (sc->auth_context == nullptr) {
+ return nullptr;
+ } else {
+ return sc->auth_context
+ ->Ref(DEBUG_LOCATION, "grpc_call_auth_context client")
+ .release();
+ }
+ } else {
+ auto* sc = static_cast<grpc_server_security_context*>(sec_ctx);
+ if (sc->auth_context == nullptr) {
+ return nullptr;
+ } else {
+ return sc->auth_context
+ ->Ref(DEBUG_LOCATION, "grpc_call_auth_context server")
+ .release();
+ }
+ }
}
void grpc_auth_context_release(grpc_auth_context* context) {
GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context));
- GRPC_AUTH_CONTEXT_UNREF(context, "grpc_auth_context_unref");
+ if (context == nullptr) return;
+ context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref");
}
/* --- grpc_client_security_context --- */
grpc_client_security_context::~grpc_client_security_context() {
- grpc_call_credentials_unref(creds);
- GRPC_AUTH_CONTEXT_UNREF(auth_context, "client_security_context");
+ auth_context.reset(DEBUG_LOCATION, "client_security_context");
if (extension.instance != nullptr && extension.destroy != nullptr) {
extension.destroy(extension.instance);
}
}
grpc_client_security_context* grpc_client_security_context_create(
- gpr_arena* arena) {
+ gpr_arena* arena, grpc_call_credentials* creds) {
return new (gpr_arena_alloc(arena, sizeof(grpc_client_security_context)))
- grpc_client_security_context();
+ grpc_client_security_context(creds != nullptr ? creds->Ref() : nullptr);
}
void grpc_client_security_context_destroy(void* ctx) {
@@ -104,7 +116,7 @@ void grpc_client_security_context_destroy(void* ctx) {
/* --- grpc_server_security_context --- */
grpc_server_security_context::~grpc_server_security_context() {
- GRPC_AUTH_CONTEXT_UNREF(auth_context, "server_security_context");
+ auth_context.reset(DEBUG_LOCATION, "server_security_context");
if (extension.instance != nullptr && extension.destroy != nullptr) {
extension.destroy(extension.instance);
}
@@ -126,69 +138,11 @@ void grpc_server_security_context_destroy(void* ctx) {
static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr};
-grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained) {
- grpc_auth_context* ctx =
- static_cast<grpc_auth_context*>(gpr_zalloc(sizeof(grpc_auth_context)));
- gpr_ref_init(&ctx->refcount, 1);
- if (chained != nullptr) {
- ctx->chained = GRPC_AUTH_CONTEXT_REF(chained, "chained");
- ctx->peer_identity_property_name =
- ctx->chained->peer_identity_property_name;
- }
- return ctx;
-}
-
-#ifndef NDEBUG
-grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx,
- const char* file, int line,
- const char* reason) {
- if (ctx == nullptr) return nullptr;
- if (grpc_trace_auth_context_refcount.enabled()) {
- gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count);
- gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
- "AUTH_CONTEXT:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val,
- val + 1, reason);
- }
-#else
-grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx) {
- if (ctx == nullptr) return nullptr;
-#endif
- gpr_ref(&ctx->refcount);
- return ctx;
-}
-
-#ifndef NDEBUG
-void grpc_auth_context_unref(grpc_auth_context* ctx, const char* file, int line,
- const char* reason) {
- if (ctx == nullptr) return;
- if (grpc_trace_auth_context_refcount.enabled()) {
- gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count);
- gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
- "AUTH_CONTEXT:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val,
- val - 1, reason);
- }
-#else
-void grpc_auth_context_unref(grpc_auth_context* ctx) {
- if (ctx == nullptr) return;
-#endif
- if (gpr_unref(&ctx->refcount)) {
- size_t i;
- GRPC_AUTH_CONTEXT_UNREF(ctx->chained, "chained");
- if (ctx->properties.array != nullptr) {
- for (i = 0; i < ctx->properties.count; i++) {
- grpc_auth_property_reset(&ctx->properties.array[i]);
- }
- gpr_free(ctx->properties.array);
- }
- gpr_free(ctx);
- }
-}
-
const char* grpc_auth_context_peer_identity_property_name(
const grpc_auth_context* ctx) {
GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1,
(ctx));
- return ctx->peer_identity_property_name;
+ return ctx->peer_identity_property_name();
}
int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx,
@@ -204,13 +158,13 @@ int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx,
name != nullptr ? name : "NULL");
return 0;
}
- ctx->peer_identity_property_name = prop->name;
+ ctx->set_peer_identity_property_name(prop->name);
return 1;
}
int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) {
GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx));
- return ctx->peer_identity_property_name == nullptr ? 0 : 1;
+ return ctx->is_authenticated();
}
grpc_auth_property_iterator grpc_auth_context_property_iterator(
@@ -226,16 +180,17 @@ const grpc_auth_property* grpc_auth_property_iterator_next(
grpc_auth_property_iterator* it) {
GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it));
if (it == nullptr || it->ctx == nullptr) return nullptr;
- while (it->index == it->ctx->properties.count) {
- if (it->ctx->chained == nullptr) return nullptr;
- it->ctx = it->ctx->chained;
+ while (it->index == it->ctx->properties().count) {
+ if (it->ctx->chained() == nullptr) return nullptr;
+ it->ctx = it->ctx->chained();
it->index = 0;
}
if (it->name == nullptr) {
- return &it->ctx->properties.array[it->index++];
+ return &it->ctx->properties().array[it->index++];
} else {
- while (it->index < it->ctx->properties.count) {
- const grpc_auth_property* prop = &it->ctx->properties.array[it->index++];
+ while (it->index < it->ctx->properties().count) {
+ const grpc_auth_property* prop =
+ &it->ctx->properties().array[it->index++];
GPR_ASSERT(prop->name != nullptr);
if (strcmp(it->name, prop->name) == 0) {
return prop;
@@ -262,49 +217,56 @@ grpc_auth_property_iterator grpc_auth_context_peer_identity(
GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx));
if (ctx == nullptr) return empty_iterator;
return grpc_auth_context_find_properties_by_name(
- ctx, ctx->peer_identity_property_name);
+ ctx, ctx->peer_identity_property_name());
}
-static void ensure_auth_context_capacity(grpc_auth_context* ctx) {
- if (ctx->properties.count == ctx->properties.capacity) {
- ctx->properties.capacity =
- GPR_MAX(ctx->properties.capacity + 8, ctx->properties.capacity * 2);
- ctx->properties.array = static_cast<grpc_auth_property*>(
- gpr_realloc(ctx->properties.array,
- ctx->properties.capacity * sizeof(grpc_auth_property)));
+void grpc_auth_context::ensure_capacity() {
+ if (properties_.count == properties_.capacity) {
+ properties_.capacity =
+ GPR_MAX(properties_.capacity + 8, properties_.capacity * 2);
+ properties_.array = static_cast<grpc_auth_property*>(gpr_realloc(
+ properties_.array, properties_.capacity * sizeof(grpc_auth_property)));
}
}
+void grpc_auth_context::add_property(const char* name, const char* value,
+ size_t value_length) {
+ ensure_capacity();
+ grpc_auth_property* prop = &properties_.array[properties_.count++];
+ prop->name = gpr_strdup(name);
+ prop->value = static_cast<char*>(gpr_malloc(value_length + 1));
+ memcpy(prop->value, value, value_length);
+ prop->value[value_length] = '\0';
+ prop->value_length = value_length;
+}
+
void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name,
const char* value, size_t value_length) {
- grpc_auth_property* prop;
GRPC_API_TRACE(
"grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, "
"value_length=%lu)",
6,
(ctx, name, (int)value_length, (int)value_length, value,
(unsigned long)value_length));
- ensure_auth_context_capacity(ctx);
- prop = &ctx->properties.array[ctx->properties.count++];
+ ctx->add_property(name, value, value_length);
+}
+
+void grpc_auth_context::add_cstring_property(const char* name,
+ const char* value) {
+ ensure_capacity();
+ grpc_auth_property* prop = &properties_.array[properties_.count++];
prop->name = gpr_strdup(name);
- prop->value = static_cast<char*>(gpr_malloc(value_length + 1));
- memcpy(prop->value, value, value_length);
- prop->value[value_length] = '\0';
- prop->value_length = value_length;
+ prop->value = gpr_strdup(value);
+ prop->value_length = strlen(value);
}
void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx,
const char* name,
const char* value) {
- grpc_auth_property* prop;
GRPC_API_TRACE(
"grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3,
(ctx, name, value));
- ensure_auth_context_capacity(ctx);
- prop = &ctx->properties.array[ctx->properties.count++];
- prop->name = gpr_strdup(name);
- prop->value = gpr_strdup(value);
- prop->value_length = strlen(value);
+ ctx->add_cstring_property(name, value);
}
void grpc_auth_property_reset(grpc_auth_property* property) {
@@ -314,12 +276,17 @@ void grpc_auth_property_reset(grpc_auth_property* property) {
}
static void auth_context_pointer_arg_destroy(void* p) {
- GRPC_AUTH_CONTEXT_UNREF((grpc_auth_context*)p, "auth_context_pointer_arg");
+ if (p != nullptr) {
+ static_cast<grpc_auth_context*>(p)->Unref(DEBUG_LOCATION,
+ "auth_context_pointer_arg");
+ }
}
static void* auth_context_pointer_arg_copy(void* p) {
- return GRPC_AUTH_CONTEXT_REF((grpc_auth_context*)p,
- "auth_context_pointer_arg");
+ auto* ctx = static_cast<grpc_auth_context*>(p);
+ return ctx == nullptr
+ ? nullptr
+ : ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release();
}
static int auth_context_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); }
diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h
index e45415f63b..b43ee5e62d 100644
--- a/src/core/lib/security/context/security_context.h
+++ b/src/core/lib/security/context/security_context.h
@@ -21,6 +21,8 @@
#include <grpc/support/port_platform.h>
+#include "src/core/lib/gprpp/ref_counted.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/security/credentials/credentials.h"
@@ -40,39 +42,59 @@ struct grpc_auth_property_array {
size_t capacity = 0;
};
-struct grpc_auth_context {
- grpc_auth_context() { gpr_ref_init(&refcount, 0); }
+void grpc_auth_property_reset(grpc_auth_property* property);
- struct grpc_auth_context* chained = nullptr;
- grpc_auth_property_array properties;
- gpr_refcount refcount;
- const char* peer_identity_property_name = nullptr;
- grpc_pollset* pollset = nullptr;
+// This type is forward declared as a C struct and we cannot define it as a
+// class. Otherwise, compiler will complain about type mismatch due to
+// -Wmismatched-tags.
+struct grpc_auth_context
+ : public grpc_core::RefCounted<grpc_auth_context,
+ grpc_core::NonPolymorphicRefCount> {
+ public:
+ explicit grpc_auth_context(
+ grpc_core::RefCountedPtr<grpc_auth_context> chained)
+ : grpc_core::RefCounted<grpc_auth_context,
+ grpc_core::NonPolymorphicRefCount>(
+ &grpc_trace_auth_context_refcount),
+ chained_(std::move(chained)) {
+ if (chained_ != nullptr) {
+ peer_identity_property_name_ = chained_->peer_identity_property_name_;
+ }
+ }
+
+ ~grpc_auth_context() {
+ chained_.reset(DEBUG_LOCATION, "chained");
+ if (properties_.array != nullptr) {
+ for (size_t i = 0; i < properties_.count; i++) {
+ grpc_auth_property_reset(&properties_.array[i]);
+ }
+ gpr_free(properties_.array);
+ }
+ }
+
+ const grpc_auth_context* chained() const { return chained_.get(); }
+ const grpc_auth_property_array& properties() const { return properties_; }
+
+ bool is_authenticated() const {
+ return peer_identity_property_name_ != nullptr;
+ }
+ const char* peer_identity_property_name() const {
+ return peer_identity_property_name_;
+ }
+ void set_peer_identity_property_name(const char* name) {
+ peer_identity_property_name_ = name;
+ }
+
+ void ensure_capacity();
+ void add_property(const char* name, const char* value, size_t value_length);
+ void add_cstring_property(const char* name, const char* value);
+
+ private:
+ grpc_core::RefCountedPtr<grpc_auth_context> chained_;
+ grpc_auth_property_array properties_;
+ const char* peer_identity_property_name_ = nullptr;
};
-/* Creation. */
-grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained);
-
-/* Refcounting. */
-#ifndef NDEBUG
-#define GRPC_AUTH_CONTEXT_REF(p, r) \
- grpc_auth_context_ref((p), __FILE__, __LINE__, (r))
-#define GRPC_AUTH_CONTEXT_UNREF(p, r) \
- grpc_auth_context_unref((p), __FILE__, __LINE__, (r))
-grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy,
- const char* file, int line,
- const char* reason);
-void grpc_auth_context_unref(grpc_auth_context* policy, const char* file,
- int line, const char* reason);
-#else
-#define GRPC_AUTH_CONTEXT_REF(p, r) grpc_auth_context_ref((p))
-#define GRPC_AUTH_CONTEXT_UNREF(p, r) grpc_auth_context_unref((p))
-grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy);
-void grpc_auth_context_unref(grpc_auth_context* policy);
-#endif
-
-void grpc_auth_property_reset(grpc_auth_property* property);
-
/* --- grpc_security_context_extension ---
Extension to the security context that may be set in a filter and accessed
@@ -88,16 +110,18 @@ struct grpc_security_context_extension {
Internal client-side security context. */
struct grpc_client_security_context {
- grpc_client_security_context() = default;
+ explicit grpc_client_security_context(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds)
+ : creds(std::move(creds)) {}
~grpc_client_security_context();
- grpc_call_credentials* creds = nullptr;
- grpc_auth_context* auth_context = nullptr;
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds;
+ grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_security_context_extension extension;
};
grpc_client_security_context* grpc_client_security_context_create(
- gpr_arena* arena);
+ gpr_arena* arena, grpc_call_credentials* creds);
void grpc_client_security_context_destroy(void* ctx);
/* --- grpc_server_security_context ---
@@ -108,7 +132,7 @@ struct grpc_server_security_context {
grpc_server_security_context() = default;
~grpc_server_security_context();
- grpc_auth_context* auth_context = nullptr;
+ grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
grpc_security_context_extension extension;
};
diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc
index 1fbef4ae0c..06546492bc 100644
--- a/src/core/lib/security/credentials/alts/alts_credentials.cc
+++ b/src/core/lib/security/credentials/alts/alts_credentials.cc
@@ -33,40 +33,47 @@
#define GRPC_CREDENTIALS_TYPE_ALTS "Alts"
#define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal:8080"
-static void alts_credentials_destruct(grpc_channel_credentials* creds) {
- grpc_alts_credentials* alts_creds =
- reinterpret_cast<grpc_alts_credentials*>(creds);
- grpc_alts_credentials_options_destroy(alts_creds->options);
- gpr_free(alts_creds->handshaker_service_url);
-}
-
-static void alts_server_credentials_destruct(grpc_server_credentials* creds) {
- grpc_alts_server_credentials* alts_creds =
- reinterpret_cast<grpc_alts_server_credentials*>(creds);
- grpc_alts_credentials_options_destroy(alts_creds->options);
- gpr_free(alts_creds->handshaker_service_url);
+grpc_alts_credentials::grpc_alts_credentials(
+ const grpc_alts_credentials_options* options,
+ const char* handshaker_service_url)
+ : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS),
+ options_(grpc_alts_credentials_options_copy(options)),
+ handshaker_service_url_(handshaker_service_url == nullptr
+ ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
+ : gpr_strdup(handshaker_service_url)) {}
+
+grpc_alts_credentials::~grpc_alts_credentials() {
+ grpc_alts_credentials_options_destroy(options_);
+ gpr_free(handshaker_service_url_);
}
-static grpc_security_status alts_create_security_connector(
- grpc_channel_credentials* creds,
- grpc_call_credentials* request_metadata_creds, const char* target_name,
- const grpc_channel_args* args, grpc_channel_security_connector** sc,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_alts_credentials::create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) {
return grpc_alts_channel_security_connector_create(
- creds, request_metadata_creds, target_name, sc);
+ this->Ref(), std::move(call_creds), target_name);
}
-static grpc_security_status alts_server_create_security_connector(
- grpc_server_credentials* creds, grpc_server_security_connector** sc) {
- return grpc_alts_server_security_connector_create(creds, sc);
+grpc_alts_server_credentials::grpc_alts_server_credentials(
+ const grpc_alts_credentials_options* options,
+ const char* handshaker_service_url)
+ : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS),
+ options_(grpc_alts_credentials_options_copy(options)),
+ handshaker_service_url_(handshaker_service_url == nullptr
+ ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
+ : gpr_strdup(handshaker_service_url)) {}
+
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_alts_server_credentials::create_security_connector() {
+ return grpc_alts_server_security_connector_create(this->Ref());
}
-static const grpc_channel_credentials_vtable alts_credentials_vtable = {
- alts_credentials_destruct, alts_create_security_connector,
- /*duplicate_without_call_credentials=*/nullptr};
-
-static const grpc_server_credentials_vtable alts_server_credentials_vtable = {
- alts_server_credentials_destruct, alts_server_create_security_connector};
+grpc_alts_server_credentials::~grpc_alts_server_credentials() {
+ grpc_alts_credentials_options_destroy(options_);
+ gpr_free(handshaker_service_url_);
+}
grpc_channel_credentials* grpc_alts_credentials_create_customized(
const grpc_alts_credentials_options* options,
@@ -74,17 +81,7 @@ grpc_channel_credentials* grpc_alts_credentials_create_customized(
if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) {
return nullptr;
}
- auto creds = static_cast<grpc_alts_credentials*>(
- gpr_zalloc(sizeof(grpc_alts_credentials)));
- creds->options = grpc_alts_credentials_options_copy(options);
- creds->handshaker_service_url =
- handshaker_service_url == nullptr
- ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
- : gpr_strdup(handshaker_service_url);
- creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS;
- creds->base.vtable = &alts_credentials_vtable;
- gpr_ref_init(&creds->base.refcount, 1);
- return &creds->base;
+ return grpc_core::New<grpc_alts_credentials>(options, handshaker_service_url);
}
grpc_server_credentials* grpc_alts_server_credentials_create_customized(
@@ -93,17 +90,8 @@ grpc_server_credentials* grpc_alts_server_credentials_create_customized(
if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) {
return nullptr;
}
- auto creds = static_cast<grpc_alts_server_credentials*>(
- gpr_zalloc(sizeof(grpc_alts_server_credentials)));
- creds->options = grpc_alts_credentials_options_copy(options);
- creds->handshaker_service_url =
- handshaker_service_url == nullptr
- ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL)
- : gpr_strdup(handshaker_service_url);
- creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS;
- creds->base.vtable = &alts_server_credentials_vtable;
- gpr_ref_init(&creds->base.refcount, 1);
- return &creds->base;
+ return grpc_core::New<grpc_alts_server_credentials>(options,
+ handshaker_service_url);
}
grpc_channel_credentials* grpc_alts_credentials_create(
diff --git a/src/core/lib/security/credentials/alts/alts_credentials.h b/src/core/lib/security/credentials/alts/alts_credentials.h
index 810117f2be..cc6d5222b1 100644
--- a/src/core/lib/security/credentials/alts/alts_credentials.h
+++ b/src/core/lib/security/credentials/alts/alts_credentials.h
@@ -27,18 +27,45 @@
#include "src/core/lib/security/credentials/credentials.h"
/* Main struct for grpc ALTS channel credential. */
-typedef struct grpc_alts_credentials {
- grpc_channel_credentials base;
- grpc_alts_credentials_options* options;
- char* handshaker_service_url;
-} grpc_alts_credentials;
+class grpc_alts_credentials final : public grpc_channel_credentials {
+ public:
+ grpc_alts_credentials(const grpc_alts_credentials_options* options,
+ const char* handshaker_service_url);
+ ~grpc_alts_credentials() override;
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target_name, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override;
+
+ const grpc_alts_credentials_options* options() const { return options_; }
+ grpc_alts_credentials_options* mutable_options() { return options_; }
+ const char* handshaker_service_url() const { return handshaker_service_url_; }
+
+ private:
+ grpc_alts_credentials_options* options_;
+ char* handshaker_service_url_;
+};
/* Main struct for grpc ALTS server credential. */
-typedef struct grpc_alts_server_credentials {
- grpc_server_credentials base;
- grpc_alts_credentials_options* options;
- char* handshaker_service_url;
-} grpc_alts_server_credentials;
+class grpc_alts_server_credentials final : public grpc_server_credentials {
+ public:
+ grpc_alts_server_credentials(const grpc_alts_credentials_options* options,
+ const char* handshaker_service_url);
+ ~grpc_alts_server_credentials() override;
+
+ grpc_core::RefCountedPtr<grpc_server_security_connector>
+ create_security_connector() override;
+
+ const grpc_alts_credentials_options* options() const { return options_; }
+ grpc_alts_credentials_options* mutable_options() { return options_; }
+ const char* handshaker_service_url() const { return handshaker_service_url_; }
+
+ private:
+ grpc_alts_credentials_options* options_;
+ char* handshaker_service_url_;
+};
/**
* This method creates an ALTS channel credential object with customized
diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc
index b8f409260f..586bbed778 100644
--- a/src/core/lib/security/credentials/composite/composite_credentials.cc
+++ b/src/core/lib/security/credentials/composite/composite_credentials.cc
@@ -20,8 +20,10 @@
#include "src/core/lib/security/credentials/composite/composite_credentials.h"
-#include <string.h>
+#include <cstring>
+#include <new>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/surface/api_trace.h"
@@ -31,35 +33,44 @@
/* -- Composite call credentials. -- */
-typedef struct {
+static void composite_call_metadata_cb(void* arg, grpc_error* error);
+
+namespace {
+struct grpc_composite_call_credentials_metadata_context {
+ grpc_composite_call_credentials_metadata_context(
+ grpc_composite_call_credentials* composite_creds,
+ grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata)
+ : composite_creds(composite_creds),
+ pollent(pollent),
+ auth_md_context(auth_md_context),
+ md_array(md_array),
+ on_request_metadata(on_request_metadata) {
+ GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb,
+ this, grpc_schedule_on_exec_ctx);
+ }
+
grpc_composite_call_credentials* composite_creds;
- size_t creds_index;
+ size_t creds_index = 0;
grpc_polling_entity* pollent;
grpc_auth_metadata_context auth_md_context;
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
grpc_closure internal_on_request_metadata;
-} grpc_composite_call_credentials_metadata_context;
-
-static void composite_call_destruct(grpc_call_credentials* creds) {
- grpc_composite_call_credentials* c =
- reinterpret_cast<grpc_composite_call_credentials*>(creds);
- for (size_t i = 0; i < c->inner.num_creds; i++) {
- grpc_call_credentials_unref(c->inner.creds_array[i]);
- }
- gpr_free(c->inner.creds_array);
-}
+};
+} // namespace
static void composite_call_metadata_cb(void* arg, grpc_error* error) {
grpc_composite_call_credentials_metadata_context* ctx =
static_cast<grpc_composite_call_credentials_metadata_context*>(arg);
if (error == GRPC_ERROR_NONE) {
+ const grpc_composite_call_credentials::CallCredentialsList& inner =
+ ctx->composite_creds->inner();
/* See if we need to get some more metadata. */
- if (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
- grpc_call_credentials* inner_creds =
- ctx->composite_creds->inner.creds_array[ctx->creds_index++];
- if (grpc_call_credentials_get_request_metadata(
- inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
+ if (ctx->creds_index < inner.size()) {
+ if (inner[ctx->creds_index++]->get_request_metadata(
+ ctx->pollent, ctx->auth_md_context, ctx->md_array,
&ctx->internal_on_request_metadata, &error)) {
// Synchronous response, so call ourselves recursively.
composite_call_metadata_cb(arg, error);
@@ -73,29 +84,18 @@ static void composite_call_metadata_cb(void* arg, grpc_error* error) {
gpr_free(ctx);
}
-static bool composite_call_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context auth_md_context,
+bool grpc_composite_call_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context,
grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
grpc_error** error) {
- grpc_composite_call_credentials* c =
- reinterpret_cast<grpc_composite_call_credentials*>(creds);
grpc_composite_call_credentials_metadata_context* ctx;
- ctx = static_cast<grpc_composite_call_credentials_metadata_context*>(
- gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context)));
- ctx->composite_creds = c;
- ctx->pollent = pollent;
- ctx->auth_md_context = auth_md_context;
- ctx->md_array = md_array;
- ctx->on_request_metadata = on_request_metadata;
- GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata,
- composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx);
+ ctx = grpc_core::New<grpc_composite_call_credentials_metadata_context>(
+ this, pollent, auth_md_context, md_array, on_request_metadata);
bool synchronous = true;
- while (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
- grpc_call_credentials* inner_creds =
- ctx->composite_creds->inner.creds_array[ctx->creds_index++];
- if (grpc_call_credentials_get_request_metadata(
- inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
+ const CallCredentialsList& inner = ctx->composite_creds->inner();
+ while (ctx->creds_index < inner.size()) {
+ if (inner[ctx->creds_index++]->get_request_metadata(
+ ctx->pollent, ctx->auth_md_context, ctx->md_array,
&ctx->internal_on_request_metadata, error)) {
if (*error != GRPC_ERROR_NONE) break;
} else {
@@ -103,46 +103,66 @@ static bool composite_call_get_request_metadata(
break;
}
}
- if (synchronous) gpr_free(ctx);
+ if (synchronous) grpc_core::Delete(ctx);
return synchronous;
}
-static void composite_call_cancel_get_request_metadata(
- grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
- grpc_composite_call_credentials* c =
- reinterpret_cast<grpc_composite_call_credentials*>(creds);
- for (size_t i = 0; i < c->inner.num_creds; ++i) {
- grpc_call_credentials_cancel_get_request_metadata(
- c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error));
+void grpc_composite_call_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
+ for (size_t i = 0; i < inner_.size(); ++i) {
+ inner_[i]->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error));
}
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable composite_call_credentials_vtable = {
- composite_call_destruct, composite_call_get_request_metadata,
- composite_call_cancel_get_request_metadata};
+static size_t get_creds_array_size(const grpc_call_credentials* creds,
+ bool is_composite) {
+ return is_composite
+ ? static_cast<const grpc_composite_call_credentials*>(creds)
+ ->inner()
+ .size()
+ : 1;
+}
-static grpc_call_credentials_array get_creds_array(
- grpc_call_credentials** creds_addr) {
- grpc_call_credentials_array result;
- grpc_call_credentials* creds = *creds_addr;
- result.creds_array = creds_addr;
- result.num_creds = 1;
- if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
- result = *grpc_composite_call_credentials_get_credentials(creds);
+void grpc_composite_call_credentials::push_to_inner(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds, bool is_composite) {
+ if (!is_composite) {
+ inner_.push_back(std::move(creds));
+ return;
+ }
+ auto composite_creds =
+ static_cast<grpc_composite_call_credentials*>(creds.get());
+ for (size_t i = 0; i < composite_creds->inner().size(); ++i) {
+ inner_.push_back(std::move(composite_creds->inner_[i]));
}
- return result;
+}
+
+grpc_composite_call_credentials::grpc_composite_call_credentials(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds2)
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) {
+ const bool creds1_is_composite =
+ strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0;
+ const bool creds2_is_composite =
+ strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0;
+ const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) +
+ get_creds_array_size(creds2.get(), creds2_is_composite);
+ inner_.reserve(size);
+ push_to_inner(std::move(creds1), creds1_is_composite);
+ push_to_inner(std::move(creds2), creds2_is_composite);
+}
+
+static grpc_core::RefCountedPtr<grpc_call_credentials>
+composite_call_credentials_create(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds2) {
+ return grpc_core::MakeRefCounted<grpc_composite_call_credentials>(
+ std::move(creds1), std::move(creds2));
}
grpc_call_credentials* grpc_composite_call_credentials_create(
grpc_call_credentials* creds1, grpc_call_credentials* creds2,
void* reserved) {
- size_t i;
- size_t creds_array_byte_size;
- grpc_call_credentials_array creds1_array;
- grpc_call_credentials_array creds2_array;
- grpc_composite_call_credentials* c;
GRPC_API_TRACE(
"grpc_composite_call_credentials_create(creds1=%p, creds2=%p, "
"reserved=%p)",
@@ -150,120 +170,40 @@ grpc_call_credentials* grpc_composite_call_credentials_create(
GPR_ASSERT(reserved == nullptr);
GPR_ASSERT(creds1 != nullptr);
GPR_ASSERT(creds2 != nullptr);
- c = static_cast<grpc_composite_call_credentials*>(
- gpr_zalloc(sizeof(grpc_composite_call_credentials)));
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE;
- c->base.vtable = &composite_call_credentials_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- creds1_array = get_creds_array(&creds1);
- creds2_array = get_creds_array(&creds2);
- c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds;
- creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*);
- c->inner.creds_array =
- static_cast<grpc_call_credentials**>(gpr_zalloc(creds_array_byte_size));
- for (i = 0; i < creds1_array.num_creds; i++) {
- grpc_call_credentials* cur_creds = creds1_array.creds_array[i];
- c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds);
- }
- for (i = 0; i < creds2_array.num_creds; i++) {
- grpc_call_credentials* cur_creds = creds2_array.creds_array[i];
- c->inner.creds_array[i + creds1_array.num_creds] =
- grpc_call_credentials_ref(cur_creds);
- }
- return &c->base;
-}
-
-const grpc_call_credentials_array*
-grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) {
- const grpc_composite_call_credentials* c =
- reinterpret_cast<const grpc_composite_call_credentials*>(creds);
- GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
- return &c->inner;
-}
-grpc_call_credentials* grpc_credentials_contains_type(
- grpc_call_credentials* creds, const char* type,
- grpc_call_credentials** composite_creds) {
- size_t i;
- if (strcmp(creds->type, type) == 0) {
- if (composite_creds != nullptr) *composite_creds = nullptr;
- return creds;
- } else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
- const grpc_call_credentials_array* inner_creds_array =
- grpc_composite_call_credentials_get_credentials(creds);
- for (i = 0; i < inner_creds_array->num_creds; i++) {
- if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) {
- if (composite_creds != nullptr) *composite_creds = creds;
- return inner_creds_array->creds_array[i];
- }
- }
- }
- return nullptr;
+ return composite_call_credentials_create(creds1->Ref(), creds2->Ref())
+ .release();
}
/* -- Composite channel credentials. -- */
-static void composite_channel_destruct(grpc_channel_credentials* creds) {
- grpc_composite_channel_credentials* c =
- reinterpret_cast<grpc_composite_channel_credentials*>(creds);
- grpc_channel_credentials_unref(c->inner_creds);
- grpc_call_credentials_unref(c->call_creds);
-}
-
-static grpc_security_status composite_channel_create_security_connector(
- grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_composite_channel_credentials::create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
- grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
- grpc_composite_channel_credentials* c =
- reinterpret_cast<grpc_composite_channel_credentials*>(creds);
- grpc_security_status status = GRPC_SECURITY_ERROR;
-
- GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr &&
- c->inner_creds->vtable != nullptr &&
- c->inner_creds->vtable->create_security_connector != nullptr);
+ grpc_channel_args** new_args) {
+ GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr);
/* If we are passed a call_creds, create a call composite to pass it
downstream. */
if (call_creds != nullptr) {
- grpc_call_credentials* composite_call_creds =
- grpc_composite_call_credentials_create(c->call_creds, call_creds,
- nullptr);
- status = c->inner_creds->vtable->create_security_connector(
- c->inner_creds, composite_call_creds, target, args, sc, new_args);
- grpc_call_credentials_unref(composite_call_creds);
+ return inner_creds_->create_security_connector(
+ composite_call_credentials_create(call_creds_, std::move(call_creds)),
+ target, args, new_args);
} else {
- status = c->inner_creds->vtable->create_security_connector(
- c->inner_creds, c->call_creds, target, args, sc, new_args);
+ return inner_creds_->create_security_connector(call_creds_, target, args,
+ new_args);
}
- return status;
}
-static grpc_channel_credentials*
-composite_channel_duplicate_without_call_credentials(
- grpc_channel_credentials* creds) {
- grpc_composite_channel_credentials* c =
- reinterpret_cast<grpc_composite_channel_credentials*>(creds);
- return grpc_channel_credentials_ref(c->inner_creds);
-}
-
-static grpc_channel_credentials_vtable composite_channel_credentials_vtable = {
- composite_channel_destruct, composite_channel_create_security_connector,
- composite_channel_duplicate_without_call_credentials};
-
grpc_channel_credentials* grpc_composite_channel_credentials_create(
grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds,
void* reserved) {
- grpc_composite_channel_credentials* c =
- static_cast<grpc_composite_channel_credentials*>(gpr_zalloc(sizeof(*c)));
GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr &&
reserved == nullptr);
GRPC_API_TRACE(
"grpc_composite_channel_credentials_create(channel_creds=%p, "
"call_creds=%p, reserved=%p)",
3, (channel_creds, call_creds, reserved));
- c->base.type = channel_creds->type;
- c->base.vtable = &composite_channel_credentials_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- c->inner_creds = grpc_channel_credentials_ref(channel_creds);
- c->call_creds = grpc_call_credentials_ref(call_creds);
- return &c->base;
+ return grpc_core::New<grpc_composite_channel_credentials>(
+ channel_creds->Ref(), call_creds->Ref());
}
diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h
index a952ad57f1..7a1c7d5e42 100644
--- a/src/core/lib/security/credentials/composite/composite_credentials.h
+++ b/src/core/lib/security/credentials/composite/composite_credentials.h
@@ -21,39 +21,75 @@
#include <grpc/support/port_platform.h>
+#include "src/core/lib/gprpp/inlined_vector.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/credentials.h"
-typedef struct {
- grpc_call_credentials** creds_array;
- size_t num_creds;
-} grpc_call_credentials_array;
+/* -- Composite channel credentials. -- */
-const grpc_call_credentials_array*
-grpc_composite_call_credentials_get_credentials(
- grpc_call_credentials* composite_creds);
+class grpc_composite_channel_credentials : public grpc_channel_credentials {
+ public:
+ grpc_composite_channel_credentials(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds)
+ : grpc_channel_credentials(channel_creds->type()),
+ inner_creds_(std::move(channel_creds)),
+ call_creds_(std::move(call_creds)) {}
-/* Returns creds if creds is of the specified type or the inner creds of the
- specified type (if found), if the creds is of type COMPOSITE.
- If composite_creds is not NULL, *composite_creds will point to creds if of
- type COMPOSITE in case of success. */
-grpc_call_credentials* grpc_credentials_contains_type(
- grpc_call_credentials* creds, const char* type,
- grpc_call_credentials** composite_creds);
+ ~grpc_composite_channel_credentials() override = default;
-/* -- Composite channel credentials. -- */
+ grpc_core::RefCountedPtr<grpc_channel_credentials>
+ duplicate_without_call_credentials() override {
+ return inner_creds_;
+ }
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override;
-typedef struct {
- grpc_channel_credentials base;
- grpc_channel_credentials* inner_creds;
- grpc_call_credentials* call_creds;
-} grpc_composite_channel_credentials;
+ const grpc_channel_credentials* inner_creds() const {
+ return inner_creds_.get();
+ }
+ const grpc_call_credentials* call_creds() const { return call_creds_.get(); }
+ grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); }
+
+ private:
+ grpc_core::RefCountedPtr<grpc_channel_credentials> inner_creds_;
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds_;
+};
/* -- Composite call credentials. -- */
-typedef struct {
- grpc_call_credentials base;
- grpc_call_credentials_array inner;
-} grpc_composite_call_credentials;
+class grpc_composite_call_credentials : public grpc_call_credentials {
+ public:
+ using CallCredentialsList =
+ grpc_core::InlinedVector<grpc_core::RefCountedPtr<grpc_call_credentials>,
+ 2>;
+
+ grpc_composite_call_credentials(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds1,
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds2);
+ ~grpc_composite_call_credentials() override = default;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ const CallCredentialsList& inner() const { return inner_; }
+
+ private:
+ void push_to_inner(grpc_core::RefCountedPtr<grpc_call_credentials> creds,
+ bool is_composite);
+
+ CallCredentialsList inner_;
+};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_COMPOSITE_COMPOSITE_CREDENTIALS_H \
*/
diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc
index c43cb440eb..90452d68d6 100644
--- a/src/core/lib/security/credentials/credentials.cc
+++ b/src/core/lib/security/credentials/credentials.cc
@@ -39,120 +39,24 @@
/* -- Common. -- */
-grpc_credentials_metadata_request* grpc_credentials_metadata_request_create(
- grpc_call_credentials* creds) {
- grpc_credentials_metadata_request* r =
- static_cast<grpc_credentials_metadata_request*>(
- gpr_zalloc(sizeof(grpc_credentials_metadata_request)));
- r->creds = grpc_call_credentials_ref(creds);
- return r;
-}
-
-void grpc_credentials_metadata_request_destroy(
- grpc_credentials_metadata_request* r) {
- grpc_call_credentials_unref(r->creds);
- grpc_http_response_destroy(&r->response);
- gpr_free(r);
-}
-
-grpc_channel_credentials* grpc_channel_credentials_ref(
- grpc_channel_credentials* creds) {
- if (creds == nullptr) return nullptr;
- gpr_ref(&creds->refcount);
- return creds;
-}
-
-void grpc_channel_credentials_unref(grpc_channel_credentials* creds) {
- if (creds == nullptr) return;
- if (gpr_unref(&creds->refcount)) {
- if (creds->vtable->destruct != nullptr) {
- creds->vtable->destruct(creds);
- }
- gpr_free(creds);
- }
-}
-
void grpc_channel_credentials_release(grpc_channel_credentials* creds) {
GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
- grpc_channel_credentials_unref(creds);
-}
-
-grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds) {
- if (creds == nullptr) return nullptr;
- gpr_ref(&creds->refcount);
- return creds;
-}
-
-void grpc_call_credentials_unref(grpc_call_credentials* creds) {
- if (creds == nullptr) return;
- if (gpr_unref(&creds->refcount)) {
- if (creds->vtable->destruct != nullptr) {
- creds->vtable->destruct(creds);
- }
- gpr_free(creds);
- }
+ if (creds) creds->Unref();
}
void grpc_call_credentials_release(grpc_call_credentials* creds) {
GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
- grpc_call_credentials_unref(creds);
-}
-
-bool grpc_call_credentials_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata, grpc_error** error) {
- if (creds == nullptr || creds->vtable->get_request_metadata == nullptr) {
- return true;
- }
- return creds->vtable->get_request_metadata(creds, pollent, context, md_array,
- on_request_metadata, error);
-}
-
-void grpc_call_credentials_cancel_get_request_metadata(
- grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
- if (creds == nullptr ||
- creds->vtable->cancel_get_request_metadata == nullptr) {
- return;
- }
- creds->vtable->cancel_get_request_metadata(creds, md_array, error);
-}
-
-grpc_security_status grpc_channel_credentials_create_security_connector(
- grpc_channel_credentials* channel_creds, const char* target,
- const grpc_channel_args* args, grpc_channel_security_connector** sc,
- grpc_channel_args** new_args) {
- *new_args = nullptr;
- if (channel_creds == nullptr) {
- return GRPC_SECURITY_ERROR;
- }
- GPR_ASSERT(channel_creds->vtable->create_security_connector != nullptr);
- return channel_creds->vtable->create_security_connector(
- channel_creds, nullptr, target, args, sc, new_args);
-}
-
-grpc_channel_credentials*
-grpc_channel_credentials_duplicate_without_call_credentials(
- grpc_channel_credentials* channel_creds) {
- if (channel_creds != nullptr && channel_creds->vtable != nullptr &&
- channel_creds->vtable->duplicate_without_call_credentials != nullptr) {
- return channel_creds->vtable->duplicate_without_call_credentials(
- channel_creds);
- } else {
- return grpc_channel_credentials_ref(channel_creds);
- }
+ if (creds) creds->Unref();
}
static void credentials_pointer_arg_destroy(void* p) {
- grpc_channel_credentials_unref(static_cast<grpc_channel_credentials*>(p));
+ static_cast<grpc_channel_credentials*>(p)->Unref();
}
static void* credentials_pointer_arg_copy(void* p) {
- return grpc_channel_credentials_ref(
- static_cast<grpc_channel_credentials*>(p));
+ return static_cast<grpc_channel_credentials*>(p)->Ref().release();
}
static int credentials_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); }
@@ -191,63 +95,35 @@ grpc_channel_credentials* grpc_channel_credentials_find_in_args(
return nullptr;
}
-grpc_server_credentials* grpc_server_credentials_ref(
- grpc_server_credentials* creds) {
- if (creds == nullptr) return nullptr;
- gpr_ref(&creds->refcount);
- return creds;
-}
-
-void grpc_server_credentials_unref(grpc_server_credentials* creds) {
- if (creds == nullptr) return;
- if (gpr_unref(&creds->refcount)) {
- if (creds->vtable->destruct != nullptr) {
- creds->vtable->destruct(creds);
- }
- if (creds->processor.destroy != nullptr &&
- creds->processor.state != nullptr) {
- creds->processor.destroy(creds->processor.state);
- }
- gpr_free(creds);
- }
-}
-
void grpc_server_credentials_release(grpc_server_credentials* creds) {
GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds));
grpc_core::ExecCtx exec_ctx;
- grpc_server_credentials_unref(creds);
+ if (creds) creds->Unref();
}
-grpc_security_status grpc_server_credentials_create_security_connector(
- grpc_server_credentials* creds, grpc_server_security_connector** sc) {
- if (creds == nullptr || creds->vtable->create_security_connector == nullptr) {
- gpr_log(GPR_ERROR, "Server credentials cannot create security context.");
- return GRPC_SECURITY_ERROR;
- }
- return creds->vtable->create_security_connector(creds, sc);
-}
-
-void grpc_server_credentials_set_auth_metadata_processor(
- grpc_server_credentials* creds, grpc_auth_metadata_processor processor) {
+void grpc_server_credentials::set_auth_metadata_processor(
+ const grpc_auth_metadata_processor& processor) {
GRPC_API_TRACE(
"grpc_server_credentials_set_auth_metadata_processor("
"creds=%p, "
"processor=grpc_auth_metadata_processor { process: %p, state: %p })",
- 3, (creds, (void*)(intptr_t)processor.process, processor.state));
- if (creds == nullptr) return;
- if (creds->processor.destroy != nullptr &&
- creds->processor.state != nullptr) {
- creds->processor.destroy(creds->processor.state);
- }
- creds->processor = processor;
+ 3, (this, (void*)(intptr_t)processor.process, processor.state));
+ DestroyProcessor();
+ processor_ = processor;
+}
+
+void grpc_server_credentials_set_auth_metadata_processor(
+ grpc_server_credentials* creds, grpc_auth_metadata_processor processor) {
+ GPR_DEBUG_ASSERT(creds != nullptr);
+ creds->set_auth_metadata_processor(processor);
}
static void server_credentials_pointer_arg_destroy(void* p) {
- grpc_server_credentials_unref(static_cast<grpc_server_credentials*>(p));
+ static_cast<grpc_server_credentials*>(p)->Unref();
}
static void* server_credentials_pointer_arg_copy(void* p) {
- return grpc_server_credentials_ref(static_cast<grpc_server_credentials*>(p));
+ return static_cast<grpc_server_credentials*>(p)->Ref().release();
}
static int server_credentials_pointer_cmp(void* a, void* b) {
diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h
index 3878958b38..4091ef3dfb 100644
--- a/src/core/lib/security/credentials/credentials.h
+++ b/src/core/lib/security/credentials/credentials.h
@@ -26,6 +26,7 @@
#include <grpc/support/sync.h>
#include "src/core/lib/transport/metadata_batch.h"
+#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/http/httpcli.h"
#include "src/core/lib/http/parser.h"
#include "src/core/lib/iomgr/polling_entity.h"
@@ -90,44 +91,46 @@ void grpc_override_well_known_credentials_path_getter(
#define GRPC_ARG_CHANNEL_CREDENTIALS "grpc.channel_credentials"
-typedef struct {
- void (*destruct)(grpc_channel_credentials* c);
-
- grpc_security_status (*create_security_connector)(
- grpc_channel_credentials* c, grpc_call_credentials* call_creds,
+// This type is forward declared as a C struct and we cannot define it as a
+// class. Otherwise, compiler will complain about type mismatch due to
+// -Wmismatched-tags.
+struct grpc_channel_credentials
+ : grpc_core::RefCounted<grpc_channel_credentials> {
+ public:
+ explicit grpc_channel_credentials(const char* type) : type_(type) {}
+ virtual ~grpc_channel_credentials() = default;
+
+ // Creates a security connector for the channel. May also create new channel
+ // args for the channel to be used in place of the passed in const args if
+ // returned non NULL. In that case the caller is responsible for destroying
+ // new_args after channel creation.
+ virtual grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
- grpc_channel_security_connector** sc, grpc_channel_args** new_args);
-
- grpc_channel_credentials* (*duplicate_without_call_credentials)(
- grpc_channel_credentials* c);
-} grpc_channel_credentials_vtable;
-
-struct grpc_channel_credentials {
- const grpc_channel_credentials_vtable* vtable;
- const char* type;
- gpr_refcount refcount;
+ grpc_channel_args** new_args) {
+ // Tell clang-tidy that call_creds cannot be passed as const-ref.
+ call_creds.reset();
+ GRPC_ABSTRACT;
+ }
+
+ // Creates a version of the channel credentials without any attached call
+ // credentials. This can be used in order to open a channel to a non-trusted
+ // gRPC load balancer.
+ virtual grpc_core::RefCountedPtr<grpc_channel_credentials>
+ duplicate_without_call_credentials() {
+ // By default we just increment the refcount.
+ return Ref();
+ }
+
+ const char* type() const { return type_; }
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ private:
+ const char* type_;
};
-grpc_channel_credentials* grpc_channel_credentials_ref(
- grpc_channel_credentials* creds);
-void grpc_channel_credentials_unref(grpc_channel_credentials* creds);
-
-/* Creates a security connector for the channel. May also create new channel
- args for the channel to be used in place of the passed in const args if
- returned non NULL. In that case the caller is responsible for destroying
- new_args after channel creation. */
-grpc_security_status grpc_channel_credentials_create_security_connector(
- grpc_channel_credentials* creds, const char* target,
- const grpc_channel_args* args, grpc_channel_security_connector** sc,
- grpc_channel_args** new_args);
-
-/* Creates a version of the channel credentials without any attached call
- credentials. This can be used in order to open a channel to a non-trusted
- gRPC load balancer. */
-grpc_channel_credentials*
-grpc_channel_credentials_duplicate_without_call_credentials(
- grpc_channel_credentials* creds);
-
/* Util to encapsulate the channel credentials in a channel arg. */
grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials);
@@ -158,44 +161,39 @@ void grpc_credentials_mdelem_array_destroy(grpc_credentials_mdelem_array* list);
/* --- grpc_call_credentials. --- */
-typedef struct {
- void (*destruct)(grpc_call_credentials* c);
- bool (*get_request_metadata)(grpc_call_credentials* c,
- grpc_polling_entity* pollent,
- grpc_auth_metadata_context context,
- grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata,
- grpc_error** error);
- void (*cancel_get_request_metadata)(grpc_call_credentials* c,
- grpc_credentials_mdelem_array* md_array,
- grpc_error* error);
-} grpc_call_credentials_vtable;
-
-struct grpc_call_credentials {
- const grpc_call_credentials_vtable* vtable;
- const char* type;
- gpr_refcount refcount;
+// This type is forward declared as a C struct and we cannot define it as a
+// class. Otherwise, compiler will complain about type mismatch due to
+// -Wmismatched-tags.
+struct grpc_call_credentials
+ : public grpc_core::RefCounted<grpc_call_credentials> {
+ public:
+ explicit grpc_call_credentials(const char* type) : type_(type) {}
+ virtual ~grpc_call_credentials() = default;
+
+ // Returns true if completed synchronously, in which case \a error will
+ // be set to indicate the result. Otherwise, \a on_request_metadata will
+ // be invoked asynchronously when complete. \a md_array will be populated
+ // with the resulting metadata once complete.
+ virtual bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) GRPC_ABSTRACT;
+
+ // Cancels a pending asynchronous operation started by
+ // grpc_call_credentials_get_request_metadata() with the corresponding
+ // value of \a md_array.
+ virtual void cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) GRPC_ABSTRACT;
+
+ const char* type() const { return type_; }
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ private:
+ const char* type_;
};
-grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds);
-void grpc_call_credentials_unref(grpc_call_credentials* creds);
-
-/// Returns true if completed synchronously, in which case \a error will
-/// be set to indicate the result. Otherwise, \a on_request_metadata will
-/// be invoked asynchronously when complete. \a md_array will be populated
-/// with the resulting metadata once complete.
-bool grpc_call_credentials_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata, grpc_error** error);
-
-/// Cancels a pending asynchronous operation started by
-/// grpc_call_credentials_get_request_metadata() with the corresponding
-/// value of \a md_array.
-void grpc_call_credentials_cancel_get_request_metadata(
- grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
- grpc_error* error);
-
/* Metadata-only credentials with the specified key and value where
asynchronicity can be simulated for testing. */
grpc_call_credentials* grpc_md_only_test_credentials_create(
@@ -203,26 +201,40 @@ grpc_call_credentials* grpc_md_only_test_credentials_create(
/* --- grpc_server_credentials. --- */
-typedef struct {
- void (*destruct)(grpc_server_credentials* c);
- grpc_security_status (*create_security_connector)(
- grpc_server_credentials* c, grpc_server_security_connector** sc);
-} grpc_server_credentials_vtable;
-
-struct grpc_server_credentials {
- const grpc_server_credentials_vtable* vtable;
- const char* type;
- gpr_refcount refcount;
- grpc_auth_metadata_processor processor;
-};
+// This type is forward declared as a C struct and we cannot define it as a
+// class. Otherwise, compiler will complain about type mismatch due to
+// -Wmismatched-tags.
+struct grpc_server_credentials
+ : public grpc_core::RefCounted<grpc_server_credentials> {
+ public:
+ explicit grpc_server_credentials(const char* type) : type_(type) {}
-grpc_security_status grpc_server_credentials_create_security_connector(
- grpc_server_credentials* creds, grpc_server_security_connector** sc);
+ virtual ~grpc_server_credentials() { DestroyProcessor(); }
-grpc_server_credentials* grpc_server_credentials_ref(
- grpc_server_credentials* creds);
+ virtual grpc_core::RefCountedPtr<grpc_server_security_connector>
+ create_security_connector() GRPC_ABSTRACT;
-void grpc_server_credentials_unref(grpc_server_credentials* creds);
+ const char* type() const { return type_; }
+
+ const grpc_auth_metadata_processor& auth_metadata_processor() const {
+ return processor_;
+ }
+ void set_auth_metadata_processor(
+ const grpc_auth_metadata_processor& processor);
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ private:
+ void DestroyProcessor() {
+ if (processor_.destroy != nullptr && processor_.state != nullptr) {
+ processor_.destroy(processor_.state);
+ }
+ }
+
+ const char* type_;
+ grpc_auth_metadata_processor processor_ =
+ grpc_auth_metadata_processor(); // Zero-initialize the C struct.
+};
#define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials"
@@ -233,15 +245,27 @@ grpc_server_credentials* grpc_find_server_credentials_in_args(
/* -- Credentials Metadata Request. -- */
-typedef struct {
- grpc_call_credentials* creds;
+struct grpc_credentials_metadata_request {
+ explicit grpc_credentials_metadata_request(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds)
+ : creds(std::move(creds)) {}
+ ~grpc_credentials_metadata_request() {
+ grpc_http_response_destroy(&response);
+ }
+
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_http_response response;
-} grpc_credentials_metadata_request;
+};
-grpc_credentials_metadata_request* grpc_credentials_metadata_request_create(
- grpc_call_credentials* creds);
+inline grpc_credentials_metadata_request*
+grpc_credentials_metadata_request_create(
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds) {
+ return grpc_core::New<grpc_credentials_metadata_request>(std::move(creds));
+}
-void grpc_credentials_metadata_request_destroy(
- grpc_credentials_metadata_request* r);
+inline void grpc_credentials_metadata_request_destroy(
+ grpc_credentials_metadata_request* r) {
+ grpc_core::Delete(r);
+}
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_CREDENTIALS_H */
diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc
index d3e0e8c816..337dd7679f 100644
--- a/src/core/lib/security/credentials/fake/fake_credentials.cc
+++ b/src/core/lib/security/credentials/fake/fake_credentials.cc
@@ -33,49 +33,45 @@
/* -- Fake transport security credentials. -- */
-static grpc_security_status fake_transport_security_create_security_connector(
- grpc_channel_credentials* c, grpc_call_credentials* call_creds,
- const char* target, const grpc_channel_args* args,
- grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
- *sc =
- grpc_fake_channel_security_connector_create(c, call_creds, target, args);
- return GRPC_SECURITY_OK;
-}
-
-static grpc_security_status
-fake_transport_security_server_create_security_connector(
- grpc_server_credentials* c, grpc_server_security_connector** sc) {
- *sc = grpc_fake_server_security_connector_create(c);
- return GRPC_SECURITY_OK;
-}
+namespace {
+class grpc_fake_channel_credentials final : public grpc_channel_credentials {
+ public:
+ grpc_fake_channel_credentials()
+ : grpc_channel_credentials(
+ GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {}
+ ~grpc_fake_channel_credentials() override = default;
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override {
+ return grpc_fake_channel_security_connector_create(
+ this->Ref(), std::move(call_creds), target, args);
+ }
+};
+
+class grpc_fake_server_credentials final : public grpc_server_credentials {
+ public:
+ grpc_fake_server_credentials()
+ : grpc_server_credentials(
+ GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {}
+ ~grpc_fake_server_credentials() override = default;
+
+ grpc_core::RefCountedPtr<grpc_server_security_connector>
+ create_security_connector() override {
+ return grpc_fake_server_security_connector_create(this->Ref());
+ }
+};
+} // namespace
-static grpc_channel_credentials_vtable
- fake_transport_security_credentials_vtable = {
- nullptr, fake_transport_security_create_security_connector, nullptr};
-
-static grpc_server_credentials_vtable
- fake_transport_security_server_credentials_vtable = {
- nullptr, fake_transport_security_server_create_security_connector};
-
-grpc_channel_credentials* grpc_fake_transport_security_credentials_create(
- void) {
- grpc_channel_credentials* c = static_cast<grpc_channel_credentials*>(
- gpr_zalloc(sizeof(grpc_channel_credentials)));
- c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY;
- c->vtable = &fake_transport_security_credentials_vtable;
- gpr_ref_init(&c->refcount, 1);
- return c;
+grpc_channel_credentials* grpc_fake_transport_security_credentials_create() {
+ return grpc_core::New<grpc_fake_channel_credentials>();
}
-grpc_server_credentials* grpc_fake_transport_security_server_credentials_create(
- void) {
- grpc_server_credentials* c = static_cast<grpc_server_credentials*>(
- gpr_malloc(sizeof(grpc_server_credentials)));
- memset(c, 0, sizeof(grpc_server_credentials));
- c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY;
- gpr_ref_init(&c->refcount, 1);
- c->vtable = &fake_transport_security_server_credentials_vtable;
- return c;
+grpc_server_credentials*
+grpc_fake_transport_security_server_credentials_create() {
+ return grpc_core::New<grpc_fake_server_credentials>();
}
grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) {
@@ -92,46 +88,25 @@ const char* grpc_fake_transport_get_expected_targets(
/* -- Metadata-only test credentials. -- */
-static void md_only_test_destruct(grpc_call_credentials* creds) {
- grpc_md_only_test_credentials* c =
- reinterpret_cast<grpc_md_only_test_credentials*>(creds);
- GRPC_MDELEM_UNREF(c->md);
-}
-
-static bool md_only_test_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata, grpc_error** error) {
- grpc_md_only_test_credentials* c =
- reinterpret_cast<grpc_md_only_test_credentials*>(creds);
- grpc_credentials_mdelem_array_add(md_array, c->md);
- if (c->is_async) {
+bool grpc_md_only_test_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
+ grpc_credentials_mdelem_array_add(md_array, md_);
+ if (is_async_) {
GRPC_CLOSURE_SCHED(on_request_metadata, GRPC_ERROR_NONE);
return false;
}
return true;
}
-static void md_only_test_cancel_get_request_metadata(
- grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
+void grpc_md_only_test_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable md_only_test_vtable = {
- md_only_test_destruct, md_only_test_get_request_metadata,
- md_only_test_cancel_get_request_metadata};
-
grpc_call_credentials* grpc_md_only_test_credentials_create(
const char* md_key, const char* md_value, bool is_async) {
- grpc_md_only_test_credentials* c =
- static_cast<grpc_md_only_test_credentials*>(
- gpr_zalloc(sizeof(grpc_md_only_test_credentials)));
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
- c->base.vtable = &md_only_test_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- c->md = grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key),
- grpc_slice_from_copied_string(md_value));
- c->is_async = is_async;
- return &c->base;
+ return grpc_core::New<grpc_md_only_test_credentials>(md_key, md_value,
+ is_async);
}
diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h
index e89e6e24cc..b7f6a1909f 100644
--- a/src/core/lib/security/credentials/fake/fake_credentials.h
+++ b/src/core/lib/security/credentials/fake/fake_credentials.h
@@ -55,10 +55,28 @@ const char* grpc_fake_transport_get_expected_targets(
/* -- Metadata-only Test credentials. -- */
-typedef struct {
- grpc_call_credentials base;
- grpc_mdelem md;
- bool is_async;
-} grpc_md_only_test_credentials;
+class grpc_md_only_test_credentials : public grpc_call_credentials {
+ public:
+ grpc_md_only_test_credentials(const char* md_key, const char* md_value,
+ bool is_async)
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2),
+ md_(grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key),
+ grpc_slice_from_copied_string(md_value))),
+ is_async_(is_async) {}
+ ~grpc_md_only_test_credentials() override { GRPC_MDELEM_UNREF(md_); }
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ private:
+ grpc_mdelem md_;
+ bool is_async_;
+};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_FAKE_FAKE_CREDENTIALS_H */
diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc
index 0674540d01..a86a17d586 100644
--- a/src/core/lib/security/credentials/google_default/google_default_credentials.cc
+++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc
@@ -30,6 +30,7 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/http/httpcli.h"
#include "src/core/lib/http/parser.h"
#include "src/core/lib/iomgr/load_file.h"
@@ -72,20 +73,11 @@ typedef struct {
grpc_http_response response;
} metadata_server_detector;
-static void google_default_credentials_destruct(
- grpc_channel_credentials* creds) {
- grpc_google_default_channel_credentials* c =
- reinterpret_cast<grpc_google_default_channel_credentials*>(creds);
- grpc_channel_credentials_unref(c->alts_creds);
- grpc_channel_credentials_unref(c->ssl_creds);
-}
-
-static grpc_security_status google_default_create_security_connector(
- grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_google_default_channel_credentials::create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
- grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
- grpc_google_default_channel_credentials* c =
- reinterpret_cast<grpc_google_default_channel_credentials*>(creds);
+ grpc_channel_args** new_args) {
bool is_grpclb_load_balancer = grpc_channel_arg_get_bool(
grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER),
false);
@@ -95,22 +87,22 @@ static grpc_security_status google_default_create_security_connector(
false);
bool use_alts =
is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer;
- grpc_security_status status = GRPC_SECURITY_ERROR;
/* Return failure if ALTS is selected but not running on GCE. */
if (use_alts && !g_is_on_gce) {
gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE.");
- goto end;
+ return nullptr;
}
- status = use_alts ? c->alts_creds->vtable->create_security_connector(
- c->alts_creds, call_creds, target, args, sc, new_args)
- : c->ssl_creds->vtable->create_security_connector(
- c->ssl_creds, call_creds, target, args, sc, new_args);
-/* grpclb-specific channel args are removed from the channel args set
- * to ensure backends and fallback adresses will have the same set of channel
- * args. By doing that, it guarantees the connections to backends will not be
- * torn down and re-connected when switching in and out of fallback mode.
- */
-end:
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
+ use_alts ? alts_creds_->create_security_connector(call_creds, target,
+ args, new_args)
+ : ssl_creds_->create_security_connector(call_creds, target, args,
+ new_args);
+ /* grpclb-specific channel args are removed from the channel args set
+ * to ensure backends and fallback adresses will have the same set of channel
+ * args. By doing that, it guarantees the connections to backends will not be
+ * torn down and re-connected when switching in and out of fallback mode.
+ */
if (use_alts) {
static const char* args_to_remove[] = {
GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER,
@@ -119,13 +111,9 @@ end:
*new_args = grpc_channel_args_copy_and_add_and_remove(
args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0);
}
- return status;
+ return sc;
}
-static grpc_channel_credentials_vtable google_default_credentials_vtable = {
- google_default_credentials_destruct,
- google_default_create_security_connector, nullptr};
-
static void on_metadata_server_detection_http_response(void* user_data,
grpc_error* error) {
metadata_server_detector* detector =
@@ -215,11 +203,11 @@ static int is_metadata_server_reachable() {
/* Takes ownership of creds_path if not NULL. */
static grpc_error* create_default_creds_from_path(
- char* creds_path, grpc_call_credentials** creds) {
+ char* creds_path, grpc_core::RefCountedPtr<grpc_call_credentials>* creds) {
grpc_json* json = nullptr;
grpc_auth_json_key key;
grpc_auth_refresh_token token;
- grpc_call_credentials* result = nullptr;
+ grpc_core::RefCountedPtr<grpc_call_credentials> result;
grpc_slice creds_data = grpc_empty_slice();
grpc_error* error = GRPC_ERROR_NONE;
if (creds_path == nullptr) {
@@ -276,9 +264,9 @@ end:
return error;
}
-grpc_channel_credentials* grpc_google_default_credentials_create(void) {
+grpc_channel_credentials* grpc_google_default_credentials_create() {
grpc_channel_credentials* result = nullptr;
- grpc_call_credentials* call_creds = nullptr;
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds;
grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Failed to create Google credentials");
grpc_error* err;
@@ -316,7 +304,8 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) {
gpr_mu_unlock(&g_state_mu);
if (g_metadata_server_available) {
- call_creds = grpc_google_compute_engine_credentials_create(nullptr);
+ call_creds = grpc_core::RefCountedPtr<grpc_call_credentials>(
+ grpc_google_compute_engine_credentials_create(nullptr));
if (call_creds == nullptr) {
error = grpc_error_add_child(
error, GRPC_ERROR_CREATE_FROM_STATIC_STRING(
@@ -327,23 +316,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) {
end:
if (call_creds != nullptr) {
/* Create google default credentials. */
- auto creds = static_cast<grpc_google_default_channel_credentials*>(
- gpr_zalloc(sizeof(grpc_google_default_channel_credentials)));
- creds->base.vtable = &google_default_credentials_vtable;
- creds->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT;
- gpr_ref_init(&creds->base.refcount, 1);
- creds->ssl_creds =
+ grpc_channel_credentials* ssl_creds =
grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr);
- GPR_ASSERT(creds->ssl_creds != nullptr);
+ GPR_ASSERT(ssl_creds != nullptr);
grpc_alts_credentials_options* options =
grpc_alts_credentials_client_options_create();
- creds->alts_creds = grpc_alts_credentials_create(options);
+ grpc_channel_credentials* alts_creds =
+ grpc_alts_credentials_create(options);
grpc_alts_credentials_options_destroy(options);
- result = grpc_composite_channel_credentials_create(&creds->base, call_creds,
- nullptr);
+ auto creds =
+ grpc_core::MakeRefCounted<grpc_google_default_channel_credentials>(
+ alts_creds != nullptr ? alts_creds->Ref() : nullptr,
+ ssl_creds != nullptr ? ssl_creds->Ref() : nullptr);
+ if (ssl_creds) ssl_creds->Unref();
+ if (alts_creds) alts_creds->Unref();
+ result = grpc_composite_channel_credentials_create(
+ creds.get(), call_creds.get(), nullptr);
GPR_ASSERT(result != nullptr);
- grpc_channel_credentials_unref(&creds->base);
- grpc_call_credentials_unref(call_creds);
} else {
gpr_log(GPR_ERROR, "Could not create google default credentials: %s",
grpc_error_string(error));
diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.h b/src/core/lib/security/credentials/google_default/google_default_credentials.h
index b9e2efb04f..bf00f7285a 100644
--- a/src/core/lib/security/credentials/google_default/google_default_credentials.h
+++ b/src/core/lib/security/credentials/google_default/google_default_credentials.h
@@ -21,6 +21,7 @@
#include <grpc/support/port_platform.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/credentials.h"
#define GRPC_GOOGLE_CLOUD_SDK_CONFIG_DIRECTORY "gcloud"
@@ -39,11 +40,33 @@
"/" GRPC_GOOGLE_WELL_KNOWN_CREDENTIALS_FILE
#endif
-typedef struct {
- grpc_channel_credentials base;
- grpc_channel_credentials* alts_creds;
- grpc_channel_credentials* ssl_creds;
-} grpc_google_default_channel_credentials;
+class grpc_google_default_channel_credentials
+ : public grpc_channel_credentials {
+ public:
+ grpc_google_default_channel_credentials(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds,
+ grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds)
+ : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT),
+ alts_creds_(std::move(alts_creds)),
+ ssl_creds_(std::move(ssl_creds)) {}
+
+ ~grpc_google_default_channel_credentials() override = default;
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override;
+
+ const grpc_channel_credentials* alts_creds() const {
+ return alts_creds_.get();
+ }
+ const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); }
+
+ private:
+ grpc_core::RefCountedPtr<grpc_channel_credentials> alts_creds_;
+ grpc_core::RefCountedPtr<grpc_channel_credentials> ssl_creds_;
+};
namespace grpc_core {
namespace internal {
diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc
index 5d92fa88c4..5cd561f676 100644
--- a/src/core/lib/security/credentials/iam/iam_credentials.cc
+++ b/src/core/lib/security/credentials/iam/iam_credentials.cc
@@ -22,6 +22,7 @@
#include <string.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/surface/api_trace.h"
#include <grpc/support/alloc.h>
@@ -29,32 +30,37 @@
#include <grpc/support/string_util.h>
#include <grpc/support/sync.h>
-static void iam_destruct(grpc_call_credentials* creds) {
- grpc_google_iam_credentials* c =
- reinterpret_cast<grpc_google_iam_credentials*>(creds);
- grpc_credentials_mdelem_array_destroy(&c->md_array);
+grpc_google_iam_credentials::~grpc_google_iam_credentials() {
+ grpc_credentials_mdelem_array_destroy(&md_array_);
}
-static bool iam_get_request_metadata(grpc_call_credentials* creds,
- grpc_polling_entity* pollent,
- grpc_auth_metadata_context context,
- grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata,
- grpc_error** error) {
- grpc_google_iam_credentials* c =
- reinterpret_cast<grpc_google_iam_credentials*>(creds);
- grpc_credentials_mdelem_array_append(md_array, &c->md_array);
+bool grpc_google_iam_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
+ grpc_credentials_mdelem_array_append(md_array, &md_array_);
return true;
}
-static void iam_cancel_get_request_metadata(
- grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
+void grpc_google_iam_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable iam_vtable = {
- iam_destruct, iam_get_request_metadata, iam_cancel_get_request_metadata};
+grpc_google_iam_credentials::grpc_google_iam_credentials(
+ const char* token, const char* authority_selector)
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM) {
+ grpc_mdelem md = grpc_mdelem_from_slices(
+ grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY),
+ grpc_slice_from_copied_string(token));
+ grpc_credentials_mdelem_array_add(&md_array_, md);
+ GRPC_MDELEM_UNREF(md);
+ md = grpc_mdelem_from_slices(
+ grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY),
+ grpc_slice_from_copied_string(authority_selector));
+ grpc_credentials_mdelem_array_add(&md_array_, md);
+ GRPC_MDELEM_UNREF(md);
+}
grpc_call_credentials* grpc_google_iam_credentials_create(
const char* token, const char* authority_selector, void* reserved) {
@@ -66,21 +72,7 @@ grpc_call_credentials* grpc_google_iam_credentials_create(
GPR_ASSERT(reserved == nullptr);
GPR_ASSERT(token != nullptr);
GPR_ASSERT(authority_selector != nullptr);
- grpc_google_iam_credentials* c =
- static_cast<grpc_google_iam_credentials*>(gpr_zalloc(sizeof(*c)));
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_IAM;
- c->base.vtable = &iam_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- grpc_mdelem md = grpc_mdelem_from_slices(
- grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY),
- grpc_slice_from_copied_string(token));
- grpc_credentials_mdelem_array_add(&c->md_array, md);
- GRPC_MDELEM_UNREF(md);
- md = grpc_mdelem_from_slices(
- grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY),
- grpc_slice_from_copied_string(authority_selector));
- grpc_credentials_mdelem_array_add(&c->md_array, md);
- GRPC_MDELEM_UNREF(md);
-
- return &c->base;
+ return grpc_core::MakeRefCounted<grpc_google_iam_credentials>(
+ token, authority_selector)
+ .release();
}
diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h
index a45710fe0f..36f5ee8930 100644
--- a/src/core/lib/security/credentials/iam/iam_credentials.h
+++ b/src/core/lib/security/credentials/iam/iam_credentials.h
@@ -23,9 +23,23 @@
#include "src/core/lib/security/credentials/credentials.h"
-typedef struct {
- grpc_call_credentials base;
- grpc_credentials_mdelem_array md_array;
-} grpc_google_iam_credentials;
+class grpc_google_iam_credentials : public grpc_call_credentials {
+ public:
+ grpc_google_iam_credentials(const char* token,
+ const char* authority_selector);
+ ~grpc_google_iam_credentials() override;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ private:
+ grpc_credentials_mdelem_array md_array_;
+};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_IAM_IAM_CREDENTIALS_H */
diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc
index 05c08a68b0..f2591a1ea5 100644
--- a/src/core/lib/security/credentials/jwt/jwt_credentials.cc
+++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc
@@ -23,6 +23,8 @@
#include <inttypes.h>
#include <string.h>
+#include "src/core/lib/gprpp/ref_counted.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/surface/api_trace.h"
#include <grpc/support/alloc.h>
@@ -30,71 +32,66 @@
#include <grpc/support/string_util.h>
#include <grpc/support/sync.h>
-static void jwt_reset_cache(grpc_service_account_jwt_access_credentials* c) {
- GRPC_MDELEM_UNREF(c->cached.jwt_md);
- c->cached.jwt_md = GRPC_MDNULL;
- if (c->cached.service_url != nullptr) {
- gpr_free(c->cached.service_url);
- c->cached.service_url = nullptr;
+void grpc_service_account_jwt_access_credentials::reset_cache() {
+ GRPC_MDELEM_UNREF(cached_.jwt_md);
+ cached_.jwt_md = GRPC_MDNULL;
+ if (cached_.service_url != nullptr) {
+ gpr_free(cached_.service_url);
+ cached_.service_url = nullptr;
}
- c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME);
+ cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME);
}
-static void jwt_destruct(grpc_call_credentials* creds) {
- grpc_service_account_jwt_access_credentials* c =
- reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds);
- grpc_auth_json_key_destruct(&c->key);
- jwt_reset_cache(c);
- gpr_mu_destroy(&c->cache_mu);
+grpc_service_account_jwt_access_credentials::
+ ~grpc_service_account_jwt_access_credentials() {
+ grpc_auth_json_key_destruct(&key_);
+ reset_cache();
+ gpr_mu_destroy(&cache_mu_);
}
-static bool jwt_get_request_metadata(grpc_call_credentials* creds,
- grpc_polling_entity* pollent,
- grpc_auth_metadata_context context,
- grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata,
- grpc_error** error) {
- grpc_service_account_jwt_access_credentials* c =
- reinterpret_cast<grpc_service_account_jwt_access_credentials*>(creds);
+bool grpc_service_account_jwt_access_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
gpr_timespec refresh_threshold = gpr_time_from_seconds(
GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN);
/* See if we can return a cached jwt. */
grpc_mdelem jwt_md = GRPC_MDNULL;
{
- gpr_mu_lock(&c->cache_mu);
- if (c->cached.service_url != nullptr &&
- strcmp(c->cached.service_url, context.service_url) == 0 &&
- !GRPC_MDISNULL(c->cached.jwt_md) &&
- (gpr_time_cmp(gpr_time_sub(c->cached.jwt_expiration,
- gpr_now(GPR_CLOCK_REALTIME)),
- refresh_threshold) > 0)) {
- jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md);
+ gpr_mu_lock(&cache_mu_);
+ if (cached_.service_url != nullptr &&
+ strcmp(cached_.service_url, context.service_url) == 0 &&
+ !GRPC_MDISNULL(cached_.jwt_md) &&
+ (gpr_time_cmp(
+ gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)),
+ refresh_threshold) > 0)) {
+ jwt_md = GRPC_MDELEM_REF(cached_.jwt_md);
}
- gpr_mu_unlock(&c->cache_mu);
+ gpr_mu_unlock(&cache_mu_);
}
if (GRPC_MDISNULL(jwt_md)) {
char* jwt = nullptr;
/* Generate a new jwt. */
- gpr_mu_lock(&c->cache_mu);
- jwt_reset_cache(c);
- jwt = grpc_jwt_encode_and_sign(&c->key, context.service_url,
- c->jwt_lifetime, nullptr);
+ gpr_mu_lock(&cache_mu_);
+ reset_cache();
+ jwt = grpc_jwt_encode_and_sign(&key_, context.service_url, jwt_lifetime_,
+ nullptr);
if (jwt != nullptr) {
char* md_value;
gpr_asprintf(&md_value, "Bearer %s", jwt);
gpr_free(jwt);
- c->cached.jwt_expiration =
- gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c->jwt_lifetime);
- c->cached.service_url = gpr_strdup(context.service_url);
- c->cached.jwt_md = grpc_mdelem_from_slices(
+ cached_.jwt_expiration =
+ gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_);
+ cached_.service_url = gpr_strdup(context.service_url);
+ cached_.jwt_md = grpc_mdelem_from_slices(
grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
grpc_slice_from_copied_string(md_value));
gpr_free(md_value);
- jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md);
+ jwt_md = GRPC_MDELEM_REF(cached_.jwt_md);
}
- gpr_mu_unlock(&c->cache_mu);
+ gpr_mu_unlock(&cache_mu_);
}
if (!GRPC_MDISNULL(jwt_md)) {
@@ -106,29 +103,15 @@ static bool jwt_get_request_metadata(grpc_call_credentials* creds,
return true;
}
-static void jwt_cancel_get_request_metadata(
- grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
+void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable jwt_vtable = {
- jwt_destruct, jwt_get_request_metadata, jwt_cancel_get_request_metadata};
-
-grpc_call_credentials*
-grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
- grpc_auth_json_key key, gpr_timespec token_lifetime) {
- grpc_service_account_jwt_access_credentials* c;
- if (!grpc_auth_json_key_is_valid(&key)) {
- gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation");
- return nullptr;
- }
- c = static_cast<grpc_service_account_jwt_access_credentials*>(
- gpr_zalloc(sizeof(grpc_service_account_jwt_access_credentials)));
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_JWT;
- gpr_ref_init(&c->base.refcount, 1);
- c->base.vtable = &jwt_vtable;
- c->key = key;
+grpc_service_account_jwt_access_credentials::
+ grpc_service_account_jwt_access_credentials(grpc_auth_json_key key,
+ gpr_timespec token_lifetime)
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) {
gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime();
if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) {
gpr_log(GPR_INFO,
@@ -136,10 +119,20 @@ grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
static_cast<int>(max_token_lifetime.tv_sec));
token_lifetime = grpc_max_auth_token_lifetime();
}
- c->jwt_lifetime = token_lifetime;
- gpr_mu_init(&c->cache_mu);
- jwt_reset_cache(c);
- return &c->base;
+ jwt_lifetime_ = token_lifetime;
+ gpr_mu_init(&cache_mu_);
+ reset_cache();
+}
+
+grpc_core::RefCountedPtr<grpc_call_credentials>
+grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
+ grpc_auth_json_key key, gpr_timespec token_lifetime) {
+ if (!grpc_auth_json_key_is_valid(&key)) {
+ gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation");
+ return nullptr;
+ }
+ return grpc_core::MakeRefCounted<grpc_service_account_jwt_access_credentials>(
+ key, token_lifetime);
}
static char* redact_private_key(const char* json_key) {
@@ -182,9 +175,7 @@ grpc_call_credentials* grpc_service_account_jwt_access_credentials_create(
}
GPR_ASSERT(reserved == nullptr);
grpc_core::ExecCtx exec_ctx;
- grpc_call_credentials* creds =
- grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
- grpc_auth_json_key_create_from_string(json_key), token_lifetime);
-
- return creds;
+ return grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
+ grpc_auth_json_key_create_from_string(json_key), token_lifetime)
+ .release();
}
diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h
index 5c3d34aa56..5af909f44d 100644
--- a/src/core/lib/security/credentials/jwt/jwt_credentials.h
+++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h
@@ -24,25 +24,44 @@
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/jwt/json_token.h"
-typedef struct {
- grpc_call_credentials base;
+class grpc_service_account_jwt_access_credentials
+ : public grpc_call_credentials {
+ public:
+ grpc_service_account_jwt_access_credentials(grpc_auth_json_key key,
+ gpr_timespec token_lifetime);
+ ~grpc_service_account_jwt_access_credentials() override;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; }
+ const grpc_auth_json_key& key() const { return key_; }
+
+ private:
+ void reset_cache();
// Have a simple cache for now with just 1 entry. We could have a map based on
// the service_url for a more sophisticated one.
- gpr_mu cache_mu;
+ gpr_mu cache_mu_;
struct {
- grpc_mdelem jwt_md;
- char* service_url;
+ grpc_mdelem jwt_md = GRPC_MDNULL;
+ char* service_url = nullptr;
gpr_timespec jwt_expiration;
- } cached;
+ } cached_;
- grpc_auth_json_key key;
- gpr_timespec jwt_lifetime;
-} grpc_service_account_jwt_access_credentials;
+ grpc_auth_json_key key_;
+ gpr_timespec jwt_lifetime_;
+};
// Private constructor for jwt credentials from an already parsed json key.
// Takes ownership of the key.
-grpc_call_credentials*
+grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_service_account_jwt_access_credentials_create_from_auth_json_key(
grpc_auth_json_key key, gpr_timespec token_lifetime);
diff --git a/src/core/lib/security/credentials/jwt/jwt_verifier.cc b/src/core/lib/security/credentials/jwt/jwt_verifier.cc
index c7d1b36ff0..cdef0f322a 100644
--- a/src/core/lib/security/credentials/jwt/jwt_verifier.cc
+++ b/src/core/lib/security/credentials/jwt/jwt_verifier.cc
@@ -31,7 +31,9 @@
#include <grpc/support/sync.h>
extern "C" {
+#include <openssl/bn.h>
#include <openssl/pem.h>
+#include <openssl/rsa.h>
}
#include "src/core/lib/gpr/string.h"
diff --git a/src/core/lib/security/credentials/local/local_credentials.cc b/src/core/lib/security/credentials/local/local_credentials.cc
index 3ccfa2b908..6f6f95a34a 100644
--- a/src/core/lib/security/credentials/local/local_credentials.cc
+++ b/src/core/lib/security/credentials/local/local_credentials.cc
@@ -29,49 +29,36 @@
#define GRPC_CREDENTIALS_TYPE_LOCAL "Local"
-static void local_credentials_destruct(grpc_channel_credentials* creds) {}
-
-static void local_server_credentials_destruct(grpc_server_credentials* creds) {}
-
-static grpc_security_status local_create_security_connector(
- grpc_channel_credentials* creds,
- grpc_call_credentials* request_metadata_creds, const char* target_name,
- const grpc_channel_args* args, grpc_channel_security_connector** sc,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_local_credentials::create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name, const grpc_channel_args* args,
grpc_channel_args** new_args) {
return grpc_local_channel_security_connector_create(
- creds, request_metadata_creds, args, target_name, sc);
+ this->Ref(), std::move(request_metadata_creds), args, target_name);
}
-static grpc_security_status local_server_create_security_connector(
- grpc_server_credentials* creds, grpc_server_security_connector** sc) {
- return grpc_local_server_security_connector_create(creds, sc);
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_local_server_credentials::create_security_connector() {
+ return grpc_local_server_security_connector_create(this->Ref());
}
-static const grpc_channel_credentials_vtable local_credentials_vtable = {
- local_credentials_destruct, local_create_security_connector,
- /*duplicate_without_call_credentials=*/nullptr};
-
-static const grpc_server_credentials_vtable local_server_credentials_vtable = {
- local_server_credentials_destruct, local_server_create_security_connector};
+grpc_local_credentials::grpc_local_credentials(
+ grpc_local_connect_type connect_type)
+ : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL),
+ connect_type_(connect_type) {}
grpc_channel_credentials* grpc_local_credentials_create(
grpc_local_connect_type connect_type) {
- auto creds = static_cast<grpc_local_credentials*>(
- gpr_zalloc(sizeof(grpc_local_credentials)));
- creds->connect_type = connect_type;
- creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL;
- creds->base.vtable = &local_credentials_vtable;
- gpr_ref_init(&creds->base.refcount, 1);
- return &creds->base;
+ return grpc_core::New<grpc_local_credentials>(connect_type);
}
+grpc_local_server_credentials::grpc_local_server_credentials(
+ grpc_local_connect_type connect_type)
+ : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL),
+ connect_type_(connect_type) {}
+
grpc_server_credentials* grpc_local_server_credentials_create(
grpc_local_connect_type connect_type) {
- auto creds = static_cast<grpc_local_server_credentials*>(
- gpr_zalloc(sizeof(grpc_local_server_credentials)));
- creds->connect_type = connect_type;
- creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL;
- creds->base.vtable = &local_server_credentials_vtable;
- gpr_ref_init(&creds->base.refcount, 1);
- return &creds->base;
+ return grpc_core::New<grpc_local_server_credentials>(connect_type);
}
diff --git a/src/core/lib/security/credentials/local/local_credentials.h b/src/core/lib/security/credentials/local/local_credentials.h
index 47358b04bc..60a8a4f64c 100644
--- a/src/core/lib/security/credentials/local/local_credentials.h
+++ b/src/core/lib/security/credentials/local/local_credentials.h
@@ -25,16 +25,37 @@
#include "src/core/lib/security/credentials/credentials.h"
-/* Main struct for grpc local channel credential. */
-typedef struct grpc_local_credentials {
- grpc_channel_credentials base;
- grpc_local_connect_type connect_type;
-} grpc_local_credentials;
-
-/* Main struct for grpc local server credential. */
-typedef struct grpc_local_server_credentials {
- grpc_server_credentials base;
- grpc_local_connect_type connect_type;
-} grpc_local_server_credentials;
+/* Main class for grpc local channel credential. */
+class grpc_local_credentials final : public grpc_channel_credentials {
+ public:
+ explicit grpc_local_credentials(grpc_local_connect_type connect_type);
+ ~grpc_local_credentials() override = default;
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override;
+
+ grpc_local_connect_type connect_type() const { return connect_type_; }
+
+ private:
+ grpc_local_connect_type connect_type_;
+};
+
+/* Main class for grpc local server credential. */
+class grpc_local_server_credentials final : public grpc_server_credentials {
+ public:
+ explicit grpc_local_server_credentials(grpc_local_connect_type connect_type);
+ ~grpc_local_server_credentials() override = default;
+
+ grpc_core::RefCountedPtr<grpc_server_security_connector>
+ create_security_connector() override;
+
+ grpc_local_connect_type connect_type() const { return connect_type_; }
+
+ private:
+ grpc_local_connect_type connect_type_;
+};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_LOCAL_LOCAL_CREDENTIALS_H */
diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc
index 44b093557f..ad63b01e75 100644
--- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc
+++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc
@@ -22,6 +22,7 @@
#include <string.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/util/json_util.h"
#include "src/core/lib/surface/api_trace.h"
@@ -105,13 +106,12 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) {
// Oauth2 Token Fetcher credentials.
//
-static void oauth2_token_fetcher_destruct(grpc_call_credentials* creds) {
- grpc_oauth2_token_fetcher_credentials* c =
- reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
- GRPC_MDELEM_UNREF(c->access_token_md);
- gpr_mu_destroy(&c->mu);
- grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&c->pollent));
- grpc_httpcli_context_destroy(&c->httpcli_context);
+grpc_oauth2_token_fetcher_credentials::
+ ~grpc_oauth2_token_fetcher_credentials() {
+ GRPC_MDELEM_UNREF(access_token_md_);
+ gpr_mu_destroy(&mu_);
+ grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_));
+ grpc_httpcli_context_destroy(&httpcli_context_);
}
grpc_credentials_status
@@ -209,25 +209,29 @@ static void on_oauth2_token_fetcher_http_response(void* user_data,
grpc_credentials_metadata_request* r =
static_cast<grpc_credentials_metadata_request*>(user_data);
grpc_oauth2_token_fetcher_credentials* c =
- reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds);
+ reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(r->creds.get());
+ c->on_http_response(r, error);
+}
+
+void grpc_oauth2_token_fetcher_credentials::on_http_response(
+ grpc_credentials_metadata_request* r, grpc_error* error) {
grpc_mdelem access_token_md = GRPC_MDNULL;
grpc_millis token_lifetime;
grpc_credentials_status status =
grpc_oauth2_token_fetcher_credentials_parse_server_response(
&r->response, &access_token_md, &token_lifetime);
// Update cache and grab list of pending requests.
- gpr_mu_lock(&c->mu);
- c->token_fetch_pending = false;
- c->access_token_md = GRPC_MDELEM_REF(access_token_md);
- c->token_expiration =
+ gpr_mu_lock(&mu_);
+ token_fetch_pending_ = false;
+ access_token_md_ = GRPC_MDELEM_REF(access_token_md);
+ token_expiration_ =
status == GRPC_CREDENTIALS_OK
? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
gpr_time_from_millis(token_lifetime, GPR_TIMESPAN))
: gpr_inf_past(GPR_CLOCK_MONOTONIC);
- grpc_oauth2_pending_get_request_metadata* pending_request =
- c->pending_requests;
- c->pending_requests = nullptr;
- gpr_mu_unlock(&c->mu);
+ grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_;
+ pending_requests_ = nullptr;
+ gpr_mu_unlock(&mu_);
// Invoke callbacks for all pending requests.
while (pending_request != nullptr) {
if (status == GRPC_CREDENTIALS_OK) {
@@ -239,42 +243,40 @@ static void on_oauth2_token_fetcher_http_response(void* user_data,
}
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, error);
grpc_polling_entity_del_from_pollset_set(
- pending_request->pollent, grpc_polling_entity_pollset_set(&c->pollent));
+ pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_));
grpc_oauth2_pending_get_request_metadata* prev = pending_request;
pending_request = pending_request->next;
gpr_free(prev);
}
GRPC_MDELEM_UNREF(access_token_md);
- grpc_call_credentials_unref(r->creds);
+ Unref();
grpc_credentials_metadata_request_destroy(r);
}
-static bool oauth2_token_fetcher_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata, grpc_error** error) {
- grpc_oauth2_token_fetcher_credentials* c =
- reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
+bool grpc_oauth2_token_fetcher_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
// Check if we can use the cached token.
grpc_millis refresh_threshold =
GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC;
grpc_mdelem cached_access_token_md = GRPC_MDNULL;
- gpr_mu_lock(&c->mu);
- if (!GRPC_MDISNULL(c->access_token_md) &&
+ gpr_mu_lock(&mu_);
+ if (!GRPC_MDISNULL(access_token_md_) &&
gpr_time_cmp(
- gpr_time_sub(c->token_expiration, gpr_now(GPR_CLOCK_MONOTONIC)),
+ gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)),
gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS,
GPR_TIMESPAN)) > 0) {
- cached_access_token_md = GRPC_MDELEM_REF(c->access_token_md);
+ cached_access_token_md = GRPC_MDELEM_REF(access_token_md_);
}
if (!GRPC_MDISNULL(cached_access_token_md)) {
- gpr_mu_unlock(&c->mu);
+ gpr_mu_unlock(&mu_);
grpc_credentials_mdelem_array_add(md_array, cached_access_token_md);
GRPC_MDELEM_UNREF(cached_access_token_md);
return true;
}
// Couldn't get the token from the cache.
- // Add request to c->pending_requests and start a new fetch if needed.
+ // Add request to pending_requests_ and start a new fetch if needed.
grpc_oauth2_pending_get_request_metadata* pending_request =
static_cast<grpc_oauth2_pending_get_request_metadata*>(
gpr_malloc(sizeof(*pending_request)));
@@ -282,41 +284,37 @@ static bool oauth2_token_fetcher_get_request_metadata(
pending_request->on_request_metadata = on_request_metadata;
pending_request->pollent = pollent;
grpc_polling_entity_add_to_pollset_set(
- pollent, grpc_polling_entity_pollset_set(&c->pollent));
- pending_request->next = c->pending_requests;
- c->pending_requests = pending_request;
+ pollent, grpc_polling_entity_pollset_set(&pollent_));
+ pending_request->next = pending_requests_;
+ pending_requests_ = pending_request;
bool start_fetch = false;
- if (!c->token_fetch_pending) {
- c->token_fetch_pending = true;
+ if (!token_fetch_pending_) {
+ token_fetch_pending_ = true;
start_fetch = true;
}
- gpr_mu_unlock(&c->mu);
+ gpr_mu_unlock(&mu_);
if (start_fetch) {
- grpc_call_credentials_ref(creds);
- c->fetch_func(grpc_credentials_metadata_request_create(creds),
- &c->httpcli_context, &c->pollent,
- on_oauth2_token_fetcher_http_response,
- grpc_core::ExecCtx::Get()->Now() + refresh_threshold);
+ Ref().release();
+ fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()),
+ &httpcli_context_, &pollent_,
+ on_oauth2_token_fetcher_http_response,
+ grpc_core::ExecCtx::Get()->Now() + refresh_threshold);
}
return false;
}
-static void oauth2_token_fetcher_cancel_get_request_metadata(
- grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
- grpc_oauth2_token_fetcher_credentials* c =
- reinterpret_cast<grpc_oauth2_token_fetcher_credentials*>(creds);
- gpr_mu_lock(&c->mu);
+void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
+ gpr_mu_lock(&mu_);
grpc_oauth2_pending_get_request_metadata* prev = nullptr;
- grpc_oauth2_pending_get_request_metadata* pending_request =
- c->pending_requests;
+ grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_;
while (pending_request != nullptr) {
if (pending_request->md_array == md_array) {
// Remove matching pending request from the list.
if (prev != nullptr) {
prev->next = pending_request->next;
} else {
- c->pending_requests = pending_request->next;
+ pending_requests_ = pending_request->next;
}
// Invoke the callback immediately with an error.
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata,
@@ -327,96 +325,89 @@ static void oauth2_token_fetcher_cancel_get_request_metadata(
prev = pending_request;
pending_request = pending_request->next;
}
- gpr_mu_unlock(&c->mu);
+ gpr_mu_unlock(&mu_);
GRPC_ERROR_UNREF(error);
}
-static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials* c,
- grpc_fetch_oauth2_func fetch_func) {
- memset(c, 0, sizeof(grpc_oauth2_token_fetcher_credentials));
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
- gpr_ref_init(&c->base.refcount, 1);
- gpr_mu_init(&c->mu);
- c->token_expiration = gpr_inf_past(GPR_CLOCK_MONOTONIC);
- c->fetch_func = fetch_func;
- c->pollent =
- grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create());
- grpc_httpcli_context_init(&c->httpcli_context);
+grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials()
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2),
+ token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)),
+ pollent_(grpc_polling_entity_create_from_pollset_set(
+ grpc_pollset_set_create())) {
+ gpr_mu_init(&mu_);
+ grpc_httpcli_context_init(&httpcli_context_);
}
//
// Google Compute Engine credentials.
//
-static grpc_call_credentials_vtable compute_engine_vtable = {
- oauth2_token_fetcher_destruct, oauth2_token_fetcher_get_request_metadata,
- oauth2_token_fetcher_cancel_get_request_metadata};
+namespace {
+
+class grpc_compute_engine_token_fetcher_credentials
+ : public grpc_oauth2_token_fetcher_credentials {
+ public:
+ grpc_compute_engine_token_fetcher_credentials() = default;
+ ~grpc_compute_engine_token_fetcher_credentials() override = default;
+
+ protected:
+ void fetch_oauth2(grpc_credentials_metadata_request* metadata_req,
+ grpc_httpcli_context* http_context,
+ grpc_polling_entity* pollent,
+ grpc_iomgr_cb_func response_cb,
+ grpc_millis deadline) override {
+ grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"};
+ grpc_httpcli_request request;
+ memset(&request, 0, sizeof(grpc_httpcli_request));
+ request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST;
+ request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH;
+ request.http.hdr_count = 1;
+ request.http.hdrs = &header;
+ /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host
+ channel. This would allow us to cancel an authentication query when under
+ extreme memory pressure. */
+ grpc_resource_quota* resource_quota =
+ grpc_resource_quota_create("oauth2_credentials");
+ grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline,
+ GRPC_CLOSURE_CREATE(response_cb, metadata_req,
+ grpc_schedule_on_exec_ctx),
+ &metadata_req->response);
+ grpc_resource_quota_unref_internal(resource_quota);
+ }
+};
-static void compute_engine_fetch_oauth2(
- grpc_credentials_metadata_request* metadata_req,
- grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent,
- grpc_iomgr_cb_func response_cb, grpc_millis deadline) {
- grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"};
- grpc_httpcli_request request;
- memset(&request, 0, sizeof(grpc_httpcli_request));
- request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST;
- request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH;
- request.http.hdr_count = 1;
- request.http.hdrs = &header;
- /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host
- channel. This would allow us to cancel an authentication query when under
- extreme memory pressure. */
- grpc_resource_quota* resource_quota =
- grpc_resource_quota_create("oauth2_credentials");
- grpc_httpcli_get(
- httpcli_context, pollent, resource_quota, &request, deadline,
- GRPC_CLOSURE_CREATE(response_cb, metadata_req, grpc_schedule_on_exec_ctx),
- &metadata_req->response);
- grpc_resource_quota_unref_internal(resource_quota);
-}
+} // namespace
grpc_call_credentials* grpc_google_compute_engine_credentials_create(
void* reserved) {
- grpc_oauth2_token_fetcher_credentials* c =
- static_cast<grpc_oauth2_token_fetcher_credentials*>(
- gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials)));
GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1,
(reserved));
GPR_ASSERT(reserved == nullptr);
- init_oauth2_token_fetcher(c, compute_engine_fetch_oauth2);
- c->base.vtable = &compute_engine_vtable;
- return &c->base;
+ return grpc_core::MakeRefCounted<
+ grpc_compute_engine_token_fetcher_credentials>()
+ .release();
}
//
// Google Refresh Token credentials.
//
-static void refresh_token_destruct(grpc_call_credentials* creds) {
- grpc_google_refresh_token_credentials* c =
- reinterpret_cast<grpc_google_refresh_token_credentials*>(creds);
- grpc_auth_refresh_token_destruct(&c->refresh_token);
- oauth2_token_fetcher_destruct(&c->base.base);
+grpc_google_refresh_token_credentials::
+ ~grpc_google_refresh_token_credentials() {
+ grpc_auth_refresh_token_destruct(&refresh_token_);
}
-static grpc_call_credentials_vtable refresh_token_vtable = {
- refresh_token_destruct, oauth2_token_fetcher_get_request_metadata,
- oauth2_token_fetcher_cancel_get_request_metadata};
-
-static void refresh_token_fetch_oauth2(
+void grpc_google_refresh_token_credentials::fetch_oauth2(
grpc_credentials_metadata_request* metadata_req,
grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent,
grpc_iomgr_cb_func response_cb, grpc_millis deadline) {
- grpc_google_refresh_token_credentials* c =
- reinterpret_cast<grpc_google_refresh_token_credentials*>(
- metadata_req->creds);
grpc_http_header header = {(char*)"Content-Type",
(char*)"application/x-www-form-urlencoded"};
grpc_httpcli_request request;
char* body = nullptr;
gpr_asprintf(&body, GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING,
- c->refresh_token.client_id, c->refresh_token.client_secret,
- c->refresh_token.refresh_token);
+ refresh_token_.client_id, refresh_token_.client_secret,
+ refresh_token_.refresh_token);
memset(&request, 0, sizeof(grpc_httpcli_request));
request.host = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_HOST;
request.http.path = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH;
@@ -437,20 +428,19 @@ static void refresh_token_fetch_oauth2(
gpr_free(body);
}
-grpc_call_credentials*
+grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials(
+ grpc_auth_refresh_token refresh_token)
+ : refresh_token_(refresh_token) {}
+
+grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_refresh_token_credentials_create_from_auth_refresh_token(
grpc_auth_refresh_token refresh_token) {
- grpc_google_refresh_token_credentials* c;
if (!grpc_auth_refresh_token_is_valid(&refresh_token)) {
gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation");
return nullptr;
}
- c = static_cast<grpc_google_refresh_token_credentials*>(
- gpr_zalloc(sizeof(grpc_google_refresh_token_credentials)));
- init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2);
- c->base.base.vtable = &refresh_token_vtable;
- c->refresh_token = refresh_token;
- return &c->base.base;
+ return grpc_core::MakeRefCounted<grpc_google_refresh_token_credentials>(
+ refresh_token);
}
static char* create_loggable_refresh_token(grpc_auth_refresh_token* token) {
@@ -478,59 +468,50 @@ grpc_call_credentials* grpc_google_refresh_token_credentials_create(
gpr_free(loggable_token);
}
GPR_ASSERT(reserved == nullptr);
- return grpc_refresh_token_credentials_create_from_auth_refresh_token(token);
+ return grpc_refresh_token_credentials_create_from_auth_refresh_token(token)
+ .release();
}
//
// Oauth2 Access Token credentials.
//
-static void access_token_destruct(grpc_call_credentials* creds) {
- grpc_access_token_credentials* c =
- reinterpret_cast<grpc_access_token_credentials*>(creds);
- GRPC_MDELEM_UNREF(c->access_token_md);
+grpc_access_token_credentials::~grpc_access_token_credentials() {
+ GRPC_MDELEM_UNREF(access_token_md_);
}
-static bool access_token_get_request_metadata(
- grpc_call_credentials* creds, grpc_polling_entity* pollent,
- grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata, grpc_error** error) {
- grpc_access_token_credentials* c =
- reinterpret_cast<grpc_access_token_credentials*>(creds);
- grpc_credentials_mdelem_array_add(md_array, c->access_token_md);
+bool grpc_access_token_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
+ grpc_credentials_mdelem_array_add(md_array, access_token_md_);
return true;
}
-static void access_token_cancel_get_request_metadata(
- grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
+void grpc_access_token_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable access_token_vtable = {
- access_token_destruct, access_token_get_request_metadata,
- access_token_cancel_get_request_metadata};
+grpc_access_token_credentials::grpc_access_token_credentials(
+ const char* access_token)
+ : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) {
+ char* token_md_value;
+ gpr_asprintf(&token_md_value, "Bearer %s", access_token);
+ grpc_core::ExecCtx exec_ctx;
+ access_token_md_ = grpc_mdelem_from_slices(
+ grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
+ grpc_slice_from_copied_string(token_md_value));
+ gpr_free(token_md_value);
+}
grpc_call_credentials* grpc_access_token_credentials_create(
const char* access_token, void* reserved) {
- grpc_access_token_credentials* c =
- static_cast<grpc_access_token_credentials*>(
- gpr_zalloc(sizeof(grpc_access_token_credentials)));
GRPC_API_TRACE(
"grpc_access_token_credentials_create(access_token=<redacted>, "
"reserved=%p)",
1, (reserved));
GPR_ASSERT(reserved == nullptr);
- c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2;
- c->base.vtable = &access_token_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- char* token_md_value;
- gpr_asprintf(&token_md_value, "Bearer %s", access_token);
- grpc_core::ExecCtx exec_ctx;
- c->access_token_md = grpc_mdelem_from_slices(
- grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY),
- grpc_slice_from_copied_string(token_md_value));
-
- gpr_free(token_md_value);
- return &c->base;
+ return grpc_core::MakeRefCounted<grpc_access_token_credentials>(access_token)
+ .release();
}
diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h
index 12a1d4484f..510a78b484 100644
--- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h
+++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h
@@ -54,46 +54,91 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token);
// This object is a base for credentials that need to acquire an oauth2 token
// from an http service.
-typedef void (*grpc_fetch_oauth2_func)(grpc_credentials_metadata_request* req,
- grpc_httpcli_context* http_context,
- grpc_polling_entity* pollent,
- grpc_iomgr_cb_func cb,
- grpc_millis deadline);
-
-typedef struct grpc_oauth2_pending_get_request_metadata {
+struct grpc_oauth2_pending_get_request_metadata {
grpc_credentials_mdelem_array* md_array;
grpc_closure* on_request_metadata;
grpc_polling_entity* pollent;
struct grpc_oauth2_pending_get_request_metadata* next;
-} grpc_oauth2_pending_get_request_metadata;
-
-typedef struct {
- grpc_call_credentials base;
- gpr_mu mu;
- grpc_mdelem access_token_md;
- gpr_timespec token_expiration;
- bool token_fetch_pending;
- grpc_oauth2_pending_get_request_metadata* pending_requests;
- grpc_httpcli_context httpcli_context;
- grpc_fetch_oauth2_func fetch_func;
- grpc_polling_entity pollent;
-} grpc_oauth2_token_fetcher_credentials;
+};
+
+class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials {
+ public:
+ grpc_oauth2_token_fetcher_credentials();
+ ~grpc_oauth2_token_fetcher_credentials() override;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ void on_http_response(grpc_credentials_metadata_request* r,
+ grpc_error* error);
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ protected:
+ virtual void fetch_oauth2(grpc_credentials_metadata_request* req,
+ grpc_httpcli_context* httpcli_context,
+ grpc_polling_entity* pollent, grpc_iomgr_cb_func cb,
+ grpc_millis deadline) GRPC_ABSTRACT;
+
+ private:
+ gpr_mu mu_;
+ grpc_mdelem access_token_md_ = GRPC_MDNULL;
+ gpr_timespec token_expiration_;
+ bool token_fetch_pending_ = false;
+ grpc_oauth2_pending_get_request_metadata* pending_requests_ = nullptr;
+ grpc_httpcli_context httpcli_context_;
+ grpc_polling_entity pollent_;
+};
// Google refresh token credentials.
-typedef struct {
- grpc_oauth2_token_fetcher_credentials base;
- grpc_auth_refresh_token refresh_token;
-} grpc_google_refresh_token_credentials;
+class grpc_google_refresh_token_credentials final
+ : public grpc_oauth2_token_fetcher_credentials {
+ public:
+ grpc_google_refresh_token_credentials(grpc_auth_refresh_token refresh_token);
+ ~grpc_google_refresh_token_credentials() override;
+
+ const grpc_auth_refresh_token& refresh_token() const {
+ return refresh_token_;
+ }
+
+ protected:
+ void fetch_oauth2(grpc_credentials_metadata_request* req,
+ grpc_httpcli_context* httpcli_context,
+ grpc_polling_entity* pollent, grpc_iomgr_cb_func cb,
+ grpc_millis deadline) override;
+
+ private:
+ grpc_auth_refresh_token refresh_token_;
+};
// Access token credentials.
-typedef struct {
- grpc_call_credentials base;
- grpc_mdelem access_token_md;
-} grpc_access_token_credentials;
+class grpc_access_token_credentials final : public grpc_call_credentials {
+ public:
+ grpc_access_token_credentials(const char* access_token);
+ ~grpc_access_token_credentials() override;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ private:
+ grpc_mdelem access_token_md_;
+};
// Private constructor for refresh token credentials from an already parsed
// refresh token. Takes ownership of the refresh token.
-grpc_call_credentials*
+grpc_core::RefCountedPtr<grpc_call_credentials>
grpc_refresh_token_credentials_create_from_auth_refresh_token(
grpc_auth_refresh_token token);
diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc
index 4015124298..52982fdb8f 100644
--- a/src/core/lib/security/credentials/plugin/plugin_credentials.cc
+++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc
@@ -35,20 +35,17 @@
grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials");
-static void plugin_destruct(grpc_call_credentials* creds) {
- grpc_plugin_credentials* c =
- reinterpret_cast<grpc_plugin_credentials*>(creds);
- gpr_mu_destroy(&c->mu);
- if (c->plugin.state != nullptr && c->plugin.destroy != nullptr) {
- c->plugin.destroy(c->plugin.state);
+grpc_plugin_credentials::~grpc_plugin_credentials() {
+ gpr_mu_destroy(&mu_);
+ if (plugin_.state != nullptr && plugin_.destroy != nullptr) {
+ plugin_.destroy(plugin_.state);
}
}
-static void pending_request_remove_locked(
- grpc_plugin_credentials* c,
- grpc_plugin_credentials_pending_request* pending_request) {
+void grpc_plugin_credentials::pending_request_remove_locked(
+ pending_request* pending_request) {
if (pending_request->prev == nullptr) {
- c->pending_requests = pending_request->next;
+ pending_requests_ = pending_request->next;
} else {
pending_request->prev->next = pending_request->next;
}
@@ -62,17 +59,17 @@ static void pending_request_remove_locked(
// cancelled out from under us.
// When this returns, r->cancelled indicates whether the request was
// cancelled before completion.
-static void pending_request_complete(
- grpc_plugin_credentials_pending_request* r) {
- gpr_mu_lock(&r->creds->mu);
- if (!r->cancelled) pending_request_remove_locked(r->creds, r);
- gpr_mu_unlock(&r->creds->mu);
+void grpc_plugin_credentials::pending_request_complete(pending_request* r) {
+ GPR_DEBUG_ASSERT(r->creds == this);
+ gpr_mu_lock(&mu_);
+ if (!r->cancelled) pending_request_remove_locked(r);
+ gpr_mu_unlock(&mu_);
// Ref to credentials not needed anymore.
- grpc_call_credentials_unref(&r->creds->base);
+ Unref();
}
static grpc_error* process_plugin_result(
- grpc_plugin_credentials_pending_request* r, const grpc_metadata* md,
+ grpc_plugin_credentials::pending_request* r, const grpc_metadata* md,
size_t num_md, grpc_status_code status, const char* error_details) {
grpc_error* error = GRPC_ERROR_NONE;
if (status != GRPC_STATUS_OK) {
@@ -119,8 +116,8 @@ static void plugin_md_request_metadata_ready(void* request,
/* called from application code */
grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED |
GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP);
- grpc_plugin_credentials_pending_request* r =
- static_cast<grpc_plugin_credentials_pending_request*>(request);
+ grpc_plugin_credentials::pending_request* r =
+ static_cast<grpc_plugin_credentials::pending_request*>(request);
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin returned "
@@ -128,7 +125,7 @@ static void plugin_md_request_metadata_ready(void* request,
r->creds, r);
}
// Remove request from pending list if not previously cancelled.
- pending_request_complete(r);
+ r->creds->pending_request_complete(r);
// If it has not been cancelled, process it.
if (!r->cancelled) {
grpc_error* error =
@@ -143,65 +140,59 @@ static void plugin_md_request_metadata_ready(void* request,
gpr_free(r);
}
-static bool plugin_get_request_metadata(grpc_call_credentials* creds,
- grpc_polling_entity* pollent,
- grpc_auth_metadata_context context,
- grpc_credentials_mdelem_array* md_array,
- grpc_closure* on_request_metadata,
- grpc_error** error) {
- grpc_plugin_credentials* c =
- reinterpret_cast<grpc_plugin_credentials*>(creds);
+bool grpc_plugin_credentials::get_request_metadata(
+ grpc_polling_entity* pollent, grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
+ grpc_error** error) {
bool retval = true; // Synchronous return.
- if (c->plugin.get_metadata != nullptr) {
+ if (plugin_.get_metadata != nullptr) {
// Create pending_request object.
- grpc_plugin_credentials_pending_request* pending_request =
- static_cast<grpc_plugin_credentials_pending_request*>(
- gpr_zalloc(sizeof(*pending_request)));
- pending_request->creds = c;
- pending_request->md_array = md_array;
- pending_request->on_request_metadata = on_request_metadata;
+ pending_request* request =
+ static_cast<pending_request*>(gpr_zalloc(sizeof(*request)));
+ request->creds = this;
+ request->md_array = md_array;
+ request->on_request_metadata = on_request_metadata;
// Add it to the pending list.
- gpr_mu_lock(&c->mu);
- if (c->pending_requests != nullptr) {
- c->pending_requests->prev = pending_request;
+ gpr_mu_lock(&mu_);
+ if (pending_requests_ != nullptr) {
+ pending_requests_->prev = request;
}
- pending_request->next = c->pending_requests;
- c->pending_requests = pending_request;
- gpr_mu_unlock(&c->mu);
+ request->next = pending_requests_;
+ pending_requests_ = request;
+ gpr_mu_unlock(&mu_);
// Invoke the plugin. The callback holds a ref to us.
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin",
- c, pending_request);
+ this, request);
}
- grpc_call_credentials_ref(creds);
+ Ref().release();
grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX];
size_t num_creds_md = 0;
grpc_status_code status = GRPC_STATUS_OK;
const char* error_details = nullptr;
- if (!c->plugin.get_metadata(c->plugin.state, context,
- plugin_md_request_metadata_ready,
- pending_request, creds_md, &num_creds_md,
- &status, &error_details)) {
+ if (!plugin_.get_metadata(
+ plugin_.state, context, plugin_md_request_metadata_ready, request,
+ creds_md, &num_creds_md, &status, &error_details)) {
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin will return "
"asynchronously",
- c, pending_request);
+ this, request);
}
return false; // Asynchronous return.
}
// Returned synchronously.
// Remove request from pending list if not previously cancelled.
- pending_request_complete(pending_request);
+ request->creds->pending_request_complete(request);
// If the request was cancelled, the error will have been returned
// asynchronously by plugin_cancel_get_request_metadata(), so return
// false. Otherwise, process the result.
- if (pending_request->cancelled) {
+ if (request->cancelled) {
if (grpc_plugin_credentials_trace.enabled()) {
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p was cancelled, error "
"will be returned asynchronously",
- c, pending_request);
+ this, request);
}
retval = false;
} else {
@@ -209,10 +200,10 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds,
gpr_log(GPR_INFO,
"plugin_credentials[%p]: request %p: plugin returned "
"synchronously",
- c, pending_request);
+ this, request);
}
- *error = process_plugin_result(pending_request, creds_md, num_creds_md,
- status, error_details);
+ *error = process_plugin_result(request, creds_md, num_creds_md, status,
+ error_details);
}
// Clean up.
for (size_t i = 0; i < num_creds_md; ++i) {
@@ -220,51 +211,42 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds,
grpc_slice_unref_internal(creds_md[i].value);
}
gpr_free((void*)error_details);
- gpr_free(pending_request);
+ gpr_free(request);
}
return retval;
}
-static void plugin_cancel_get_request_metadata(
- grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
- grpc_error* error) {
- grpc_plugin_credentials* c =
- reinterpret_cast<grpc_plugin_credentials*>(creds);
- gpr_mu_lock(&c->mu);
- for (grpc_plugin_credentials_pending_request* pending_request =
- c->pending_requests;
+void grpc_plugin_credentials::cancel_get_request_metadata(
+ grpc_credentials_mdelem_array* md_array, grpc_error* error) {
+ gpr_mu_lock(&mu_);
+ for (pending_request* pending_request = pending_requests_;
pending_request != nullptr; pending_request = pending_request->next) {
if (pending_request->md_array == md_array) {
if (grpc_plugin_credentials_trace.enabled()) {
- gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", c,
+ gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this,
pending_request);
}
pending_request->cancelled = true;
GRPC_CLOSURE_SCHED(pending_request->on_request_metadata,
GRPC_ERROR_REF(error));
- pending_request_remove_locked(c, pending_request);
+ pending_request_remove_locked(pending_request);
break;
}
}
- gpr_mu_unlock(&c->mu);
+ gpr_mu_unlock(&mu_);
GRPC_ERROR_UNREF(error);
}
-static grpc_call_credentials_vtable plugin_vtable = {
- plugin_destruct, plugin_get_request_metadata,
- plugin_cancel_get_request_metadata};
+grpc_plugin_credentials::grpc_plugin_credentials(
+ grpc_metadata_credentials_plugin plugin)
+ : grpc_call_credentials(plugin.type), plugin_(plugin) {
+ gpr_mu_init(&mu_);
+}
grpc_call_credentials* grpc_metadata_credentials_create_from_plugin(
grpc_metadata_credentials_plugin plugin, void* reserved) {
- grpc_plugin_credentials* c =
- static_cast<grpc_plugin_credentials*>(gpr_zalloc(sizeof(*c)));
GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1,
(reserved));
GPR_ASSERT(reserved == nullptr);
- c->base.type = plugin.type;
- c->base.vtable = &plugin_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- c->plugin = plugin;
- gpr_mu_init(&c->mu);
- return &c->base;
+ return grpc_core::New<grpc_plugin_credentials>(plugin);
}
diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h
index caf990efa1..77a957e513 100644
--- a/src/core/lib/security/credentials/plugin/plugin_credentials.h
+++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h
@@ -25,22 +25,45 @@
extern grpc_core::TraceFlag grpc_plugin_credentials_trace;
-struct grpc_plugin_credentials;
-
-typedef struct grpc_plugin_credentials_pending_request {
- bool cancelled;
- struct grpc_plugin_credentials* creds;
- grpc_credentials_mdelem_array* md_array;
- grpc_closure* on_request_metadata;
- struct grpc_plugin_credentials_pending_request* prev;
- struct grpc_plugin_credentials_pending_request* next;
-} grpc_plugin_credentials_pending_request;
-
-typedef struct grpc_plugin_credentials {
- grpc_call_credentials base;
- grpc_metadata_credentials_plugin plugin;
- gpr_mu mu;
- grpc_plugin_credentials_pending_request* pending_requests;
-} grpc_plugin_credentials;
+// This type is forward declared as a C struct and we cannot define it as a
+// class. Otherwise, compiler will complain about type mismatch due to
+// -Wmismatched-tags.
+struct grpc_plugin_credentials final : public grpc_call_credentials {
+ public:
+ struct pending_request {
+ bool cancelled;
+ struct grpc_plugin_credentials* creds;
+ grpc_credentials_mdelem_array* md_array;
+ grpc_closure* on_request_metadata;
+ struct pending_request* prev;
+ struct pending_request* next;
+ };
+
+ explicit grpc_plugin_credentials(grpc_metadata_credentials_plugin plugin);
+ ~grpc_plugin_credentials() override;
+
+ bool get_request_metadata(grpc_polling_entity* pollent,
+ grpc_auth_metadata_context context,
+ grpc_credentials_mdelem_array* md_array,
+ grpc_closure* on_request_metadata,
+ grpc_error** error) override;
+
+ void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array,
+ grpc_error* error) override;
+
+ // Checks if the request has been cancelled.
+ // If not, removes it from the pending list, so that it cannot be
+ // cancelled out from under us.
+ // When this returns, r->cancelled indicates whether the request was
+ // cancelled before completion.
+ void pending_request_complete(pending_request* r);
+
+ private:
+ void pending_request_remove_locked(pending_request* pending_request);
+
+ grpc_metadata_credentials_plugin plugin_;
+ gpr_mu mu_;
+ pending_request* pending_requests_ = nullptr;
+};
#endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_PLUGIN_PLUGIN_CREDENTIALS_H */
diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc
index 3d6f2f200a..83db86f1ea 100644
--- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc
+++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc
@@ -44,22 +44,27 @@ void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp,
gpr_free(kp);
}
-static void ssl_destruct(grpc_channel_credentials* creds) {
- grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds);
- gpr_free(c->config.pem_root_certs);
- grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pair, 1);
- if (c->config.verify_options.verify_peer_destruct != nullptr) {
- c->config.verify_options.verify_peer_destruct(
- c->config.verify_options.verify_peer_callback_userdata);
+grpc_ssl_credentials::grpc_ssl_credentials(
+ const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
+ const verify_peer_options* verify_options)
+ : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) {
+ build_config(pem_root_certs, pem_key_cert_pair, verify_options);
+}
+
+grpc_ssl_credentials::~grpc_ssl_credentials() {
+ gpr_free(config_.pem_root_certs);
+ grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1);
+ if (config_.verify_options.verify_peer_destruct != nullptr) {
+ config_.verify_options.verify_peer_destruct(
+ config_.verify_options.verify_peer_callback_userdata);
}
}
-static grpc_security_status ssl_create_security_connector(
- grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_ssl_credentials::create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
const char* target, const grpc_channel_args* args,
- grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
- grpc_ssl_credentials* c = reinterpret_cast<grpc_ssl_credentials*>(creds);
- grpc_security_status status = GRPC_SECURITY_OK;
+ grpc_channel_args** new_args) {
const char* overridden_target_name = nullptr;
tsi_ssl_session_cache* ssl_session_cache = nullptr;
for (size_t i = 0; args && i < args->num_args; i++) {
@@ -74,52 +79,47 @@ static grpc_security_status ssl_create_security_connector(
static_cast<tsi_ssl_session_cache*>(arg->value.pointer.p);
}
}
- status = grpc_ssl_channel_security_connector_create(
- creds, call_creds, &c->config, target, overridden_target_name,
- ssl_session_cache, sc);
- if (status != GRPC_SECURITY_OK) {
- return status;
+ grpc_core::RefCountedPtr<grpc_channel_security_connector> sc =
+ grpc_ssl_channel_security_connector_create(
+ this->Ref(), std::move(call_creds), &config_, target,
+ overridden_target_name, ssl_session_cache);
+ if (sc == nullptr) {
+ return sc;
}
grpc_arg new_arg = grpc_channel_arg_string_create(
(char*)GRPC_ARG_HTTP2_SCHEME, (char*)"https");
*new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1);
- return status;
+ return sc;
}
-static grpc_channel_credentials_vtable ssl_vtable = {
- ssl_destruct, ssl_create_security_connector, nullptr};
-
-static void ssl_build_config(const char* pem_root_certs,
- grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
- const verify_peer_options* verify_options,
- grpc_ssl_config* config) {
- if (pem_root_certs != nullptr) {
- config->pem_root_certs = gpr_strdup(pem_root_certs);
- }
+void grpc_ssl_credentials::build_config(
+ const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
+ const verify_peer_options* verify_options) {
+ config_.pem_root_certs = gpr_strdup(pem_root_certs);
if (pem_key_cert_pair != nullptr) {
GPR_ASSERT(pem_key_cert_pair->private_key != nullptr);
GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr);
- config->pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>(
+ config_.pem_key_cert_pair = static_cast<tsi_ssl_pem_key_cert_pair*>(
gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair)));
- config->pem_key_cert_pair->cert_chain =
+ config_.pem_key_cert_pair->cert_chain =
gpr_strdup(pem_key_cert_pair->cert_chain);
- config->pem_key_cert_pair->private_key =
+ config_.pem_key_cert_pair->private_key =
gpr_strdup(pem_key_cert_pair->private_key);
+ } else {
+ config_.pem_key_cert_pair = nullptr;
}
if (verify_options != nullptr) {
- memcpy(&config->verify_options, verify_options,
+ memcpy(&config_.verify_options, verify_options,
sizeof(verify_peer_options));
} else {
// Otherwise set all options to default values
- memset(&config->verify_options, 0, sizeof(verify_peer_options));
+ memset(&config_.verify_options, 0, sizeof(verify_peer_options));
}
}
grpc_channel_credentials* grpc_ssl_credentials_create(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
const verify_peer_options* verify_options, void* reserved) {
- grpc_ssl_credentials* c = static_cast<grpc_ssl_credentials*>(
- gpr_zalloc(sizeof(grpc_ssl_credentials)));
GRPC_API_TRACE(
"grpc_ssl_credentials_create(pem_root_certs=%s, "
"pem_key_cert_pair=%p, "
@@ -127,12 +127,9 @@ grpc_channel_credentials* grpc_ssl_credentials_create(
"reserved=%p)",
4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved));
GPR_ASSERT(reserved == nullptr);
- c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL;
- c->base.vtable = &ssl_vtable;
- gpr_ref_init(&c->base.refcount, 1);
- ssl_build_config(pem_root_certs, pem_key_cert_pair, verify_options,
- &c->config);
- return &c->base;
+
+ return grpc_core::New<grpc_ssl_credentials>(pem_root_certs, pem_key_cert_pair,
+ verify_options);
}
//
@@ -145,21 +142,29 @@ struct grpc_ssl_server_credentials_options {
grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher;
};
-static void ssl_server_destruct(grpc_server_credentials* creds) {
- grpc_ssl_server_credentials* c =
- reinterpret_cast<grpc_ssl_server_credentials*>(creds);
- grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pairs,
- c->config.num_key_cert_pairs);
- gpr_free(c->config.pem_root_certs);
+grpc_ssl_server_credentials::grpc_ssl_server_credentials(
+ const grpc_ssl_server_credentials_options& options)
+ : grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) {
+ if (options.certificate_config_fetcher != nullptr) {
+ config_.client_certificate_request = options.client_certificate_request;
+ certificate_config_fetcher_ = *options.certificate_config_fetcher;
+ } else {
+ build_config(options.certificate_config->pem_root_certs,
+ options.certificate_config->pem_key_cert_pairs,
+ options.certificate_config->num_key_cert_pairs,
+ options.client_certificate_request);
+ }
}
-static grpc_security_status ssl_server_create_security_connector(
- grpc_server_credentials* creds, grpc_server_security_connector** sc) {
- return grpc_ssl_server_security_connector_create(creds, sc);
+grpc_ssl_server_credentials::~grpc_ssl_server_credentials() {
+ grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs,
+ config_.num_key_cert_pairs);
+ gpr_free(config_.pem_root_certs);
+}
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_ssl_server_credentials::create_security_connector() {
+ return grpc_ssl_server_security_connector_create(this->Ref());
}
-
-static grpc_server_credentials_vtable ssl_server_vtable = {
- ssl_server_destruct, ssl_server_create_security_connector};
tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,
@@ -179,18 +184,15 @@ tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
return tsi_pairs;
}
-static void ssl_build_server_config(
+void grpc_ssl_server_credentials::build_config(
const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,
size_t num_key_cert_pairs,
- grpc_ssl_client_certificate_request_type client_certificate_request,
- grpc_ssl_server_config* config) {
- config->client_certificate_request = client_certificate_request;
- if (pem_root_certs != nullptr) {
- config->pem_root_certs = gpr_strdup(pem_root_certs);
- }
- config->pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
+ grpc_ssl_client_certificate_request_type client_certificate_request) {
+ config_.client_certificate_request = client_certificate_request;
+ config_.pem_root_certs = gpr_strdup(pem_root_certs);
+ config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
pem_key_cert_pairs, num_key_cert_pairs);
- config->num_key_cert_pairs = num_key_cert_pairs;
+ config_.num_key_cert_pairs = num_key_cert_pairs;
}
grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create(
@@ -200,9 +202,7 @@ grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create(
grpc_ssl_server_certificate_config* config =
static_cast<grpc_ssl_server_certificate_config*>(
gpr_zalloc(sizeof(grpc_ssl_server_certificate_config)));
- if (pem_root_certs != nullptr) {
- config->pem_root_certs = gpr_strdup(pem_root_certs);
- }
+ config->pem_root_certs = gpr_strdup(pem_root_certs);
if (num_key_cert_pairs > 0) {
GPR_ASSERT(pem_key_cert_pairs != nullptr);
config->pem_key_cert_pairs = static_cast<grpc_ssl_pem_key_cert_pair*>(
@@ -311,7 +311,6 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_ex(
grpc_server_credentials* grpc_ssl_server_credentials_create_with_options(
grpc_ssl_server_credentials_options* options) {
grpc_server_credentials* retval = nullptr;
- grpc_ssl_server_credentials* c = nullptr;
if (options == nullptr) {
gpr_log(GPR_ERROR,
@@ -331,23 +330,7 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_with_options(
goto done;
}
- c = static_cast<grpc_ssl_server_credentials*>(
- gpr_zalloc(sizeof(grpc_ssl_server_credentials)));
- c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL;
- gpr_ref_init(&c->base.refcount, 1);
- c->base.vtable = &ssl_server_vtable;
-
- if (options->certificate_config_fetcher != nullptr) {
- c->config.client_certificate_request = options->client_certificate_request;
- c->certificate_config_fetcher = *options->certificate_config_fetcher;
- } else {
- ssl_build_server_config(options->certificate_config->pem_root_certs,
- options->certificate_config->pem_key_cert_pairs,
- options->certificate_config->num_key_cert_pairs,
- options->client_certificate_request, &c->config);
- }
-
- retval = &c->base;
+ retval = grpc_core::New<grpc_ssl_server_credentials>(*options);
done:
grpc_ssl_server_credentials_options_destroy(options);
diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h
index 0fba413876..e1174327b3 100644
--- a/src/core/lib/security/credentials/ssl/ssl_credentials.h
+++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h
@@ -24,27 +24,70 @@
#include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h"
-typedef struct {
- grpc_channel_credentials base;
- grpc_ssl_config config;
-} grpc_ssl_credentials;
+class grpc_ssl_credentials : public grpc_channel_credentials {
+ public:
+ grpc_ssl_credentials(const char* pem_root_certs,
+ grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
+ const verify_peer_options* verify_options);
+
+ ~grpc_ssl_credentials() override;
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector>
+ create_security_connector(
+ grpc_core::RefCountedPtr<grpc_call_credentials> call_creds,
+ const char* target, const grpc_channel_args* args,
+ grpc_channel_args** new_args) override;
+
+ private:
+ void build_config(const char* pem_root_certs,
+ grpc_ssl_pem_key_cert_pair* pem_key_cert_pair,
+ const verify_peer_options* verify_options);
+
+ grpc_ssl_config config_;
+};
struct grpc_ssl_server_certificate_config {
- grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs;
- size_t num_key_cert_pairs;
- char* pem_root_certs;
+ grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr;
+ size_t num_key_cert_pairs = 0;
+ char* pem_root_certs = nullptr;
};
-typedef struct {
- grpc_ssl_server_certificate_config_callback cb;
+struct grpc_ssl_server_certificate_config_fetcher {
+ grpc_ssl_server_certificate_config_callback cb = nullptr;
void* user_data;
-} grpc_ssl_server_certificate_config_fetcher;
+};
+
+class grpc_ssl_server_credentials final : public grpc_server_credentials {
+ public:
+ grpc_ssl_server_credentials(
+ const grpc_ssl_server_credentials_options& options);
+ ~grpc_ssl_server_credentials() override;
-typedef struct {
- grpc_server_credentials base;
- grpc_ssl_server_config config;
- grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher;
-} grpc_ssl_server_credentials;
+ grpc_core::RefCountedPtr<grpc_server_security_connector>
+ create_security_connector() override;
+
+ bool has_cert_config_fetcher() const {
+ return certificate_config_fetcher_.cb != nullptr;
+ }
+
+ grpc_ssl_certificate_config_reload_status FetchCertConfig(
+ grpc_ssl_server_certificate_config** config) {
+ GPR_DEBUG_ASSERT(has_cert_config_fetcher());
+ return certificate_config_fetcher_.cb(certificate_config_fetcher_.user_data,
+ config);
+ }
+
+ const grpc_ssl_server_config& config() const { return config_; }
+
+ private:
+ void build_config(
+ const char* pem_root_certs,
+ grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs,
+ grpc_ssl_client_certificate_request_type client_certificate_request);
+
+ grpc_ssl_server_config config_;
+ grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher_;
+};
tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs(
const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs,
diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc
index dd71c8bc60..3ad0cc353c 100644
--- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc
+++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc
@@ -28,6 +28,7 @@
#include <grpc/support/log.h>
#include <grpc/support/string_util.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/credentials/alts/alts_credentials.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/lib/slice/slice_internal.h"
@@ -35,64 +36,9 @@
#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
#include "src/core/tsi/transport_security.h"
-typedef struct {
- grpc_channel_security_connector base;
- char* target_name;
-} grpc_alts_channel_security_connector;
+namespace {
-typedef struct {
- grpc_server_security_connector base;
-} grpc_alts_server_security_connector;
-
-static void alts_channel_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- grpc_call_credentials_unref(c->base.request_metadata_creds);
- grpc_channel_credentials_unref(c->base.channel_creds);
- gpr_free(c->target_name);
- gpr_free(sc);
-}
-
-static void alts_server_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
- grpc_server_credentials_unref(c->base.server_creds);
- gpr_free(sc);
-}
-
-static void alts_channel_add_handshakers(
- grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- auto c = reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- grpc_alts_credentials* creds =
- reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
- GPR_ASSERT(alts_tsi_handshaker_create(
- creds->options, c->target_name, creds->handshaker_service_url,
- true, interested_parties, &handshaker) == TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static void alts_server_add_handshakers(
- grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- auto c = reinterpret_cast<grpc_alts_server_security_connector*>(sc);
- grpc_alts_server_credentials* creds =
- reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
- GPR_ASSERT(alts_tsi_handshaker_create(
- creds->options, nullptr, creds->handshaker_service_url, false,
- interested_parties, &handshaker) == TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static void alts_set_rpc_protocol_versions(
+void alts_set_rpc_protocol_versions(
grpc_gcp_rpc_protocol_versions* rpc_versions) {
grpc_gcp_rpc_protocol_versions_set_max(rpc_versions,
GRPC_PROTOCOL_VERSION_MAX_MAJOR,
@@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions(
GRPC_PROTOCOL_VERSION_MIN_MINOR);
}
+void alts_check_peer(tsi_peer peer,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) {
+ *auth_context =
+ grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer);
+ tsi_peer_destruct(&peer);
+ grpc_error* error =
+ *auth_context != nullptr
+ ? GRPC_ERROR_NONE
+ : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "Could not get ALTS auth context from TSI peer");
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
+}
+
+class grpc_alts_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ grpc_alts_channel_security_connector(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name)
+ : grpc_channel_security_connector(/*url_scheme=*/nullptr,
+ std::move(channel_creds),
+ std::move(request_metadata_creds)),
+ target_name_(gpr_strdup(target_name)) {
+ grpc_alts_credentials* creds =
+ static_cast<grpc_alts_credentials*>(mutable_channel_creds());
+ alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
+ }
+
+ ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); }
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ const grpc_alts_credentials* creds =
+ static_cast<const grpc_alts_credentials*>(channel_creds());
+ GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_,
+ creds->handshaker_service_url(), true,
+ interested_parties,
+ &handshaker) == TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ alts_check_peer(peer, auth_context, on_peer_checked);
+ }
+
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_alts_channel_security_connector*>(other_sc);
+ int c = channel_security_connector_cmp(other);
+ if (c != 0) return c;
+ return strcmp(target_name_, other->target_name_);
+ }
+
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ if (host == nullptr || strcmp(host, target_name_) != 0) {
+ *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "ALTS call host does not match target name");
+ }
+ return true;
+ }
+
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
+ }
+
+ private:
+ char* target_name_;
+};
+
+class grpc_alts_server_security_connector final
+ : public grpc_server_security_connector {
+ public:
+ grpc_alts_server_security_connector(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_server_security_connector(/*url_scheme=*/nullptr,
+ std::move(server_creds)) {
+ grpc_alts_server_credentials* creds =
+ reinterpret_cast<grpc_alts_server_credentials*>(mutable_server_creds());
+ alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions);
+ }
+ ~grpc_alts_server_security_connector() override = default;
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ const grpc_alts_server_credentials* creds =
+ static_cast<const grpc_alts_server_credentials*>(server_creds());
+ GPR_ASSERT(alts_tsi_handshaker_create(
+ creds->options(), nullptr, creds->handshaker_service_url(),
+ false, interested_parties, &handshaker) == TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ alts_check_peer(peer, auth_context, on_peer_checked);
+ }
+
+ int cmp(const grpc_security_connector* other) const override {
+ return server_security_connector_cmp(
+ static_cast<const grpc_server_security_connector*>(other));
+ }
+};
+} // namespace
+
namespace grpc_core {
namespace internal {
-
-grpc_security_status grpc_alts_auth_context_from_tsi_peer(
- const tsi_peer* peer, grpc_auth_context** ctx) {
- if (peer == nullptr || ctx == nullptr) {
+grpc_core::RefCountedPtr<grpc_auth_context>
+grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) {
+ if (peer == nullptr) {
gpr_log(GPR_ERROR,
"Invalid arguments to grpc_alts_auth_context_from_tsi_peer()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- *ctx = nullptr;
/* Validate certificate type. */
const tsi_peer_property* cert_type_prop =
tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY);
@@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE,
cert_type_prop->value.length) != 0) {
gpr_log(GPR_ERROR, "Invalid or missing certificate type property.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* Validate RPC protocol versions. */
const tsi_peer_property* rpc_versions_prop =
tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS);
if (rpc_versions_prop == nullptr) {
gpr_log(GPR_ERROR, "Missing rpc protocol versions property.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
grpc_gcp_rpc_protocol_versions local_versions, peer_versions;
alts_set_rpc_protocol_versions(&local_versions);
@@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
grpc_slice_unref_internal(slice);
if (!decode_result) {
gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* TODO: Pass highest common rpc protocol version to grpc caller. */
bool check_result = grpc_gcp_rpc_protocol_versions_check(
&local_versions, &peer_versions, nullptr);
if (!check_result) {
gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions.");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
/* Create auth context. */
- *ctx = grpc_auth_context_create(nullptr);
+ auto ctx = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
- *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
+ ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_ALTS_TRANSPORT_SECURITY_TYPE);
size_t i = 0;
for (i = 0; i < peer->property_count; i++) {
@@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer(
/* Add service account to auth context. */
if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) {
grpc_auth_context_add_property(
- *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data,
- tsi_prop->value.length);
+ ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY,
+ tsi_prop->value.data, tsi_prop->value.length);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
- *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
+ ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1);
}
}
- if (!grpc_auth_context_peer_is_authenticated(*ctx)) {
+ if (!grpc_auth_context_peer_is_authenticated(ctx.get())) {
gpr_log(GPR_ERROR, "Invalid unauthenticated peer.");
- GRPC_AUTH_CONTEXT_UNREF(*ctx, "test");
- *ctx = nullptr;
- return GRPC_SECURITY_ERROR;
+ ctx.reset(DEBUG_LOCATION, "test");
+ return nullptr;
}
- return GRPC_SECURITY_OK;
+ return ctx;
}
} // namespace internal
} // namespace grpc_core
-static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_security_status status;
- status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(
- &peer, auth_context);
- tsi_peer_destruct(&peer);
- grpc_error* error =
- status == GRPC_SECURITY_OK
- ? GRPC_ERROR_NONE
- : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "Could not get ALTS auth context from TSI peer");
- GRPC_CLOSURE_SCHED(on_peer_checked, error);
-}
-
-static int alts_channel_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_alts_channel_security_connector* c1 =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc1);
- grpc_alts_channel_security_connector* c2 =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc2);
- int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
- if (c != 0) return c;
- return strcmp(c1->target_name, c2->target_name);
-}
-
-static int alts_server_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_alts_server_security_connector* c1 =
- reinterpret_cast<grpc_alts_server_security_connector*>(sc1);
- grpc_alts_server_security_connector* c2 =
- reinterpret_cast<grpc_alts_server_security_connector*>(sc2);
- return grpc_server_security_connector_cmp(&c1->base, &c2->base);
-}
-
-static grpc_security_connector_vtable alts_channel_vtable = {
- alts_channel_destroy, alts_check_peer, alts_channel_cmp};
-
-static grpc_security_connector_vtable alts_server_vtable = {
- alts_server_destroy, alts_check_peer, alts_server_cmp};
-
-static bool alts_check_call_host(grpc_channel_security_connector* sc,
- const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error) {
- grpc_alts_channel_security_connector* alts_sc =
- reinterpret_cast<grpc_alts_channel_security_connector*>(sc);
- if (host == nullptr || alts_sc == nullptr ||
- strcmp(host, alts_sc->target_name) != 0) {
- *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "ALTS call host does not match target name");
- }
- return true;
-}
-
-static void alts_cancel_check_call_host(grpc_channel_security_connector* sc,
- grpc_closure* on_call_host_checked,
- grpc_error* error) {
- GRPC_ERROR_UNREF(error);
-}
-
-grpc_security_status grpc_alts_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds, const char* target_name,
- grpc_channel_security_connector** sc) {
- if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) {
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_alts_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name) {
+ if (channel_creds == nullptr || target_name == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_channel_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_alts_channel_security_connector*>(
- gpr_zalloc(sizeof(grpc_alts_channel_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &alts_channel_vtable;
- c->base.add_handshakers = alts_channel_add_handshakers;
- c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
- c->base.request_metadata_creds =
- grpc_call_credentials_ref(request_metadata_creds);
- c->base.check_call_host = alts_check_call_host;
- c->base.cancel_check_call_host = alts_cancel_check_call_host;
- grpc_alts_credentials* creds =
- reinterpret_cast<grpc_alts_credentials*>(c->base.channel_creds);
- alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
- c->target_name = gpr_strdup(target_name);
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_alts_channel_security_connector>(
+ std::move(channel_creds), std::move(request_metadata_creds), target_name);
}
-grpc_security_status grpc_alts_server_security_connector_create(
- grpc_server_credentials* server_creds,
- grpc_server_security_connector** sc) {
- if (server_creds == nullptr || sc == nullptr) {
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_alts_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
+ if (server_creds == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_alts_server_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_alts_server_security_connector*>(
- gpr_zalloc(sizeof(grpc_alts_server_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &alts_server_vtable;
- c->base.server_creds = grpc_server_credentials_ref(server_creds);
- c->base.add_handshakers = alts_server_add_handshakers;
- grpc_alts_server_credentials* creds =
- reinterpret_cast<grpc_alts_server_credentials*>(c->base.server_creds);
- alts_set_rpc_protocol_versions(&creds->options->rpc_versions);
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_alts_server_security_connector>(
+ std::move(server_creds));
}
diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.h b/src/core/lib/security/security_connector/alts/alts_security_connector.h
index d2e057a76a..b96dc36b30 100644
--- a/src/core/lib/security/security_connector/alts/alts_security_connector.h
+++ b/src/core/lib/security/security_connector/alts/alts_security_connector.h
@@ -36,12 +36,13 @@
* - sc: address of ALTS channel security connector instance to be returned from
* the method.
*
- * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure.
+ * It returns nullptr on failure.
*/
-grpc_security_status grpc_alts_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds, const char* target_name,
- grpc_channel_security_connector** sc);
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_alts_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name);
/**
* This method creates an ALTS server security connector.
@@ -50,17 +51,18 @@ grpc_security_status grpc_alts_channel_security_connector_create(
* - sc: address of ALTS server security connector instance to be returned from
* the method.
*
- * It returns GRPC_SECURITY_OK on success, and an error status code on failure.
+ * It returns nullptr on failure.
*/
-grpc_security_status grpc_alts_server_security_connector_create(
- grpc_server_credentials* server_creds, grpc_server_security_connector** sc);
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_alts_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
namespace grpc_core {
namespace internal {
/* Exposed only for testing. */
-grpc_security_status grpc_alts_auth_context_from_tsi_peer(
- const tsi_peer* peer, grpc_auth_context** ctx);
+grpc_core::RefCountedPtr<grpc_auth_context>
+grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer);
} // namespace internal
} // namespace grpc_core
diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc
index 5c0c89b88f..e3b8affb36 100644
--- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc
+++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc
@@ -31,6 +31,7 @@
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/fake/fake_credentials.h"
@@ -38,91 +39,183 @@
#include "src/core/lib/security/transport/target_authority_table.h"
#include "src/core/tsi/fake_transport_security.h"
-typedef struct {
- grpc_channel_security_connector base;
- char* target;
- char* expected_targets;
- bool is_lb_channel;
- char* target_name_override;
-} grpc_fake_channel_security_connector;
+namespace {
+class grpc_fake_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ grpc_fake_channel_security_connector(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target, const grpc_channel_args* args)
+ : grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME,
+ std::move(channel_creds),
+ std::move(request_metadata_creds)),
+ target_(gpr_strdup(target)),
+ expected_targets_(
+ gpr_strdup(grpc_fake_transport_get_expected_targets(args))),
+ is_lb_channel_(grpc_core::FindTargetAuthorityTableInArgs(args) !=
+ nullptr) {
+ const grpc_arg* target_name_override_arg =
+ grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
+ if (target_name_override_arg != nullptr) {
+ target_name_override_ =
+ gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg));
+ } else {
+ target_name_override_ = nullptr;
+ }
+ }
-static void fake_channel_destroy(grpc_security_connector* sc) {
- grpc_fake_channel_security_connector* c =
- reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
- grpc_call_credentials_unref(c->base.request_metadata_creds);
- gpr_free(c->target);
- gpr_free(c->expected_targets);
- gpr_free(c->target_name_override);
- gpr_free(c);
-}
+ ~grpc_fake_channel_security_connector() override {
+ gpr_free(target_);
+ gpr_free(expected_targets_);
+ if (target_name_override_ != nullptr) gpr_free(target_name_override_);
+ }
-static void fake_server_destroy(grpc_security_connector* sc) { gpr_free(sc); }
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override;
-static bool fake_check_target(const char* target_type, const char* target,
- const char* set_str) {
- GPR_ASSERT(target_type != nullptr);
- GPR_ASSERT(target != nullptr);
- char** set = nullptr;
- size_t set_size = 0;
- gpr_string_split(set_str, ",", &set, &set_size);
- bool found = false;
- for (size_t i = 0; i < set_size; ++i) {
- if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true;
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_fake_channel_security_connector*>(other_sc);
+ int c = channel_security_connector_cmp(other);
+ if (c != 0) return c;
+ c = strcmp(target_, other->target_);
+ if (c != 0) return c;
+ if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) {
+ c = GPR_ICMP(expected_targets_, other->expected_targets_);
+ } else {
+ c = strcmp(expected_targets_, other->expected_targets_);
+ }
+ if (c != 0) return c;
+ return GPR_ICMP(is_lb_channel_, other->is_lb_channel_);
}
- for (size_t i = 0; i < set_size; ++i) {
- gpr_free(set[i]);
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr) override {
+ grpc_handshake_manager_add(
+ handshake_mgr,
+ grpc_security_handshaker_create(
+ tsi_create_fake_handshaker(/*is_client=*/true), this));
}
- gpr_free(set);
- return found;
-}
-static void fake_secure_name_check(const char* target,
- const char* expected_targets,
- bool is_lb_channel) {
- if (expected_targets == nullptr) return;
- char** lbs_and_backends = nullptr;
- size_t lbs_and_backends_size = 0;
- bool success = false;
- gpr_string_split(expected_targets, ";", &lbs_and_backends,
- &lbs_and_backends_size);
- if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) {
- gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'",
- expected_targets);
- goto done;
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ char* authority_hostname = nullptr;
+ char* authority_ignored_port = nullptr;
+ char* target_hostname = nullptr;
+ char* target_ignored_port = nullptr;
+ gpr_split_host_port(host, &authority_hostname, &authority_ignored_port);
+ gpr_split_host_port(target_, &target_hostname, &target_ignored_port);
+ if (target_name_override_ != nullptr) {
+ char* fake_security_target_name_override_hostname = nullptr;
+ char* fake_security_target_name_override_ignored_port = nullptr;
+ gpr_split_host_port(target_name_override_,
+ &fake_security_target_name_override_hostname,
+ &fake_security_target_name_override_ignored_port);
+ if (strcmp(authority_hostname,
+ fake_security_target_name_override_hostname) != 0) {
+ gpr_log(GPR_ERROR,
+ "Authority (host) '%s' != Fake Security Target override '%s'",
+ host, fake_security_target_name_override_hostname);
+ abort();
+ }
+ gpr_free(fake_security_target_name_override_hostname);
+ gpr_free(fake_security_target_name_override_ignored_port);
+ } else if (strcmp(authority_hostname, target_hostname) != 0) {
+ gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'",
+ authority_hostname, target_hostname);
+ abort();
+ }
+ gpr_free(authority_hostname);
+ gpr_free(authority_ignored_port);
+ gpr_free(target_hostname);
+ gpr_free(target_ignored_port);
+ return true;
}
- if (is_lb_channel) {
- if (lbs_and_backends_size != 2) {
- gpr_log(GPR_ERROR,
- "Invalid expected targets arg value: '%s'. Expectations for LB "
- "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...",
- expected_targets);
- goto done;
+
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
+ }
+
+ char* target() const { return target_; }
+ char* expected_targets() const { return expected_targets_; }
+ bool is_lb_channel() const { return is_lb_channel_; }
+ char* target_name_override() const { return target_name_override_; }
+
+ private:
+ bool fake_check_target(const char* target_type, const char* target,
+ const char* set_str) const {
+ GPR_ASSERT(target_type != nullptr);
+ GPR_ASSERT(target != nullptr);
+ char** set = nullptr;
+ size_t set_size = 0;
+ gpr_string_split(set_str, ",", &set, &set_size);
+ bool found = false;
+ for (size_t i = 0; i < set_size; ++i) {
+ if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true;
}
- if (!fake_check_target("LB", target, lbs_and_backends[1])) {
- gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'",
- target, lbs_and_backends[1]);
- goto done;
+ for (size_t i = 0; i < set_size; ++i) {
+ gpr_free(set[i]);
}
- success = true;
- } else {
- if (!fake_check_target("Backend", target, lbs_and_backends[0])) {
- gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'",
- target, lbs_and_backends[0]);
+ gpr_free(set);
+ return found;
+ }
+
+ void fake_secure_name_check() const {
+ if (expected_targets_ == nullptr) return;
+ char** lbs_and_backends = nullptr;
+ size_t lbs_and_backends_size = 0;
+ bool success = false;
+ gpr_string_split(expected_targets_, ";", &lbs_and_backends,
+ &lbs_and_backends_size);
+ if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) {
+ gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'",
+ expected_targets_);
goto done;
}
- success = true;
- }
-done:
- for (size_t i = 0; i < lbs_and_backends_size; ++i) {
- gpr_free(lbs_and_backends[i]);
+ if (is_lb_channel_) {
+ if (lbs_and_backends_size != 2) {
+ gpr_log(GPR_ERROR,
+ "Invalid expected targets arg value: '%s'. Expectations for LB "
+ "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...",
+ expected_targets_);
+ goto done;
+ }
+ if (!fake_check_target("LB", target_, lbs_and_backends[1])) {
+ gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'",
+ target_, lbs_and_backends[1]);
+ goto done;
+ }
+ success = true;
+ } else {
+ if (!fake_check_target("Backend", target_, lbs_and_backends[0])) {
+ gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'",
+ target_, lbs_and_backends[0]);
+ goto done;
+ }
+ success = true;
+ }
+ done:
+ for (size_t i = 0; i < lbs_and_backends_size; ++i) {
+ gpr_free(lbs_and_backends[i]);
+ }
+ gpr_free(lbs_and_backends);
+ if (!success) abort();
}
- gpr_free(lbs_and_backends);
- if (!success) abort();
-}
-static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
+ char* target_;
+ char* expected_targets_;
+ bool is_lb_channel_;
+ char* target_name_override_;
+};
+
+static void fake_check_peer(
+ grpc_security_connector* sc, tsi_peer peer,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) {
const char* prop_name;
grpc_error* error = GRPC_ERROR_NONE;
*auth_context = nullptr;
@@ -147,164 +240,66 @@ static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer,
"Invalid value for cert type property.");
goto end;
}
- *auth_context = grpc_auth_context_create(nullptr);
+ *auth_context = grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
- *auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
+ auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_FAKE_TRANSPORT_SECURITY_TYPE);
end:
GRPC_CLOSURE_SCHED(on_peer_checked, error);
tsi_peer_destruct(&peer);
}
-static void fake_channel_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- fake_check_peer(sc, peer, auth_context, on_peer_checked);
- grpc_fake_channel_security_connector* c =
- reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
- fake_secure_name_check(c->target, c->expected_targets, c->is_lb_channel);
+void grpc_fake_channel_security_connector::check_peer(
+ tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) {
+ fake_check_peer(this, peer, auth_context, on_peer_checked);
+ fake_secure_name_check();
}
-static void fake_server_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- fake_check_peer(sc, peer, auth_context, on_peer_checked);
-}
+class grpc_fake_server_security_connector
+ : public grpc_server_security_connector {
+ public:
+ grpc_fake_server_security_connector(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME,
+ std::move(server_creds)) {}
+ ~grpc_fake_server_security_connector() override = default;
-static int fake_channel_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_fake_channel_security_connector* c1 =
- reinterpret_cast<grpc_fake_channel_security_connector*>(sc1);
- grpc_fake_channel_security_connector* c2 =
- reinterpret_cast<grpc_fake_channel_security_connector*>(sc2);
- int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
- if (c != 0) return c;
- c = strcmp(c1->target, c2->target);
- if (c != 0) return c;
- if (c1->expected_targets == nullptr || c2->expected_targets == nullptr) {
- c = GPR_ICMP(c1->expected_targets, c2->expected_targets);
- } else {
- c = strcmp(c1->expected_targets, c2->expected_targets);
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ fake_check_peer(this, peer, auth_context, on_peer_checked);
}
- if (c != 0) return c;
- return GPR_ICMP(c1->is_lb_channel, c2->is_lb_channel);
-}
-static int fake_server_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- return grpc_server_security_connector_cmp(
- reinterpret_cast<grpc_server_security_connector*>(sc1),
- reinterpret_cast<grpc_server_security_connector*>(sc2));
-}
-
-static bool fake_channel_check_call_host(grpc_channel_security_connector* sc,
- const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error) {
- grpc_fake_channel_security_connector* c =
- reinterpret_cast<grpc_fake_channel_security_connector*>(sc);
- char* authority_hostname = nullptr;
- char* authority_ignored_port = nullptr;
- char* target_hostname = nullptr;
- char* target_ignored_port = nullptr;
- gpr_split_host_port(host, &authority_hostname, &authority_ignored_port);
- gpr_split_host_port(c->target, &target_hostname, &target_ignored_port);
- if (c->target_name_override != nullptr) {
- char* fake_security_target_name_override_hostname = nullptr;
- char* fake_security_target_name_override_ignored_port = nullptr;
- gpr_split_host_port(c->target_name_override,
- &fake_security_target_name_override_hostname,
- &fake_security_target_name_override_ignored_port);
- if (strcmp(authority_hostname,
- fake_security_target_name_override_hostname) != 0) {
- gpr_log(GPR_ERROR,
- "Authority (host) '%s' != Fake Security Target override '%s'",
- host, fake_security_target_name_override_hostname);
- abort();
- }
- gpr_free(fake_security_target_name_override_hostname);
- gpr_free(fake_security_target_name_override_ignored_port);
- } else if (strcmp(authority_hostname, target_hostname) != 0) {
- gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'",
- authority_hostname, target_hostname);
- abort();
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr) override {
+ grpc_handshake_manager_add(
+ handshake_mgr,
+ grpc_security_handshaker_create(
+ tsi_create_fake_handshaker(/*=is_client*/ false), this));
}
- gpr_free(authority_hostname);
- gpr_free(authority_ignored_port);
- gpr_free(target_hostname);
- gpr_free(target_ignored_port);
- return true;
-}
-static void fake_channel_cancel_check_call_host(
- grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
- grpc_error* error) {
- GRPC_ERROR_UNREF(error);
-}
-
-static void fake_channel_add_handshakers(
- grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- grpc_handshake_manager_add(
- handshake_mgr,
- grpc_security_handshaker_create(
- tsi_create_fake_handshaker(true /* is_client */), &sc->base));
-}
-
-static void fake_server_add_handshakers(grpc_server_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- grpc_handshake_manager_add(
- handshake_mgr,
- grpc_security_handshaker_create(
- tsi_create_fake_handshaker(false /* is_client */), &sc->base));
-}
-
-static grpc_security_connector_vtable fake_channel_vtable = {
- fake_channel_destroy, fake_channel_check_peer, fake_channel_cmp};
-
-static grpc_security_connector_vtable fake_server_vtable = {
- fake_server_destroy, fake_server_check_peer, fake_server_cmp};
-
-grpc_channel_security_connector* grpc_fake_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds, const char* target,
- const grpc_channel_args* args) {
- grpc_fake_channel_security_connector* c =
- static_cast<grpc_fake_channel_security_connector*>(
- gpr_zalloc(sizeof(*c)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME;
- c->base.base.vtable = &fake_channel_vtable;
- c->base.channel_creds = channel_creds;
- c->base.request_metadata_creds =
- grpc_call_credentials_ref(request_metadata_creds);
- c->base.check_call_host = fake_channel_check_call_host;
- c->base.cancel_check_call_host = fake_channel_cancel_check_call_host;
- c->base.add_handshakers = fake_channel_add_handshakers;
- c->target = gpr_strdup(target);
- const char* expected_targets = grpc_fake_transport_get_expected_targets(args);
- c->expected_targets = gpr_strdup(expected_targets);
- c->is_lb_channel = grpc_core::FindTargetAuthorityTableInArgs(args) != nullptr;
- const grpc_arg* target_name_override_arg =
- grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG);
- if (target_name_override_arg != nullptr) {
- c->target_name_override =
- gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg));
+ int cmp(const grpc_security_connector* other) const override {
+ return server_security_connector_cmp(
+ static_cast<const grpc_server_security_connector*>(other));
}
- return &c->base;
+};
+} // namespace
+
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_fake_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target, const grpc_channel_args* args) {
+ return grpc_core::MakeRefCounted<grpc_fake_channel_security_connector>(
+ std::move(channel_creds), std::move(request_metadata_creds), target,
+ args);
}
-grpc_server_security_connector* grpc_fake_server_security_connector_create(
- grpc_server_credentials* server_creds) {
- grpc_server_security_connector* c =
- static_cast<grpc_server_security_connector*>(
- gpr_zalloc(sizeof(grpc_server_security_connector)));
- gpr_ref_init(&c->base.refcount, 1);
- c->base.vtable = &fake_server_vtable;
- c->base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME;
- c->server_creds = server_creds;
- c->add_handshakers = fake_server_add_handshakers;
- return c;
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_fake_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
+ return grpc_core::MakeRefCounted<grpc_fake_server_security_connector>(
+ std::move(server_creds));
}
diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.h b/src/core/lib/security/security_connector/fake/fake_security_connector.h
index fdfe048c6e..344a2349a4 100644
--- a/src/core/lib/security/security_connector/fake/fake_security_connector.h
+++ b/src/core/lib/security/security_connector/fake/fake_security_connector.h
@@ -24,19 +24,22 @@
#include <grpc/grpc_security.h>
#include "src/core/lib/channel/handshaker.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/security_connector/security_connector.h"
#define GRPC_FAKE_SECURITY_URL_SCHEME "http+fake_security"
/* Creates a fake connector that emulates real channel security. */
-grpc_channel_security_connector* grpc_fake_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds, const char* target,
- const grpc_channel_args* args);
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_fake_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target, const grpc_channel_args* args);
/* Creates a fake connector that emulates real server security. */
-grpc_server_security_connector* grpc_fake_server_security_connector_create(
- grpc_server_credentials* server_creds);
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_fake_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_FAKE_FAKE_SECURITY_CONNECTOR_H \
*/
diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc
index 008a98df28..7cc482c16c 100644
--- a/src/core/lib/security/security_connector/local/local_security_connector.cc
+++ b/src/core/lib/security/security_connector/local/local_security_connector.cc
@@ -30,217 +30,224 @@
#include "src/core/ext/filters/client_channel/client_channel.h"
#include "src/core/lib/channel/channel_args.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/pollset.h"
+#include "src/core/lib/iomgr/resolve_address.h"
+#include "src/core/lib/iomgr/sockaddr.h"
+#include "src/core/lib/iomgr/sockaddr_utils.h"
+#include "src/core/lib/iomgr/socket_utils.h"
+#include "src/core/lib/iomgr/unix_sockets_posix.h"
#include "src/core/lib/security/credentials/local/local_credentials.h"
#include "src/core/lib/security/transport/security_handshaker.h"
#include "src/core/tsi/local_transport_security.h"
#define GRPC_UDS_URI_PATTERN "unix:"
-#define GRPC_UDS_URL_SCHEME "unix"
#define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local"
-typedef struct {
- grpc_channel_security_connector base;
- char* target_name;
-} grpc_local_channel_security_connector;
+namespace {
-typedef struct {
- grpc_server_security_connector base;
-} grpc_local_server_security_connector;
-
-static void local_channel_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_local_channel_security_connector*>(sc);
- grpc_call_credentials_unref(c->base.request_metadata_creds);
- grpc_channel_credentials_unref(c->base.channel_creds);
- gpr_free(c->target_name);
- gpr_free(sc);
-}
-
-static void local_server_destroy(grpc_security_connector* sc) {
- if (sc == nullptr) {
- return;
- }
- auto c = reinterpret_cast<grpc_local_server_security_connector*>(sc);
- grpc_server_credentials_unref(c->base.server_creds);
- gpr_free(sc);
-}
-
-static void local_channel_add_handshakers(
- grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
- TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static void local_server_add_handshakers(
- grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_manager) {
- tsi_handshaker* handshaker = nullptr;
- GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) ==
- TSI_OK);
- grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create(
- handshaker, &sc->base));
-}
-
-static int local_channel_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_local_channel_security_connector* c1 =
- reinterpret_cast<grpc_local_channel_security_connector*>(sc1);
- grpc_local_channel_security_connector* c2 =
- reinterpret_cast<grpc_local_channel_security_connector*>(sc2);
- int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
- if (c != 0) return c;
- return strcmp(c1->target_name, c2->target_name);
-}
-
-static int local_server_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_local_server_security_connector* c1 =
- reinterpret_cast<grpc_local_server_security_connector*>(sc1);
- grpc_local_server_security_connector* c2 =
- reinterpret_cast<grpc_local_server_security_connector*>(sc2);
- return grpc_server_security_connector_cmp(&c1->base, &c2->base);
-}
-
-static grpc_security_status local_auth_context_create(grpc_auth_context** ctx) {
- if (ctx == nullptr) {
- gpr_log(GPR_ERROR, "Invalid arguments to local_auth_context_create()");
- return GRPC_SECURITY_ERROR;
- }
+grpc_core::RefCountedPtr<grpc_auth_context> local_auth_context_create() {
/* Create auth context. */
- *ctx = grpc_auth_context_create(nullptr);
+ grpc_core::RefCountedPtr<grpc_auth_context> ctx =
+ grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
- *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
+ ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_LOCAL_TRANSPORT_SECURITY_TYPE);
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
- *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1);
- return GRPC_SECURITY_OK;
+ ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1);
+ return ctx;
}
-static void local_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_security_status status;
+void local_check_peer(grpc_security_connector* sc, tsi_peer peer,
+ grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked,
+ grpc_local_connect_type type) {
+ int fd = grpc_endpoint_get_fd(ep);
+ grpc_resolved_address resolved_addr;
+ memset(&resolved_addr, 0, sizeof(resolved_addr));
+ resolved_addr.len = GRPC_MAX_SOCKADDR_SIZE;
+ bool is_endpoint_local = false;
+ if (getsockname(fd, reinterpret_cast<grpc_sockaddr*>(resolved_addr.addr),
+ &resolved_addr.len) == 0) {
+ grpc_resolved_address addr_normalized;
+ grpc_resolved_address* addr =
+ grpc_sockaddr_is_v4mapped(&resolved_addr, &addr_normalized)
+ ? &addr_normalized
+ : &resolved_addr;
+ grpc_sockaddr* sock_addr = reinterpret_cast<grpc_sockaddr*>(&addr->addr);
+ // UDS
+ if (type == UDS && grpc_is_unix_socket(addr)) {
+ is_endpoint_local = true;
+ // IPV4
+ } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET) {
+ const grpc_sockaddr_in* addr4 =
+ reinterpret_cast<const grpc_sockaddr_in*>(sock_addr);
+ if (grpc_htonl(addr4->sin_addr.s_addr) == INADDR_LOOPBACK) {
+ is_endpoint_local = true;
+ }
+ // IPv6
+ } else if (type == LOCAL_TCP && sock_addr->sa_family == GRPC_AF_INET6) {
+ const grpc_sockaddr_in6* addr6 =
+ reinterpret_cast<const grpc_sockaddr_in6*>(addr);
+ if (memcmp(&addr6->sin6_addr, &in6addr_loopback,
+ sizeof(in6addr_loopback)) == 0) {
+ is_endpoint_local = true;
+ }
+ }
+ }
+ grpc_error* error = GRPC_ERROR_NONE;
+ if (!is_endpoint_local) {
+ error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "Endpoint is neither UDS or TCP loopback address.");
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
+ return;
+ }
/* Create an auth context which is necessary to pass the santiy check in
* {client, server}_auth_filter that verifies if the peer's auth context is
* obtained during handshakes. The auth context is only checked for its
* existence and not actually used.
*/
- status = local_auth_context_create(auth_context);
- grpc_error* error = status == GRPC_SECURITY_OK
- ? GRPC_ERROR_NONE
- : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "Could not create local auth context");
+ *auth_context = local_auth_context_create();
+ error = *auth_context != nullptr ? GRPC_ERROR_NONE
+ : GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "Could not create local auth context");
GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
-static grpc_security_connector_vtable local_channel_vtable = {
- local_channel_destroy, local_check_peer, local_channel_cmp};
-
-static grpc_security_connector_vtable local_server_vtable = {
- local_server_destroy, local_check_peer, local_server_cmp};
-
-static bool local_check_call_host(grpc_channel_security_connector* sc,
- const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error) {
- grpc_local_channel_security_connector* local_sc =
- reinterpret_cast<grpc_local_channel_security_connector*>(sc);
- if (host == nullptr || local_sc == nullptr ||
- strcmp(host, local_sc->target_name) != 0) {
- *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "local call host does not match target name");
+class grpc_local_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ grpc_local_channel_security_connector(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const char* target_name)
+ : grpc_channel_security_connector(nullptr, std::move(channel_creds),
+ std::move(request_metadata_creds)),
+ target_name_(gpr_strdup(target_name)) {}
+
+ ~grpc_local_channel_security_connector() override { gpr_free(target_name_); }
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) ==
+ TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
}
- return true;
-}
-static void local_cancel_check_call_host(grpc_channel_security_connector* sc,
- grpc_closure* on_call_host_checked,
- grpc_error* error) {
- GRPC_ERROR_UNREF(error);
-}
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_local_channel_security_connector*>(
+ other_sc);
+ int c = channel_security_connector_cmp(other);
+ if (c != 0) return c;
+ return strcmp(target_name_, other->target_name_);
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ grpc_local_credentials* creds =
+ reinterpret_cast<grpc_local_credentials*>(mutable_channel_creds());
+ local_check_peer(this, peer, ep, auth_context, on_peer_checked,
+ creds->connect_type());
+ }
+
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ if (host == nullptr || strcmp(host, target_name_) != 0) {
+ *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "local call host does not match target name");
+ }
+ return true;
+ }
+
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
+ }
-grpc_security_status grpc_local_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds,
- const grpc_channel_args* args, const char* target_name,
- grpc_channel_security_connector** sc) {
- if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) {
+ const char* target_name() const { return target_name_; }
+
+ private:
+ char* target_name_;
+};
+
+class grpc_local_server_security_connector final
+ : public grpc_server_security_connector {
+ public:
+ grpc_local_server_security_connector(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_server_security_connector(nullptr, std::move(server_creds)) {}
+ ~grpc_local_server_security_connector() override = default;
+
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_manager) override {
+ tsi_handshaker* handshaker = nullptr;
+ GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */,
+ &handshaker) == TSI_OK);
+ grpc_handshake_manager_add(
+ handshake_manager, grpc_security_handshaker_create(handshaker, this));
+ }
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ grpc_local_server_credentials* creds =
+ static_cast<grpc_local_server_credentials*>(mutable_server_creds());
+ local_check_peer(this, peer, ep, auth_context, on_peer_checked,
+ creds->connect_type());
+ }
+
+ int cmp(const grpc_security_connector* other) const override {
+ return server_security_connector_cmp(
+ static_cast<const grpc_server_security_connector*>(other));
+ }
+};
+} // namespace
+
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_local_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const grpc_channel_args* args, const char* target_name) {
+ if (channel_creds == nullptr || target_name == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_local_channel_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- // Check if local_connect_type is UDS. Only UDS is supported for now.
+ // Perform sanity check on UDS address. For TCP local connection, the check
+ // will be done during check_peer procedure.
grpc_local_credentials* creds =
- reinterpret_cast<grpc_local_credentials*>(channel_creds);
- if (creds->connect_type != UDS) {
- gpr_log(GPR_ERROR,
- "Invalid local channel type to "
- "grpc_local_channel_security_connector_create()");
- return GRPC_SECURITY_ERROR;
- }
- // Check if target_name is a valid UDS address.
+ static_cast<grpc_local_credentials*>(channel_creds.get());
const grpc_arg* server_uri_arg =
grpc_channel_args_find(args, GRPC_ARG_SERVER_URI);
const char* server_uri_str = grpc_channel_arg_get_string(server_uri_arg);
- if (strncmp(GRPC_UDS_URI_PATTERN, server_uri_str,
+ if (creds->connect_type() == UDS &&
+ strncmp(GRPC_UDS_URI_PATTERN, server_uri_str,
strlen(GRPC_UDS_URI_PATTERN)) != 0) {
gpr_log(GPR_ERROR,
- "Invalid target_name to "
+ "Invalid UDS target name to "
"grpc_local_channel_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_local_channel_security_connector*>(
- gpr_zalloc(sizeof(grpc_local_channel_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &local_channel_vtable;
- c->base.add_handshakers = local_channel_add_handshakers;
- c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
- c->base.request_metadata_creds =
- grpc_call_credentials_ref(request_metadata_creds);
- c->base.check_call_host = local_check_call_host;
- c->base.cancel_check_call_host = local_cancel_check_call_host;
- c->base.base.url_scheme =
- creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr;
- c->target_name = gpr_strdup(target_name);
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_local_channel_security_connector>(
+ channel_creds, request_metadata_creds, target_name);
}
-grpc_security_status grpc_local_server_security_connector_create(
- grpc_server_credentials* server_creds,
- grpc_server_security_connector** sc) {
- if (server_creds == nullptr || sc == nullptr) {
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_local_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds) {
+ if (server_creds == nullptr) {
gpr_log(
GPR_ERROR,
"Invalid arguments to grpc_local_server_security_connector_create()");
- return GRPC_SECURITY_ERROR;
- }
- // Check if local_connect_type is UDS. Only UDS is supported for now.
- grpc_local_server_credentials* creds =
- reinterpret_cast<grpc_local_server_credentials*>(server_creds);
- if (creds->connect_type != UDS) {
- gpr_log(GPR_ERROR,
- "Invalid local server type to "
- "grpc_local_server_security_connector_create()");
- return GRPC_SECURITY_ERROR;
+ return nullptr;
}
- auto c = static_cast<grpc_local_server_security_connector*>(
- gpr_zalloc(sizeof(grpc_local_server_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &local_server_vtable;
- c->base.server_creds = grpc_server_credentials_ref(server_creds);
- c->base.base.url_scheme =
- creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr;
- c->base.add_handshakers = local_server_add_handshakers;
- *sc = &c->base;
- return GRPC_SECURITY_OK;
+ return grpc_core::MakeRefCounted<grpc_local_server_security_connector>(
+ std::move(server_creds));
}
diff --git a/src/core/lib/security/security_connector/local/local_security_connector.h b/src/core/lib/security/security_connector/local/local_security_connector.h
index 5369a2127a..6eee0ca9a6 100644
--- a/src/core/lib/security/security_connector/local/local_security_connector.h
+++ b/src/core/lib/security/security_connector/local/local_security_connector.h
@@ -34,13 +34,13 @@
* - sc: address of local channel security connector instance to be returned
* from the method.
*
- * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure.
+ * It returns nullptr on failure.
*/
-grpc_security_status grpc_local_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds,
- const grpc_channel_args* args, const char* target_name,
- grpc_channel_security_connector** sc);
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_local_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const grpc_channel_args* args, const char* target_name);
/**
* This method creates a local server security connector.
@@ -49,10 +49,11 @@ grpc_security_status grpc_local_channel_security_connector_create(
* - sc: address of local server security connector instance to be returned from
* the method.
*
- * It returns GRPC_SECURITY_OK on success, and an error status code on failure.
+ * It returns nullptr on failure.
*/
-grpc_security_status grpc_local_server_security_connector_create(
- grpc_server_credentials* server_creds, grpc_server_security_connector** sc);
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_local_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_LOCAL_LOCAL_SECURITY_CONNECTOR_H \
*/
diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc
index 02cecb0eb1..96a1960546 100644
--- a/src/core/lib/security/security_connector/security_connector.cc
+++ b/src/core/lib/security/security_connector/security_connector.cc
@@ -35,150 +35,67 @@
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/security_connector/load_system_roots.h"
+#include "src/core/lib/security/security_connector/security_connector.h"
#include "src/core/lib/security/transport/security_handshaker.h"
grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount(
false, "security_connector_refcount");
-void grpc_channel_security_connector_add_handshakers(
- grpc_channel_security_connector* connector,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- if (connector != nullptr) {
- connector->add_handshakers(connector, interested_parties, handshake_mgr);
- }
-}
-
-void grpc_server_security_connector_add_handshakers(
- grpc_server_security_connector* connector,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- if (connector != nullptr) {
- connector->add_handshakers(connector, interested_parties, handshake_mgr);
- }
-}
-
-void grpc_security_connector_check_peer(grpc_security_connector* sc,
- tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- if (sc == nullptr) {
- GRPC_CLOSURE_SCHED(on_peer_checked,
- GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "cannot check peer -- no security connector"));
- tsi_peer_destruct(&peer);
- } else {
- sc->vtable->check_peer(sc, peer, auth_context, on_peer_checked);
- }
-}
-
-int grpc_security_connector_cmp(grpc_security_connector* sc,
- grpc_security_connector* other) {
+grpc_server_security_connector::grpc_server_security_connector(
+ const char* url_scheme,
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_security_connector(url_scheme),
+ server_creds_(std::move(server_creds)) {}
+
+grpc_channel_security_connector::grpc_channel_security_connector(
+ const char* url_scheme,
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds)
+ : grpc_security_connector(url_scheme),
+ channel_creds_(std::move(channel_creds)),
+ request_metadata_creds_(std::move(request_metadata_creds)) {}
+grpc_channel_security_connector::~grpc_channel_security_connector() {}
+
+int grpc_security_connector_cmp(const grpc_security_connector* sc,
+ const grpc_security_connector* other) {
if (sc == nullptr || other == nullptr) return GPR_ICMP(sc, other);
- int c = GPR_ICMP(sc->vtable, other->vtable);
- if (c != 0) return c;
- return sc->vtable->cmp(sc, other);
+ return sc->cmp(other);
}
-int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1,
- grpc_channel_security_connector* sc2) {
- GPR_ASSERT(sc1->channel_creds != nullptr);
- GPR_ASSERT(sc2->channel_creds != nullptr);
- int c = GPR_ICMP(sc1->channel_creds, sc2->channel_creds);
- if (c != 0) return c;
- c = GPR_ICMP(sc1->request_metadata_creds, sc2->request_metadata_creds);
- if (c != 0) return c;
- c = GPR_ICMP((void*)sc1->check_call_host, (void*)sc2->check_call_host);
- if (c != 0) return c;
- c = GPR_ICMP((void*)sc1->cancel_check_call_host,
- (void*)sc2->cancel_check_call_host);
+int grpc_channel_security_connector::channel_security_connector_cmp(
+ const grpc_channel_security_connector* other) const {
+ const grpc_channel_security_connector* other_sc =
+ static_cast<const grpc_channel_security_connector*>(other);
+ GPR_ASSERT(channel_creds() != nullptr);
+ GPR_ASSERT(other_sc->channel_creds() != nullptr);
+ int c = GPR_ICMP(channel_creds(), other_sc->channel_creds());
if (c != 0) return c;
- return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers);
+ return GPR_ICMP(request_metadata_creds(), other_sc->request_metadata_creds());
}
-int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1,
- grpc_server_security_connector* sc2) {
- GPR_ASSERT(sc1->server_creds != nullptr);
- GPR_ASSERT(sc2->server_creds != nullptr);
- int c = GPR_ICMP(sc1->server_creds, sc2->server_creds);
- if (c != 0) return c;
- return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers);
-}
-
-bool grpc_channel_security_connector_check_call_host(
- grpc_channel_security_connector* sc, const char* host,
- grpc_auth_context* auth_context, grpc_closure* on_call_host_checked,
- grpc_error** error) {
- if (sc == nullptr || sc->check_call_host == nullptr) {
- *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "cannot check call host -- no security connector");
- return true;
- }
- return sc->check_call_host(sc, host, auth_context, on_call_host_checked,
- error);
-}
-
-void grpc_channel_security_connector_cancel_check_call_host(
- grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
- grpc_error* error) {
- if (sc == nullptr || sc->cancel_check_call_host == nullptr) {
- GRPC_ERROR_UNREF(error);
- return;
- }
- sc->cancel_check_call_host(sc, on_call_host_checked, error);
-}
-
-#ifndef NDEBUG
-grpc_security_connector* grpc_security_connector_ref(
- grpc_security_connector* sc, const char* file, int line,
- const char* reason) {
- if (sc == nullptr) return nullptr;
- if (grpc_trace_security_connector_refcount.enabled()) {
- gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count);
- gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
- "SECURITY_CONNECTOR:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", sc,
- val, val + 1, reason);
- }
-#else
-grpc_security_connector* grpc_security_connector_ref(
- grpc_security_connector* sc) {
- if (sc == nullptr) return nullptr;
-#endif
- gpr_ref(&sc->refcount);
- return sc;
-}
-
-#ifndef NDEBUG
-void grpc_security_connector_unref(grpc_security_connector* sc,
- const char* file, int line,
- const char* reason) {
- if (sc == nullptr) return;
- if (grpc_trace_security_connector_refcount.enabled()) {
- gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count);
- gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG,
- "SECURITY_CONNECTOR:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", sc,
- val, val - 1, reason);
- }
-#else
-void grpc_security_connector_unref(grpc_security_connector* sc) {
- if (sc == nullptr) return;
-#endif
- if (gpr_unref(&sc->refcount)) sc->vtable->destroy(sc);
+int grpc_server_security_connector::server_security_connector_cmp(
+ const grpc_server_security_connector* other) const {
+ const grpc_server_security_connector* other_sc =
+ static_cast<const grpc_server_security_connector*>(other);
+ GPR_ASSERT(server_creds() != nullptr);
+ GPR_ASSERT(other_sc->server_creds() != nullptr);
+ return GPR_ICMP(server_creds(), other_sc->server_creds());
}
static void connector_arg_destroy(void* p) {
- GRPC_SECURITY_CONNECTOR_UNREF((grpc_security_connector*)p,
- "connector_arg_destroy");
+ static_cast<grpc_security_connector*>(p)->Unref(DEBUG_LOCATION,
+ "connector_arg_destroy");
}
static void* connector_arg_copy(void* p) {
- return GRPC_SECURITY_CONNECTOR_REF((grpc_security_connector*)p,
- "connector_arg_copy");
+ return static_cast<grpc_security_connector*>(p)
+ ->Ref(DEBUG_LOCATION, "connector_arg_copy")
+ .release();
}
static int connector_cmp(void* a, void* b) {
- return grpc_security_connector_cmp(static_cast<grpc_security_connector*>(a),
- static_cast<grpc_security_connector*>(b));
+ return static_cast<grpc_security_connector*>(a)->cmp(
+ static_cast<grpc_security_connector*>(b));
}
static const grpc_arg_pointer_vtable connector_arg_vtable = {
diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h
index 4c921a8793..74b0ef21a6 100644
--- a/src/core/lib/security/security_connector/security_connector.h
+++ b/src/core/lib/security/security_connector/security_connector.h
@@ -26,6 +26,7 @@
#include <grpc/grpc_security.h>
#include "src/core/lib/channel/handshaker.h"
+#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/iomgr/endpoint.h"
#include "src/core/lib/iomgr/pollset.h"
#include "src/core/lib/iomgr/tcp_server.h"
@@ -34,8 +35,6 @@
extern grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount;
-/* --- status enum. --- */
-
typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status;
/* --- security_connector object. ---
@@ -43,54 +42,34 @@ typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status;
A security connector object represents away to configure the underlying
transport security mechanism and check the resulting trusted peer. */
-typedef struct grpc_security_connector grpc_security_connector;
-
#define GRPC_ARG_SECURITY_CONNECTOR "grpc.security_connector"
-typedef struct {
- void (*destroy)(grpc_security_connector* sc);
- void (*check_peer)(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked);
- int (*cmp)(grpc_security_connector* sc, grpc_security_connector* other);
-} grpc_security_connector_vtable;
-
-struct grpc_security_connector {
- const grpc_security_connector_vtable* vtable;
- gpr_refcount refcount;
- const char* url_scheme;
-};
+class grpc_security_connector
+ : public grpc_core::RefCounted<grpc_security_connector> {
+ public:
+ explicit grpc_security_connector(const char* url_scheme)
+ : grpc_core::RefCounted<grpc_security_connector>(
+ &grpc_trace_security_connector_refcount),
+ url_scheme_(url_scheme) {}
+ virtual ~grpc_security_connector() = default;
+
+ /* Check the peer. Callee takes ownership of the peer object.
+ When done, sets *auth_context and invokes on_peer_checked. */
+ virtual void check_peer(
+ tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) GRPC_ABSTRACT;
+
+ /* Compares two security connectors. */
+ virtual int cmp(const grpc_security_connector* other) const GRPC_ABSTRACT;
+
+ const char* url_scheme() const { return url_scheme_; }
-/* Refcounting. */
-#ifndef NDEBUG
-#define GRPC_SECURITY_CONNECTOR_REF(p, r) \
- grpc_security_connector_ref((p), __FILE__, __LINE__, (r))
-#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) \
- grpc_security_connector_unref((p), __FILE__, __LINE__, (r))
-grpc_security_connector* grpc_security_connector_ref(
- grpc_security_connector* policy, const char* file, int line,
- const char* reason);
-void grpc_security_connector_unref(grpc_security_connector* policy,
- const char* file, int line,
- const char* reason);
-#else
-#define GRPC_SECURITY_CONNECTOR_REF(p, r) grpc_security_connector_ref((p))
-#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) grpc_security_connector_unref((p))
-grpc_security_connector* grpc_security_connector_ref(
- grpc_security_connector* policy);
-void grpc_security_connector_unref(grpc_security_connector* policy);
-#endif
-
-/* Check the peer. Callee takes ownership of the peer object.
- When done, sets *auth_context and invokes on_peer_checked. */
-void grpc_security_connector_check_peer(grpc_security_connector* sc,
- tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked);
-
-/* Compares two security connectors. */
-int grpc_security_connector_cmp(grpc_security_connector* sc,
- grpc_security_connector* other);
+ GRPC_ABSTRACT_BASE_CLASS
+
+ private:
+ const char* url_scheme_;
+};
/* Util to encapsulate the connector in a channel arg. */
grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc);
@@ -107,71 +86,89 @@ grpc_security_connector* grpc_security_connector_find_in_args(
A channel security connector object represents a way to configure the
underlying transport security mechanism on the client side. */
-typedef struct grpc_channel_security_connector grpc_channel_security_connector;
-
-struct grpc_channel_security_connector {
- grpc_security_connector base;
- grpc_channel_credentials* channel_creds;
- grpc_call_credentials* request_metadata_creds;
- bool (*check_call_host)(grpc_channel_security_connector* sc, const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error);
- void (*cancel_check_call_host)(grpc_channel_security_connector* sc,
- grpc_closure* on_call_host_checked,
- grpc_error* error);
- void (*add_handshakers)(grpc_channel_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr);
+class grpc_channel_security_connector : public grpc_security_connector {
+ public:
+ grpc_channel_security_connector(
+ const char* url_scheme,
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds);
+ ~grpc_channel_security_connector() override;
+
+ /// Checks that the host that will be set for a call is acceptable.
+ /// Returns true if completed synchronously, in which case \a error will
+ /// be set to indicate the result. Otherwise, \a on_call_host_checked
+ /// will be invoked when complete.
+ virtual bool check_call_host(const char* host,
+ grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) GRPC_ABSTRACT;
+ /// Cancels a pending asychronous call to
+ /// grpc_channel_security_connector_check_call_host() with
+ /// \a on_call_host_checked as its callback.
+ virtual void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) GRPC_ABSTRACT;
+ /// Registers handshakers with \a handshake_mgr.
+ virtual void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr)
+ GRPC_ABSTRACT;
+
+ const grpc_channel_credentials* channel_creds() const {
+ return channel_creds_.get();
+ }
+ grpc_channel_credentials* mutable_channel_creds() {
+ return channel_creds_.get();
+ }
+ const grpc_call_credentials* request_metadata_creds() const {
+ return request_metadata_creds_.get();
+ }
+ grpc_call_credentials* mutable_request_metadata_creds() {
+ return request_metadata_creds_.get();
+ }
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ protected:
+ // Helper methods to be used in subclasses.
+ int channel_security_connector_cmp(
+ const grpc_channel_security_connector* other) const;
+
+ private:
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds_;
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds_;
};
-/// A helper function for use in grpc_security_connector_cmp() implementations.
-int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1,
- grpc_channel_security_connector* sc2);
-
-/// Checks that the host that will be set for a call is acceptable.
-/// Returns true if completed synchronously, in which case \a error will
-/// be set to indicate the result. Otherwise, \a on_call_host_checked
-/// will be invoked when complete.
-bool grpc_channel_security_connector_check_call_host(
- grpc_channel_security_connector* sc, const char* host,
- grpc_auth_context* auth_context, grpc_closure* on_call_host_checked,
- grpc_error** error);
-
-/// Cancels a pending asychronous call to
-/// grpc_channel_security_connector_check_call_host() with
-/// \a on_call_host_checked as its callback.
-void grpc_channel_security_connector_cancel_check_call_host(
- grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
- grpc_error* error);
-
-/* Registers handshakers with \a handshake_mgr. */
-void grpc_channel_security_connector_add_handshakers(
- grpc_channel_security_connector* connector,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr);
-
/* --- server_security_connector object. ---
A server security connector object represents a way to configure the
underlying transport security mechanism on the server side. */
-typedef struct grpc_server_security_connector grpc_server_security_connector;
-
-struct grpc_server_security_connector {
- grpc_security_connector base;
- grpc_server_credentials* server_creds;
- void (*add_handshakers)(grpc_server_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr);
+class grpc_server_security_connector : public grpc_security_connector {
+ public:
+ grpc_server_security_connector(
+ const char* url_scheme,
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds);
+ ~grpc_server_security_connector() override = default;
+
+ virtual void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr)
+ GRPC_ABSTRACT;
+
+ const grpc_server_credentials* server_creds() const {
+ return server_creds_.get();
+ }
+ grpc_server_credentials* mutable_server_creds() {
+ return server_creds_.get();
+ }
+
+ GRPC_ABSTRACT_BASE_CLASS
+
+ protected:
+ // Helper methods to be used in subclasses.
+ int server_security_connector_cmp(
+ const grpc_server_security_connector* other) const;
+
+ private:
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds_;
};
-/// A helper function for use in grpc_security_connector_cmp() implementations.
-int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1,
- grpc_server_security_connector* sc2);
-
-void grpc_server_security_connector_add_handshakers(
- grpc_server_security_connector* sc, grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr);
-
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */
diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
index 20a9533dd1..7414ab1a37 100644
--- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
+++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc
@@ -30,6 +30,7 @@
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/credentials/credentials.h"
#include "src/core/lib/security/credentials/ssl/ssl_credentials.h"
@@ -39,172 +40,10 @@
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security.h"
-typedef struct {
- grpc_channel_security_connector base;
- tsi_ssl_client_handshaker_factory* client_handshaker_factory;
- char* target_name;
- char* overridden_target_name;
- const verify_peer_options* verify_options;
-} grpc_ssl_channel_security_connector;
-
-typedef struct {
- grpc_server_security_connector base;
- tsi_ssl_server_handshaker_factory* server_handshaker_factory;
-} grpc_ssl_server_security_connector;
-
-static bool server_connector_has_cert_config_fetcher(
- grpc_ssl_server_security_connector* c) {
- GPR_ASSERT(c != nullptr);
- grpc_ssl_server_credentials* server_creds =
- reinterpret_cast<grpc_ssl_server_credentials*>(c->base.server_creds);
- GPR_ASSERT(server_creds != nullptr);
- return server_creds->certificate_config_fetcher.cb != nullptr;
-}
-
-static void ssl_channel_destroy(grpc_security_connector* sc) {
- grpc_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
- grpc_channel_credentials_unref(c->base.channel_creds);
- grpc_call_credentials_unref(c->base.request_metadata_creds);
- tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory);
- c->client_handshaker_factory = nullptr;
- if (c->target_name != nullptr) gpr_free(c->target_name);
- if (c->overridden_target_name != nullptr) gpr_free(c->overridden_target_name);
- gpr_free(sc);
-}
-
-static void ssl_server_destroy(grpc_security_connector* sc) {
- grpc_ssl_server_security_connector* c =
- reinterpret_cast<grpc_ssl_server_security_connector*>(sc);
- grpc_server_credentials_unref(c->base.server_creds);
- tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory);
- c->server_handshaker_factory = nullptr;
- gpr_free(sc);
-}
-
-static void ssl_channel_add_handshakers(grpc_channel_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- grpc_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
- // Instantiate TSI handshaker.
- tsi_handshaker* tsi_hs = nullptr;
- tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
- c->client_handshaker_factory,
- c->overridden_target_name != nullptr ? c->overridden_target_name
- : c->target_name,
- &tsi_hs);
- if (result != TSI_OK) {
- gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
- tsi_result_to_string(result));
- return;
- }
- // Create handshakers.
- grpc_handshake_manager_add(
- handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base));
-}
-
-/* Attempts to replace the server_handshaker_factory with a new factory using
- * the provided grpc_ssl_server_certificate_config. Should new factory creation
- * fail, the existing factory will not be replaced. Returns true on success (new
- * factory created). */
-static bool try_replace_server_handshaker_factory(
- grpc_ssl_server_security_connector* sc,
- const grpc_ssl_server_certificate_config* config) {
- if (config == nullptr) {
- gpr_log(GPR_ERROR,
- "Server certificate config callback returned invalid (NULL) "
- "config.");
- return false;
- }
- gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config);
-
- size_t num_alpn_protocols = 0;
- const char** alpn_protocol_strings =
- grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
- tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
- config->pem_key_cert_pairs, config->num_key_cert_pairs);
- tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr;
- grpc_ssl_server_credentials* server_creds =
- reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds);
- tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
- cert_pairs, config->num_key_cert_pairs, config->pem_root_certs,
- grpc_get_tsi_client_certificate_request_type(
- server_creds->config.client_certificate_request),
- grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
- static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory);
- gpr_free(cert_pairs);
- gpr_free((void*)alpn_protocol_strings);
-
- if (result != TSI_OK) {
- gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
- tsi_result_to_string(result));
- return false;
- }
- tsi_ssl_server_handshaker_factory_unref(sc->server_handshaker_factory);
- sc->server_handshaker_factory = new_handshaker_factory;
- return true;
-}
-
-/* Attempts to fetch the server certificate config if a callback is available.
- * Current certificate config will continue to be used if the callback returns
- * an error. Returns true if new credentials were sucessfully loaded. */
-static bool try_fetch_ssl_server_credentials(
- grpc_ssl_server_security_connector* sc) {
- grpc_ssl_server_certificate_config* certificate_config = nullptr;
- bool status;
-
- GPR_ASSERT(sc != nullptr);
- if (!server_connector_has_cert_config_fetcher(sc)) return false;
-
- grpc_ssl_server_credentials* server_creds =
- reinterpret_cast<grpc_ssl_server_credentials*>(sc->base.server_creds);
- grpc_ssl_certificate_config_reload_status cb_result =
- server_creds->certificate_config_fetcher.cb(
- server_creds->certificate_config_fetcher.user_data,
- &certificate_config);
- if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) {
- gpr_log(GPR_DEBUG, "No change in SSL server credentials.");
- status = false;
- } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) {
- status = try_replace_server_handshaker_factory(sc, certificate_config);
- } else {
- // Log error, continue using previously-loaded credentials.
- gpr_log(GPR_ERROR,
- "Failed fetching new server credentials, continuing to "
- "use previously-loaded credentials.");
- status = false;
- }
-
- if (certificate_config != nullptr) {
- grpc_ssl_server_certificate_config_destroy(certificate_config);
- }
- return status;
-}
-
-static void ssl_server_add_handshakers(grpc_server_security_connector* sc,
- grpc_pollset_set* interested_parties,
- grpc_handshake_manager* handshake_mgr) {
- grpc_ssl_server_security_connector* c =
- reinterpret_cast<grpc_ssl_server_security_connector*>(sc);
- // Instantiate TSI handshaker.
- try_fetch_ssl_server_credentials(c);
- tsi_handshaker* tsi_hs = nullptr;
- tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
- c->server_handshaker_factory, &tsi_hs);
- if (result != TSI_OK) {
- gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
- tsi_result_to_string(result));
- return;
- }
- // Create handshakers.
- grpc_handshake_manager_add(
- handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base));
-}
-
-static grpc_error* ssl_check_peer(grpc_security_connector* sc,
- const char* peer_name, const tsi_peer* peer,
- grpc_auth_context** auth_context) {
+namespace {
+grpc_error* ssl_check_peer(
+ const char* peer_name, const tsi_peer* peer,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context) {
#if TSI_OPENSSL_ALPN_SUPPORT
/* Check the ALPN if ALPN is supported. */
const tsi_peer_property* p =
@@ -230,245 +69,384 @@ static grpc_error* ssl_check_peer(grpc_security_connector* sc,
return GRPC_ERROR_NONE;
}
-static void ssl_channel_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
- const char* target_name = c->overridden_target_name != nullptr
- ? c->overridden_target_name
- : c->target_name;
- grpc_error* error = ssl_check_peer(sc, target_name, &peer, auth_context);
- if (error == GRPC_ERROR_NONE &&
- c->verify_options->verify_peer_callback != nullptr) {
- const tsi_peer_property* p =
- tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY);
- if (p == nullptr) {
- error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "Cannot check peer: missing pem cert property.");
- } else {
- char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1));
- memcpy(peer_pem, p->value.data, p->value.length);
- peer_pem[p->value.length] = '\0';
- int callback_status = c->verify_options->verify_peer_callback(
- target_name, peer_pem,
- c->verify_options->verify_peer_callback_userdata);
- gpr_free(peer_pem);
- if (callback_status) {
- char* msg;
- gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)",
- callback_status);
- error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
- gpr_free(msg);
- }
- }
+class grpc_ssl_channel_security_connector final
+ : public grpc_channel_security_connector {
+ public:
+ grpc_ssl_channel_security_connector(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const grpc_ssl_config* config, const char* target_name,
+ const char* overridden_target_name)
+ : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME,
+ std::move(channel_creds),
+ std::move(request_metadata_creds)),
+ overridden_target_name_(overridden_target_name == nullptr
+ ? nullptr
+ : gpr_strdup(overridden_target_name)),
+ verify_options_(&config->verify_options) {
+ char* port;
+ gpr_split_host_port(target_name, &target_name_, &port);
+ gpr_free(port);
}
- GRPC_CLOSURE_SCHED(on_peer_checked, error);
- tsi_peer_destruct(&peer);
-}
-static void ssl_server_check_peer(grpc_security_connector* sc, tsi_peer peer,
- grpc_auth_context** auth_context,
- grpc_closure* on_peer_checked) {
- grpc_error* error = ssl_check_peer(sc, nullptr, &peer, auth_context);
- tsi_peer_destruct(&peer);
- GRPC_CLOSURE_SCHED(on_peer_checked, error);
-}
+ ~grpc_ssl_channel_security_connector() override {
+ tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_);
+ if (target_name_ != nullptr) gpr_free(target_name_);
+ if (overridden_target_name_ != nullptr) gpr_free(overridden_target_name_);
+ }
-static int ssl_channel_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- grpc_ssl_channel_security_connector* c1 =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc1);
- grpc_ssl_channel_security_connector* c2 =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc2);
- int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base);
- if (c != 0) return c;
- c = strcmp(c1->target_name, c2->target_name);
- if (c != 0) return c;
- return (c1->overridden_target_name == nullptr ||
- c2->overridden_target_name == nullptr)
- ? GPR_ICMP(c1->overridden_target_name, c2->overridden_target_name)
- : strcmp(c1->overridden_target_name, c2->overridden_target_name);
-}
+ grpc_security_status InitializeHandshakerFactory(
+ const grpc_ssl_config* config, const char* pem_root_certs,
+ const tsi_ssl_root_certs_store* root_store,
+ tsi_ssl_session_cache* ssl_session_cache) {
+ bool has_key_cert_pair =
+ config->pem_key_cert_pair != nullptr &&
+ config->pem_key_cert_pair->private_key != nullptr &&
+ config->pem_key_cert_pair->cert_chain != nullptr;
+ tsi_ssl_client_handshaker_options options;
+ memset(&options, 0, sizeof(options));
+ GPR_DEBUG_ASSERT(pem_root_certs != nullptr);
+ options.pem_root_certs = pem_root_certs;
+ options.root_store = root_store;
+ options.alpn_protocols =
+ grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
+ if (has_key_cert_pair) {
+ options.pem_key_cert_pair = config->pem_key_cert_pair;
+ }
+ options.cipher_suites = grpc_get_ssl_cipher_suites();
+ options.session_cache = ssl_session_cache;
+ const tsi_result result =
+ tsi_create_ssl_client_handshaker_factory_with_options(
+ &options, &client_handshaker_factory_);
+ gpr_free((void*)options.alpn_protocols);
+ if (result != TSI_OK) {
+ gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
+ tsi_result_to_string(result));
+ return GRPC_SECURITY_ERROR;
+ }
+ return GRPC_SECURITY_OK;
+ }
-static int ssl_server_cmp(grpc_security_connector* sc1,
- grpc_security_connector* sc2) {
- return grpc_server_security_connector_cmp(
- reinterpret_cast<grpc_server_security_connector*>(sc1),
- reinterpret_cast<grpc_server_security_connector*>(sc2));
-}
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr) override {
+ // Instantiate TSI handshaker.
+ tsi_handshaker* tsi_hs = nullptr;
+ tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker(
+ client_handshaker_factory_,
+ overridden_target_name_ != nullptr ? overridden_target_name_
+ : target_name_,
+ &tsi_hs);
+ if (result != TSI_OK) {
+ gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
+ tsi_result_to_string(result));
+ return;
+ }
+ // Create handshakers.
+ grpc_handshake_manager_add(handshake_mgr,
+ grpc_security_handshaker_create(tsi_hs, this));
+ }
-static bool ssl_channel_check_call_host(grpc_channel_security_connector* sc,
- const char* host,
- grpc_auth_context* auth_context,
- grpc_closure* on_call_host_checked,
- grpc_error** error) {
- grpc_ssl_channel_security_connector* c =
- reinterpret_cast<grpc_ssl_channel_security_connector*>(sc);
- grpc_security_status status = GRPC_SECURITY_ERROR;
- tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context);
- if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK;
- /* If the target name was overridden, then the original target_name was
- 'checked' transitively during the previous peer check at the end of the
- handshake. */
- if (c->overridden_target_name != nullptr &&
- strcmp(host, c->target_name) == 0) {
- status = GRPC_SECURITY_OK;
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ const char* target_name = overridden_target_name_ != nullptr
+ ? overridden_target_name_
+ : target_name_;
+ grpc_error* error = ssl_check_peer(target_name, &peer, auth_context);
+ if (error == GRPC_ERROR_NONE &&
+ verify_options_->verify_peer_callback != nullptr) {
+ const tsi_peer_property* p =
+ tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY);
+ if (p == nullptr) {
+ error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "Cannot check peer: missing pem cert property.");
+ } else {
+ char* peer_pem = static_cast<char*>(gpr_malloc(p->value.length + 1));
+ memcpy(peer_pem, p->value.data, p->value.length);
+ peer_pem[p->value.length] = '\0';
+ int callback_status = verify_options_->verify_peer_callback(
+ target_name, peer_pem,
+ verify_options_->verify_peer_callback_userdata);
+ gpr_free(peer_pem);
+ if (callback_status) {
+ char* msg;
+ gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)",
+ callback_status);
+ error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg);
+ gpr_free(msg);
+ }
+ }
+ }
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
+ tsi_peer_destruct(&peer);
}
- if (status != GRPC_SECURITY_OK) {
- *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
- "call host does not match SSL server name");
+
+ int cmp(const grpc_security_connector* other_sc) const override {
+ auto* other =
+ reinterpret_cast<const grpc_ssl_channel_security_connector*>(other_sc);
+ int c = channel_security_connector_cmp(other);
+ if (c != 0) return c;
+ c = strcmp(target_name_, other->target_name_);
+ if (c != 0) return c;
+ return (overridden_target_name_ == nullptr ||
+ other->overridden_target_name_ == nullptr)
+ ? GPR_ICMP(overridden_target_name_,
+ other->overridden_target_name_)
+ : strcmp(overridden_target_name_,
+ other->overridden_target_name_);
}
- grpc_shallow_peer_destruct(&peer);
- return true;
-}
-static void ssl_channel_cancel_check_call_host(
- grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked,
- grpc_error* error) {
- GRPC_ERROR_UNREF(error);
-}
+ bool check_call_host(const char* host, grpc_auth_context* auth_context,
+ grpc_closure* on_call_host_checked,
+ grpc_error** error) override {
+ grpc_security_status status = GRPC_SECURITY_ERROR;
+ tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context);
+ if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK;
+ /* If the target name was overridden, then the original target_name was
+ 'checked' transitively during the previous peer check at the end of the
+ handshake. */
+ if (overridden_target_name_ != nullptr && strcmp(host, target_name_) == 0) {
+ status = GRPC_SECURITY_OK;
+ }
+ if (status != GRPC_SECURITY_OK) {
+ *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+ "call host does not match SSL server name");
+ }
+ grpc_shallow_peer_destruct(&peer);
+ return true;
+ }
-static grpc_security_connector_vtable ssl_channel_vtable = {
- ssl_channel_destroy, ssl_channel_check_peer, ssl_channel_cmp};
+ void cancel_check_call_host(grpc_closure* on_call_host_checked,
+ grpc_error* error) override {
+ GRPC_ERROR_UNREF(error);
+ }
-static grpc_security_connector_vtable ssl_server_vtable = {
- ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp};
+ private:
+ tsi_ssl_client_handshaker_factory* client_handshaker_factory_;
+ char* target_name_;
+ char* overridden_target_name_;
+ const verify_peer_options* verify_options_;
+};
+
+class grpc_ssl_server_security_connector
+ : public grpc_server_security_connector {
+ public:
+ grpc_ssl_server_security_connector(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_creds)
+ : grpc_server_security_connector(GRPC_SSL_URL_SCHEME,
+ std::move(server_creds)) {}
+
+ ~grpc_ssl_server_security_connector() override {
+ tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_);
+ }
-grpc_security_status grpc_ssl_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds,
- const grpc_ssl_config* config, const char* target_name,
- const char* overridden_target_name,
- tsi_ssl_session_cache* ssl_session_cache,
- grpc_channel_security_connector** sc) {
- tsi_result result = TSI_OK;
- grpc_ssl_channel_security_connector* c;
- char* port;
- bool has_key_cert_pair;
- tsi_ssl_client_handshaker_options options;
- memset(&options, 0, sizeof(options));
- options.alpn_protocols =
- grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols);
+ bool has_cert_config_fetcher() const {
+ return static_cast<const grpc_ssl_server_credentials*>(server_creds())
+ ->has_cert_config_fetcher();
+ }
- if (config == nullptr || target_name == nullptr) {
- gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
- goto error;
+ const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const {
+ return server_handshaker_factory_;
}
- if (config->pem_root_certs == nullptr) {
- // Use default root certificates.
- options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
- options.root_store = grpc_core::DefaultSslRootStore::GetRootStore();
- if (options.pem_root_certs == nullptr) {
- gpr_log(GPR_ERROR, "Could not get default pem root certs.");
- goto error;
+
+ grpc_security_status InitializeHandshakerFactory() {
+ if (has_cert_config_fetcher()) {
+ // Load initial credentials from certificate_config_fetcher:
+ if (!try_fetch_ssl_server_credentials()) {
+ gpr_log(GPR_ERROR,
+ "Failed loading SSL server credentials from fetcher.");
+ return GRPC_SECURITY_ERROR;
+ }
+ } else {
+ auto* server_credentials =
+ static_cast<const grpc_ssl_server_credentials*>(server_creds());
+ size_t num_alpn_protocols = 0;
+ const char** alpn_protocol_strings =
+ grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
+ const tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
+ server_credentials->config().pem_key_cert_pairs,
+ server_credentials->config().num_key_cert_pairs,
+ server_credentials->config().pem_root_certs,
+ grpc_get_tsi_client_certificate_request_type(
+ server_credentials->config().client_certificate_request),
+ grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
+ static_cast<uint16_t>(num_alpn_protocols),
+ &server_handshaker_factory_);
+ gpr_free((void*)alpn_protocol_strings);
+ if (result != TSI_OK) {
+ gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
+ tsi_result_to_string(result));
+ return GRPC_SECURITY_ERROR;
+ }
}
- } else {
- options.pem_root_certs = config->pem_root_certs;
- }
- c = static_cast<grpc_ssl_channel_security_connector*>(
- gpr_zalloc(sizeof(grpc_ssl_channel_security_connector)));
-
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.vtable = &ssl_channel_vtable;
- c->base.base.url_scheme = GRPC_SSL_URL_SCHEME;
- c->base.channel_creds = grpc_channel_credentials_ref(channel_creds);
- c->base.request_metadata_creds =
- grpc_call_credentials_ref(request_metadata_creds);
- c->base.check_call_host = ssl_channel_check_call_host;
- c->base.cancel_check_call_host = ssl_channel_cancel_check_call_host;
- c->base.add_handshakers = ssl_channel_add_handshakers;
- gpr_split_host_port(target_name, &c->target_name, &port);
- gpr_free(port);
- if (overridden_target_name != nullptr) {
- c->overridden_target_name = gpr_strdup(overridden_target_name);
+ return GRPC_SECURITY_OK;
}
- c->verify_options = &config->verify_options;
- has_key_cert_pair = config->pem_key_cert_pair != nullptr &&
- config->pem_key_cert_pair->private_key != nullptr &&
- config->pem_key_cert_pair->cert_chain != nullptr;
- if (has_key_cert_pair) {
- options.pem_key_cert_pair = config->pem_key_cert_pair;
+ void add_handshakers(grpc_pollset_set* interested_parties,
+ grpc_handshake_manager* handshake_mgr) override {
+ // Instantiate TSI handshaker.
+ try_fetch_ssl_server_credentials();
+ tsi_handshaker* tsi_hs = nullptr;
+ tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker(
+ server_handshaker_factory_, &tsi_hs);
+ if (result != TSI_OK) {
+ gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.",
+ tsi_result_to_string(result));
+ return;
+ }
+ // Create handshakers.
+ grpc_handshake_manager_add(handshake_mgr,
+ grpc_security_handshaker_create(tsi_hs, this));
}
- options.cipher_suites = grpc_get_ssl_cipher_suites();
- options.session_cache = ssl_session_cache;
- result = tsi_create_ssl_client_handshaker_factory_with_options(
- &options, &c->client_handshaker_factory);
- if (result != TSI_OK) {
- gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
- tsi_result_to_string(result));
- ssl_channel_destroy(&c->base.base);
- *sc = nullptr;
- goto error;
+
+ void check_peer(tsi_peer peer, grpc_endpoint* ep,
+ grpc_core::RefCountedPtr<grpc_auth_context>* auth_context,
+ grpc_closure* on_peer_checked) override {
+ grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context);
+ tsi_peer_destruct(&peer);
+ GRPC_CLOSURE_SCHED(on_peer_checked, error);
}
- *sc = &c->base;
- gpr_free((void*)options.alpn_protocols);
- return GRPC_SECURITY_OK;
-error:
- gpr_free((void*)options.alpn_protocols);
- return GRPC_SECURITY_ERROR;
-}
+ int cmp(const grpc_security_connector* other) const override {
+ return server_security_connector_cmp(
+ static_cast<const grpc_server_security_connector*>(other));
+ }
-static grpc_ssl_server_security_connector*
-grpc_ssl_server_security_connector_initialize(
- grpc_server_credentials* server_creds) {
- grpc_ssl_server_security_connector* c =
- static_cast<grpc_ssl_server_security_connector*>(
- gpr_zalloc(sizeof(grpc_ssl_server_security_connector)));
- gpr_ref_init(&c->base.base.refcount, 1);
- c->base.base.url_scheme = GRPC_SSL_URL_SCHEME;
- c->base.base.vtable = &ssl_server_vtable;
- c->base.add_handshakers = ssl_server_add_handshakers;
- c->base.server_creds = grpc_server_credentials_ref(server_creds);
- return c;
-}
+ private:
+ /* Attempts to fetch the server certificate config if a callback is available.
+ * Current certificate config will continue to be used if the callback returns
+ * an error. Returns true if new credentials were sucessfully loaded. */
+ bool try_fetch_ssl_server_credentials() {
+ grpc_ssl_server_certificate_config* certificate_config = nullptr;
+ bool status;
+
+ if (!has_cert_config_fetcher()) return false;
+
+ grpc_ssl_server_credentials* server_creds =
+ static_cast<grpc_ssl_server_credentials*>(this->mutable_server_creds());
+ grpc_ssl_certificate_config_reload_status cb_result =
+ server_creds->FetchCertConfig(&certificate_config);
+ if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) {
+ gpr_log(GPR_DEBUG, "No change in SSL server credentials.");
+ status = false;
+ } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) {
+ status = try_replace_server_handshaker_factory(certificate_config);
+ } else {
+ // Log error, continue using previously-loaded credentials.
+ gpr_log(GPR_ERROR,
+ "Failed fetching new server credentials, continuing to "
+ "use previously-loaded credentials.");
+ status = false;
+ }
-grpc_security_status grpc_ssl_server_security_connector_create(
- grpc_server_credentials* gsc, grpc_server_security_connector** sc) {
- tsi_result result = TSI_OK;
- grpc_ssl_server_credentials* server_credentials =
- reinterpret_cast<grpc_ssl_server_credentials*>(gsc);
- grpc_security_status retval = GRPC_SECURITY_OK;
+ if (certificate_config != nullptr) {
+ grpc_ssl_server_certificate_config_destroy(certificate_config);
+ }
+ return status;
+ }
- GPR_ASSERT(server_credentials != nullptr);
- GPR_ASSERT(sc != nullptr);
-
- grpc_ssl_server_security_connector* c =
- grpc_ssl_server_security_connector_initialize(gsc);
- if (server_connector_has_cert_config_fetcher(c)) {
- // Load initial credentials from certificate_config_fetcher:
- if (!try_fetch_ssl_server_credentials(c)) {
- gpr_log(GPR_ERROR, "Failed loading SSL server credentials from fetcher.");
- retval = GRPC_SECURITY_ERROR;
+ /* Attempts to replace the server_handshaker_factory with a new factory using
+ * the provided grpc_ssl_server_certificate_config. Should new factory
+ * creation fail, the existing factory will not be replaced. Returns true on
+ * success (new factory created). */
+ bool try_replace_server_handshaker_factory(
+ const grpc_ssl_server_certificate_config* config) {
+ if (config == nullptr) {
+ gpr_log(GPR_ERROR,
+ "Server certificate config callback returned invalid (NULL) "
+ "config.");
+ return false;
}
- } else {
+ gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config);
+
size_t num_alpn_protocols = 0;
const char** alpn_protocol_strings =
grpc_fill_alpn_protocol_strings(&num_alpn_protocols);
- result = tsi_create_ssl_server_handshaker_factory_ex(
- server_credentials->config.pem_key_cert_pairs,
- server_credentials->config.num_key_cert_pairs,
- server_credentials->config.pem_root_certs,
+ tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs(
+ config->pem_key_cert_pairs, config->num_key_cert_pairs);
+ tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr;
+ const grpc_ssl_server_credentials* server_creds =
+ static_cast<const grpc_ssl_server_credentials*>(this->server_creds());
+ GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr);
+ tsi_result result = tsi_create_ssl_server_handshaker_factory_ex(
+ cert_pairs, config->num_key_cert_pairs, config->pem_root_certs,
grpc_get_tsi_client_certificate_request_type(
- server_credentials->config.client_certificate_request),
+ server_creds->config().client_certificate_request),
grpc_get_ssl_cipher_suites(), alpn_protocol_strings,
- static_cast<uint16_t>(num_alpn_protocols),
- &c->server_handshaker_factory);
+ static_cast<uint16_t>(num_alpn_protocols), &new_handshaker_factory);
+ gpr_free(cert_pairs);
gpr_free((void*)alpn_protocol_strings);
+
if (result != TSI_OK) {
gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.",
tsi_result_to_string(result));
- retval = GRPC_SECURITY_ERROR;
+ return false;
}
+ set_server_handshaker_factory(new_handshaker_factory);
+ return true;
+ }
+
+ void set_server_handshaker_factory(
+ tsi_ssl_server_handshaker_factory* new_factory) {
+ if (server_handshaker_factory_) {
+ tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_);
+ }
+ server_handshaker_factory_ = new_factory;
+ }
+
+ tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr;
+};
+} // namespace
+
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_ssl_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
+ const grpc_ssl_config* config, const char* target_name,
+ const char* overridden_target_name,
+ tsi_ssl_session_cache* ssl_session_cache) {
+ if (config == nullptr || target_name == nullptr) {
+ gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name.");
+ return nullptr;
}
- if (retval == GRPC_SECURITY_OK) {
- *sc = &c->base;
+ const char* pem_root_certs;
+ const tsi_ssl_root_certs_store* root_store;
+ if (config->pem_root_certs == nullptr) {
+ // Use default root certificates.
+ pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts();
+ if (pem_root_certs == nullptr) {
+ gpr_log(GPR_ERROR, "Could not get default pem root certs.");
+ return nullptr;
+ }
+ root_store = grpc_core::DefaultSslRootStore::GetRootStore();
} else {
- if (c != nullptr) ssl_server_destroy(&c->base.base);
- if (sc != nullptr) *sc = nullptr;
+ pem_root_certs = config->pem_root_certs;
+ root_store = nullptr;
+ }
+
+ grpc_core::RefCountedPtr<grpc_ssl_channel_security_connector> c =
+ grpc_core::MakeRefCounted<grpc_ssl_channel_security_connector>(
+ std::move(channel_creds), std::move(request_metadata_creds), config,
+ target_name, overridden_target_name);
+ const grpc_security_status result = c->InitializeHandshakerFactory(
+ config, pem_root_certs, root_store, ssl_session_cache);
+ if (result != GRPC_SECURITY_OK) {
+ return nullptr;
}
- return retval;
+ return c;
+}
+
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_ssl_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials) {
+ GPR_ASSERT(server_credentials != nullptr);
+ grpc_core::RefCountedPtr<grpc_ssl_server_security_connector> c =
+ grpc_core::MakeRefCounted<grpc_ssl_server_security_connector>(
+ std::move(server_credentials));
+ const grpc_security_status retval = c->InitializeHandshakerFactory();
+ if (retval != GRPC_SECURITY_OK) {
+ return nullptr;
+ }
+ return c;
}
diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h
index 9b80590606..70e26e338a 100644
--- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h
+++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h
@@ -25,6 +25,7 @@
#include "src/core/lib/security/security_connector/security_connector.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"
@@ -47,20 +48,21 @@ typedef struct {
This function returns GRPC_SECURITY_OK in case of success or a
specific error code otherwise.
*/
-grpc_security_status grpc_ssl_channel_security_connector_create(
- grpc_channel_credentials* channel_creds,
- grpc_call_credentials* request_metadata_creds,
+grpc_core::RefCountedPtr<grpc_channel_security_connector>
+grpc_ssl_channel_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_channel_credentials> channel_creds,
+ grpc_core::RefCountedPtr<grpc_call_credentials> request_metadata_creds,
const grpc_ssl_config* config, const char* target_name,
const char* overridden_target_name,
- tsi_ssl_session_cache* ssl_session_cache,
- grpc_channel_security_connector** sc);
+ tsi_ssl_session_cache* ssl_session_cache);
/* Config for ssl servers. */
typedef struct {
- tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs;
- size_t num_key_cert_pairs;
- char* pem_root_certs;
- grpc_ssl_client_certificate_request_type client_certificate_request;
+ tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr;
+ size_t num_key_cert_pairs = 0;
+ char* pem_root_certs = nullptr;
+ grpc_ssl_client_certificate_request_type client_certificate_request =
+ GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE;
} grpc_ssl_server_config;
/* Creates an SSL server_security_connector.
@@ -69,9 +71,9 @@ typedef struct {
This function returns GRPC_SECURITY_OK in case of success or a
specific error code otherwise.
*/
-grpc_security_status grpc_ssl_server_security_connector_create(
- grpc_server_credentials* server_credentials,
- grpc_server_security_connector** sc);
+grpc_core::RefCountedPtr<grpc_server_security_connector>
+grpc_ssl_server_security_connector_create(
+ grpc_core::RefCountedPtr<grpc_server_credentials> server_credentials);
#endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SSL_SSL_SECURITY_CONNECTOR_H \
*/
diff --git a/src/core/lib/security/security_connector/ssl_utils.cc b/src/core/lib/security/security_connector/ssl_utils.cc
index fbf41cfbc7..29030f07ad 100644
--- a/src/core/lib/security/security_connector/ssl_utils.cc
+++ b/src/core/lib/security/security_connector/ssl_utils.cc
@@ -30,6 +30,7 @@
#include "src/core/lib/gpr/env.h"
#include "src/core/lib/gpr/host_port.h"
#include "src/core/lib/gpr/string.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/iomgr/load_file.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/security_connector/load_system_roots.h"
@@ -141,16 +142,17 @@ int grpc_ssl_host_matches_name(const tsi_peer* peer, const char* peer_name) {
return r;
}
-grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) {
+grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context(
+ const tsi_peer* peer) {
size_t i;
- grpc_auth_context* ctx = nullptr;
const char* peer_identity_property_name = nullptr;
/* The caller has checked the certificate type property. */
GPR_ASSERT(peer->property_count >= 1);
- ctx = grpc_auth_context_create(nullptr);
+ grpc_core::RefCountedPtr<grpc_auth_context> ctx =
+ grpc_core::MakeRefCounted<grpc_auth_context>(nullptr);
grpc_auth_context_add_cstring_property(
- ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
+ ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME,
GRPC_SSL_TRANSPORT_SECURITY_TYPE);
for (i = 0; i < peer->property_count; i++) {
const tsi_peer_property* prop = &peer->properties[i];
@@ -160,24 +162,26 @@ grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) {
if (peer_identity_property_name == nullptr) {
peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME;
}
- grpc_auth_context_add_property(ctx, GRPC_X509_CN_PROPERTY_NAME,
+ grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name,
TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME;
- grpc_auth_context_add_property(ctx, GRPC_X509_SAN_PROPERTY_NAME,
+ grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) {
- grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME,
+ grpc_auth_context_add_property(ctx.get(),
+ GRPC_X509_PEM_CERT_PROPERTY_NAME,
prop->value.data, prop->value.length);
} else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) {
- grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY,
+ grpc_auth_context_add_property(ctx.get(),
+ GRPC_SSL_SESSION_REUSED_PROPERTY,
prop->value.data, prop->value.length);
}
}
if (peer_identity_property_name != nullptr) {
GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(
- ctx, peer_identity_property_name) == 1);
+ ctx.get(), peer_identity_property_name) == 1);
}
return ctx;
}
diff --git a/src/core/lib/security/security_connector/ssl_utils.h b/src/core/lib/security/security_connector/ssl_utils.h
index 6f6d473311..c9cd1a1d9c 100644
--- a/src/core/lib/security/security_connector/ssl_utils.h
+++ b/src/core/lib/security/security_connector/ssl_utils.h
@@ -26,6 +26,7 @@
#include <grpc/grpc_security.h>
#include <grpc/slice_buffer.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/tsi/ssl_transport_security.h"
#include "src/core/tsi/transport_security_interface.h"
@@ -47,7 +48,8 @@ grpc_get_tsi_client_certificate_request_type(
const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols);
/* Exposed for testing only. */
-grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer);
+grpc_core::RefCountedPtr<grpc_auth_context> grpc_ssl_peer_to_auth_context(
+ const tsi_peer* peer);
tsi_peer grpc_shallow_peer_from_ssl_auth_context(
const grpc_auth_context* auth_context);
void grpc_shallow_peer_destruct(tsi_peer* peer);
diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc
index 6955e8698e..66f86b8bc5 100644
--- a/src/core/lib/security/transport/client_auth_filter.cc
+++ b/src/core/lib/security/transport/client_auth_filter.cc
@@ -55,7 +55,7 @@ struct call_data {
// that the memory is not initialized.
void destroy() {
grpc_credentials_mdelem_array_destroy(&md_array);
- grpc_call_credentials_unref(creds);
+ creds.reset();
grpc_slice_unref_internal(host);
grpc_slice_unref_internal(method);
grpc_auth_metadata_context_reset(&auth_md_context);
@@ -64,7 +64,7 @@ struct call_data {
gpr_arena* arena;
grpc_call_stack* owning_call;
grpc_call_combiner* call_combiner;
- grpc_call_credentials* creds = nullptr;
+ grpc_core::RefCountedPtr<grpc_call_credentials> creds;
grpc_slice host = grpc_empty_slice();
grpc_slice method = grpc_empty_slice();
/* pollset{_set} bound to this call; if we need to make external
@@ -83,8 +83,18 @@ struct call_data {
/* We can have a per-channel credentials. */
struct channel_data {
- grpc_channel_security_connector* security_connector;
- grpc_auth_context* auth_context;
+ channel_data(grpc_channel_security_connector* security_connector,
+ grpc_auth_context* auth_context)
+ : security_connector(
+ security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")),
+ auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {}
+ ~channel_data() {
+ security_connector.reset(DEBUG_LOCATION, "client_auth_filter");
+ auth_context.reset(DEBUG_LOCATION, "client_auth_filter");
+ }
+
+ grpc_core::RefCountedPtr<grpc_channel_security_connector> security_connector;
+ grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
};
} // namespace
@@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset(
gpr_free(const_cast<char*>(auth_md_context->method_name));
auth_md_context->method_name = nullptr;
}
- GRPC_AUTH_CONTEXT_UNREF(
- (grpc_auth_context*)auth_md_context->channel_auth_context,
- "grpc_auth_metadata_context");
- auth_md_context->channel_auth_context = nullptr;
+ if (auth_md_context->channel_auth_context != nullptr) {
+ const_cast<grpc_auth_context*>(auth_md_context->channel_auth_context)
+ ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context");
+ auth_md_context->channel_auth_context = nullptr;
+ }
}
static void add_error(grpc_error** combined, grpc_error* error) {
@@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build(
auth_md_context->service_url = service_url;
auth_md_context->method_name = method_name;
auth_md_context->channel_auth_context =
- GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context");
+ auth_context == nullptr
+ ? nullptr
+ : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context")
+ .release();
gpr_free(service);
gpr_free(host_and_port);
}
@@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) {
grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
call_data* calld = static_cast<call_data*>(elem->call_data);
if (error != GRPC_ERROR_NONE) {
- grpc_call_credentials_cancel_get_request_metadata(
- calld->creds, &calld->md_array, GRPC_ERROR_REF(error));
+ calld->creds->cancel_get_request_metadata(&calld->md_array,
+ GRPC_ERROR_REF(error));
}
}
@@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem,
static_cast<grpc_client_security_context*>(
batch->payload->context[GRPC_CONTEXT_SECURITY].value);
grpc_call_credentials* channel_call_creds =
- chand->security_connector->request_metadata_creds;
+ chand->security_connector->mutable_request_metadata_creds();
int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr);
if (channel_call_creds == nullptr && !call_creds_has_md) {
@@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem,
}
if (channel_call_creds != nullptr && call_creds_has_md) {
- calld->creds = grpc_composite_call_credentials_create(channel_call_creds,
- ctx->creds, nullptr);
+ calld->creds = grpc_core::RefCountedPtr<grpc_call_credentials>(
+ grpc_composite_call_credentials_create(channel_call_creds,
+ ctx->creds.get(), nullptr));
if (calld->creds == nullptr) {
grpc_transport_stream_op_batch_finish_with_failure(
batch,
@@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem,
return;
}
} else {
- calld->creds = grpc_call_credentials_ref(
- call_creds_has_md ? ctx->creds : channel_call_creds);
+ calld->creds =
+ call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref();
}
grpc_auth_metadata_context_build(
- chand->security_connector->base.url_scheme, calld->host, calld->method,
- chand->auth_context, &calld->auth_md_context);
+ chand->security_connector->url_scheme(), calld->host, calld->method,
+ chand->auth_context.get(), &calld->auth_md_context);
GPR_ASSERT(calld->pollent != nullptr);
GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata");
GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata,
batch, grpc_schedule_on_exec_ctx);
grpc_error* error = GRPC_ERROR_NONE;
- if (grpc_call_credentials_get_request_metadata(
- calld->creds, calld->pollent, calld->auth_md_context,
- &calld->md_array, &calld->async_result_closure, &error)) {
+ if (calld->creds->get_request_metadata(
+ calld->pollent, calld->auth_md_context, &calld->md_array,
+ &calld->async_result_closure, &error)) {
// Synchronous return; invoke on_credentials_metadata() directly.
on_credentials_metadata(batch, error);
GRPC_ERROR_UNREF(error);
@@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) {
call_data* calld = static_cast<call_data*>(elem->call_data);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
if (error != GRPC_ERROR_NONE) {
- grpc_channel_security_connector_cancel_check_call_host(
- chand->security_connector, &calld->async_result_closure,
- GRPC_ERROR_REF(error));
+ chand->security_connector->cancel_check_call_host(
+ &calld->async_result_closure, GRPC_ERROR_REF(error));
}
}
@@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch(
GPR_ASSERT(batch->payload->context != nullptr);
if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) {
batch->payload->context[GRPC_CONTEXT_SECURITY].value =
- grpc_client_security_context_create(calld->arena);
+ grpc_client_security_context_create(calld->arena, /*creds=*/nullptr);
batch->payload->context[GRPC_CONTEXT_SECURITY].destroy =
grpc_client_security_context_destroy;
}
grpc_client_security_context* sec_ctx =
static_cast<grpc_client_security_context*>(
batch->payload->context[GRPC_CONTEXT_SECURITY].value);
- GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter");
+ sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter");
sec_ctx->auth_context =
- GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter");
+ chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter");
}
if (batch->send_initial_metadata) {
@@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch(
grpc_schedule_on_exec_ctx);
char* call_host = grpc_slice_to_c_string(calld->host);
grpc_error* error = GRPC_ERROR_NONE;
- if (grpc_channel_security_connector_check_call_host(
- chand->security_connector, call_host, chand->auth_context,
+ if (chand->security_connector->check_call_host(
+ call_host, chand->auth_context.get(),
&calld->async_result_closure, &error)) {
// Synchronous return; invoke on_host_checked() directly.
on_host_checked(batch, error);
@@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem,
/* Constructor for channel_data */
static grpc_error* init_channel_elem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
+ /* The first and the last filters tend to be implemented differently to
+ handle the case that there's no 'next' filter to call on the up or down
+ path */
+ GPR_ASSERT(!args->is_last);
grpc_security_connector* sc =
grpc_security_connector_find_in_args(args->channel_args);
if (sc == nullptr) {
@@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem,
return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
"Auth context missing from client auth filter args");
}
-
- /* grab pointers to our data from the channel element */
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
-
- /* The first and the last filters tend to be implemented differently to
- handle the case that there's no 'next' filter to call on the up or down
- path */
- GPR_ASSERT(!args->is_last);
-
- /* initialize members */
- chand->security_connector =
- reinterpret_cast<grpc_channel_security_connector*>(
- GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter"));
- chand->auth_context =
- GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter");
+ new (elem->channel_data) channel_data(
+ static_cast<grpc_channel_security_connector*>(sc), auth_context);
return GRPC_ERROR_NONE;
}
/* Destructor for channel data */
static void destroy_channel_elem(grpc_channel_element* elem) {
- /* grab pointers to our data from the channel element */
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- grpc_channel_security_connector* sc = chand->security_connector;
- if (sc != nullptr) {
- GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter");
- }
- GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter");
+ chand->~channel_data();
}
const grpc_channel_filter grpc_client_auth_filter = {
diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc
index 854a1c4af9..01831dab10 100644
--- a/src/core/lib/security/transport/security_handshaker.cc
+++ b/src/core/lib/security/transport/security_handshaker.cc
@@ -30,6 +30,7 @@
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/handshaker.h"
#include "src/core/lib/channel/handshaker_registry.h"
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/core/lib/security/context/security_context.h"
#include "src/core/lib/security/transport/secure_endpoint.h"
#include "src/core/lib/security/transport/tsi_error.h"
@@ -38,34 +39,62 @@
#define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
-typedef struct {
+namespace {
+struct security_handshaker {
+ security_handshaker(tsi_handshaker* handshaker,
+ grpc_security_connector* connector);
+ ~security_handshaker() {
+ gpr_mu_destroy(&mu);
+ tsi_handshaker_destroy(handshaker);
+ tsi_handshaker_result_destroy(handshaker_result);
+ if (endpoint_to_destroy != nullptr) {
+ grpc_endpoint_destroy(endpoint_to_destroy);
+ }
+ if (read_buffer_to_destroy != nullptr) {
+ grpc_slice_buffer_destroy_internal(read_buffer_to_destroy);
+ gpr_free(read_buffer_to_destroy);
+ }
+ gpr_free(handshake_buffer);
+ grpc_slice_buffer_destroy_internal(&outgoing);
+ auth_context.reset(DEBUG_LOCATION, "handshake");
+ connector.reset(DEBUG_LOCATION, "handshake");
+ }
+
+ void Ref() { refs.Ref(); }
+ void Unref() {
+ if (refs.Unref()) {
+ grpc_core::Delete(this);
+ }
+ }
+
grpc_handshaker base;
// State set at creation time.
tsi_handshaker* handshaker;
- grpc_security_connector* connector;
+ grpc_core::RefCountedPtr<grpc_security_connector> connector;
gpr_mu mu;
- gpr_refcount refs;
+ grpc_core::RefCount refs;
- bool shutdown;
+ bool shutdown = false;
// Endpoint and read buffer to destroy after a shutdown.
- grpc_endpoint* endpoint_to_destroy;
- grpc_slice_buffer* read_buffer_to_destroy;
+ grpc_endpoint* endpoint_to_destroy = nullptr;
+ grpc_slice_buffer* read_buffer_to_destroy = nullptr;
// State saved while performing the handshake.
- grpc_handshaker_args* args;
- grpc_closure* on_handshake_done;
+ grpc_handshaker_args* args = nullptr;
+ grpc_closure* on_handshake_done = nullptr;
- unsigned char* handshake_buffer;
size_t handshake_buffer_size;
+ unsigned char* handshake_buffer;
grpc_slice_buffer outgoing;
grpc_closure on_handshake_data_sent_to_peer;
grpc_closure on_handshake_data_received_from_peer;
grpc_closure on_peer_checked;
- grpc_auth_context* auth_context;
- tsi_handshaker_result* handshaker_result;
-} security_handshaker;
+ grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
+ tsi_handshaker_result* handshaker_result = nullptr;
+};
+} // namespace
static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) {
size_t bytes_in_read_buffer = h->args->read_buffer->length;
@@ -85,26 +114,6 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) {
return bytes_in_read_buffer;
}
-static void security_handshaker_unref(security_handshaker* h) {
- if (gpr_unref(&h->refs)) {
- gpr_mu_destroy(&h->mu);
- tsi_handshaker_destroy(h->handshaker);
- tsi_handshaker_result_destroy(h->handshaker_result);
- if (h->endpoint_to_destroy != nullptr) {
- grpc_endpoint_destroy(h->endpoint_to_destroy);
- }
- if (h->read_buffer_to_destroy != nullptr) {
- grpc_slice_buffer_destroy_internal(h->read_buffer_to_destroy);
- gpr_free(h->read_buffer_to_destroy);
- }
- gpr_free(h->handshake_buffer);
- grpc_slice_buffer_destroy_internal(&h->outgoing);
- GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake");
- GRPC_SECURITY_CONNECTOR_UNREF(h->connector, "handshake");
- gpr_free(h);
- }
-}
-
// Set args fields to NULL, saving the endpoint and read buffer for
// later destruction.
static void cleanup_args_for_failure_locked(security_handshaker* h) {
@@ -194,7 +203,7 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) {
tsi_handshaker_result_destroy(h->handshaker_result);
h->handshaker_result = nullptr;
// Add auth context to channel args.
- grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context);
+ grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get());
grpc_channel_args* tmp_args = h->args->args;
h->args->args =
grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1);
@@ -211,7 +220,7 @@ static void on_peer_checked(void* arg, grpc_error* error) {
gpr_mu_lock(&h->mu);
on_peer_checked_inner(h, error);
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
}
static grpc_error* check_peer_locked(security_handshaker* h) {
@@ -222,8 +231,8 @@ static grpc_error* check_peer_locked(security_handshaker* h) {
return grpc_set_tsi_error_result(
GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result);
}
- grpc_security_connector_check_peer(h->connector, peer, &h->auth_context,
- &h->on_peer_checked);
+ h->connector->check_peer(peer, h->args->endpoint, &h->auth_context,
+ &h->on_peer_checked);
return GRPC_ERROR_NONE;
}
@@ -281,7 +290,7 @@ static void on_handshake_next_done_grpc_wrapper(
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
} else {
gpr_mu_unlock(&h->mu);
}
@@ -317,7 +326,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake read failed", &error, 1));
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
return;
}
// Copy all slices received.
@@ -329,7 +338,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) {
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
} else {
gpr_mu_unlock(&h->mu);
}
@@ -343,7 +352,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING(
"Handshake write failed", &error, 1));
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
return;
}
// We may be done.
@@ -355,7 +364,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
return;
}
}
@@ -368,7 +377,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) {
static void security_handshaker_destroy(grpc_handshaker* handshaker) {
security_handshaker* h = reinterpret_cast<security_handshaker*>(handshaker);
- security_handshaker_unref(h);
+ h->Unref();
}
static void security_handshaker_shutdown(grpc_handshaker* handshaker,
@@ -393,14 +402,14 @@ static void security_handshaker_do_handshake(grpc_handshaker* handshaker,
gpr_mu_lock(&h->mu);
h->args = args;
h->on_handshake_done = on_handshake_done;
- gpr_ref(&h->refs);
+ h->Ref();
size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h);
grpc_error* error =
do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size);
if (error != GRPC_ERROR_NONE) {
security_handshake_failed_locked(h, error);
gpr_mu_unlock(&h->mu);
- security_handshaker_unref(h);
+ h->Unref();
return;
}
gpr_mu_unlock(&h->mu);
@@ -410,27 +419,32 @@ static const grpc_handshaker_vtable security_handshaker_vtable = {
security_handshaker_destroy, security_handshaker_shutdown,
security_handshaker_do_handshake, "security"};
-static grpc_handshaker* security_handshaker_create(
- tsi_handshaker* handshaker, grpc_security_connector* connector) {
- security_handshaker* h = static_cast<security_handshaker*>(
- gpr_zalloc(sizeof(security_handshaker)));
- grpc_handshaker_init(&security_handshaker_vtable, &h->base);
- h->handshaker = handshaker;
- h->connector = GRPC_SECURITY_CONNECTOR_REF(connector, "handshake");
- gpr_mu_init(&h->mu);
- gpr_ref_init(&h->refs, 1);
- h->handshake_buffer_size = GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE;
- h->handshake_buffer =
- static_cast<uint8_t*>(gpr_malloc(h->handshake_buffer_size));
- GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer,
- on_handshake_data_sent_to_peer, h,
+namespace {
+security_handshaker::security_handshaker(tsi_handshaker* handshaker,
+ grpc_security_connector* connector)
+ : handshaker(handshaker),
+ connector(connector->Ref(DEBUG_LOCATION, "handshake")),
+ handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
+ handshake_buffer(
+ static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size))) {
+ grpc_handshaker_init(&security_handshaker_vtable, &base);
+ gpr_mu_init(&mu);
+ grpc_slice_buffer_init(&outgoing);
+ GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer,
+ ::on_handshake_data_sent_to_peer, this,
grpc_schedule_on_exec_ctx);
- GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer,
- on_handshake_data_received_from_peer, h,
+ GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer,
+ ::on_handshake_data_received_from_peer, this,
grpc_schedule_on_exec_ctx);
- GRPC_CLOSURE_INIT(&h->on_peer_checked, on_peer_checked, h,
+ GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this,
grpc_schedule_on_exec_ctx);
- grpc_slice_buffer_init(&h->outgoing);
+}
+} // namespace
+
+static grpc_handshaker* security_handshaker_create(
+ tsi_handshaker* handshaker, grpc_security_connector* connector) {
+ security_handshaker* h =
+ grpc_core::New<security_handshaker>(handshaker, connector);
return &h->base;
}
@@ -477,8 +491,9 @@ static void client_handshaker_factory_add_handshakers(
grpc_channel_security_connector* security_connector =
reinterpret_cast<grpc_channel_security_connector*>(
grpc_security_connector_find_in_args(args));
- grpc_channel_security_connector_add_handshakers(
- security_connector, interested_parties, handshake_mgr);
+ if (security_connector) {
+ security_connector->add_handshakers(interested_parties, handshake_mgr);
+ }
}
static void server_handshaker_factory_add_handshakers(
@@ -488,8 +503,9 @@ static void server_handshaker_factory_add_handshakers(
grpc_server_security_connector* security_connector =
reinterpret_cast<grpc_server_security_connector*>(
grpc_security_connector_find_in_args(args));
- grpc_server_security_connector_add_handshakers(
- security_connector, interested_parties, handshake_mgr);
+ if (security_connector) {
+ security_connector->add_handshakers(interested_parties, handshake_mgr);
+ }
}
static void handshaker_factory_destroy(
diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc
index 362f49a584..f93eb4275e 100644
--- a/src/core/lib/security/transport/server_auth_filter.cc
+++ b/src/core/lib/security/transport/server_auth_filter.cc
@@ -39,8 +39,12 @@ enum async_state {
};
struct channel_data {
- grpc_auth_context* auth_context;
- grpc_server_credentials* creds;
+ channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds)
+ : auth_context(auth_context->Ref()), creds(creds->Ref()) {}
+ ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); }
+
+ grpc_core::RefCountedPtr<grpc_auth_context> auth_context;
+ grpc_core::RefCountedPtr<grpc_server_credentials> creds;
};
struct call_data {
@@ -58,7 +62,7 @@ struct call_data {
grpc_server_security_context_create(args.arena);
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
server_ctx->auth_context =
- GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter");
+ chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter");
if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) {
args.context[GRPC_CONTEXT_SECURITY].destroy(
args.context[GRPC_CONTEXT_SECURITY].value);
@@ -208,7 +212,8 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) {
call_data* calld = static_cast<call_data*>(elem->call_data);
grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
if (error == GRPC_ERROR_NONE) {
- if (chand->creds != nullptr && chand->creds->processor.process != nullptr) {
+ if (chand->creds != nullptr &&
+ chand->creds->auth_metadata_processor().process != nullptr) {
// We're calling out to the application, so we need to make sure
// to drop the call combiner early if we get cancelled.
GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem,
@@ -218,9 +223,10 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) {
GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata");
calld->md = metadata_batch_to_md_array(
batch->payload->recv_initial_metadata.recv_initial_metadata);
- chand->creds->processor.process(
- chand->creds->processor.state, chand->auth_context,
- calld->md.metadata, calld->md.count, on_md_processing_done, elem);
+ chand->creds->auth_metadata_processor().process(
+ chand->creds->auth_metadata_processor().state,
+ chand->auth_context.get(), calld->md.metadata, calld->md.count,
+ on_md_processing_done, elem);
return;
}
}
@@ -290,23 +296,19 @@ static void destroy_call_elem(grpc_call_element* elem,
static grpc_error* init_channel_elem(grpc_channel_element* elem,
grpc_channel_element_args* args) {
GPR_ASSERT(!args->is_last);
- channel_data* chand = static_cast<channel_data*>(elem->channel_data);
grpc_auth_context* auth_context =
grpc_find_auth_context_in_args(args->channel_args);
GPR_ASSERT(auth_context != nullptr);
- chand->auth_context =
- GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter");
grpc_server_credentials* creds =
grpc_find_server_credentials_in_args(args->channel_args);
- chand->creds = grpc_server_credentials_ref(creds);
+ new (elem->channel_data) channel_data(auth_context, creds);
return GRPC_ERROR_NONE;
}
/* Destructor for channel data */
static void destroy_channel_elem(grpc_channel_element* elem) {
channel_data* chand = static_cast<channel_data*>(elem->channel_data);
- GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter");
- grpc_server_credentials_unref(chand->creds);
+ chand->~channel_data();
}
const grpc_channel_filter grpc_server_auth_filter = {
diff --git a/src/core/lib/surface/server.cc b/src/core/lib/surface/server.cc
index 67b38e6f0c..7ae6e51a5f 100644
--- a/src/core/lib/surface/server.cc
+++ b/src/core/lib/surface/server.cc
@@ -194,13 +194,10 @@ struct call_data {
};
struct request_matcher {
- request_matcher(grpc_server* server);
- ~request_matcher();
-
grpc_server* server;
- std::atomic<call_data*> pending_head{nullptr};
- call_data* pending_tail = nullptr;
- gpr_locked_mpscq* requests_per_cq = nullptr;
+ call_data* pending_head;
+ call_data* pending_tail;
+ gpr_locked_mpscq* requests_per_cq;
};
struct registered_method {
@@ -349,30 +346,22 @@ static void channel_broadcaster_shutdown(channel_broadcaster* cb,
* request_matcher
*/
-namespace {
-request_matcher::request_matcher(grpc_server* server) : server(server) {
- requests_per_cq = static_cast<gpr_locked_mpscq*>(
- gpr_malloc(sizeof(*requests_per_cq) * server->cq_count));
- for (size_t i = 0; i < server->cq_count; i++) {
- gpr_locked_mpscq_init(&requests_per_cq[i]);
- }
-}
-
-request_matcher::~request_matcher() {
+static void request_matcher_init(request_matcher* rm, grpc_server* server) {
+ memset(rm, 0, sizeof(*rm));
+ rm->server = server;
+ rm->requests_per_cq = static_cast<gpr_locked_mpscq*>(
+ gpr_malloc(sizeof(*rm->requests_per_cq) * server->cq_count));
for (size_t i = 0; i < server->cq_count; i++) {
- GPR_ASSERT(gpr_locked_mpscq_pop(&requests_per_cq[i]) == nullptr);
- gpr_locked_mpscq_destroy(&requests_per_cq[i]);
+ gpr_locked_mpscq_init(&rm->requests_per_cq[i]);
}
- gpr_free(requests_per_cq);
-}
-} // namespace
-
-static void request_matcher_init(request_matcher* rm, grpc_server* server) {
- new (rm) request_matcher(server);
}
static void request_matcher_destroy(request_matcher* rm) {
- rm->~request_matcher();
+ for (size_t i = 0; i < rm->server->cq_count; i++) {
+ GPR_ASSERT(gpr_locked_mpscq_pop(&rm->requests_per_cq[i]) == nullptr);
+ gpr_locked_mpscq_destroy(&rm->requests_per_cq[i]);
+ }
+ gpr_free(rm->requests_per_cq);
}
static void kill_zombie(void* elem, grpc_error* error) {
@@ -381,10 +370,9 @@ static void kill_zombie(void* elem, grpc_error* error) {
}
static void request_matcher_zombify_all_pending_calls(request_matcher* rm) {
- call_data* calld;
- while ((calld = rm->pending_head.load(std::memory_order_relaxed)) !=
- nullptr) {
- rm->pending_head.store(calld->pending_next, std::memory_order_relaxed);
+ while (rm->pending_head) {
+ call_data* calld = rm->pending_head;
+ rm->pending_head = calld->pending_next;
gpr_atm_no_barrier_store(&calld->state, ZOMBIED);
GRPC_CLOSURE_INIT(
&calld->kill_zombie_closure, kill_zombie,
@@ -582,9 +570,8 @@ static void publish_new_rpc(void* arg, grpc_error* error) {
}
gpr_atm_no_barrier_store(&calld->state, PENDING);
- if (rm->pending_head.load(std::memory_order_relaxed) == nullptr) {
- rm->pending_head.store(calld, std::memory_order_relaxed);
- rm->pending_tail = calld;
+ if (rm->pending_head == nullptr) {
+ rm->pending_tail = rm->pending_head = calld;
} else {
rm->pending_tail->pending_next = calld;
rm->pending_tail = calld;
@@ -1448,39 +1435,30 @@ static grpc_call_error queue_call_request(grpc_server* server, size_t cq_idx,
rm = &rc->data.registered.method->matcher;
break;
}
-
- // Fast path: if there is no pending request to be processed, immediately
- // return.
- if (!gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link) ||
- // Note: We are reading the pending_head without holding the server's call
- // mutex. Even if we read a non-null value here due to reordering,
- // we will check it below again after grabbing the lock.
- rm->pending_head.load(std::memory_order_relaxed) == nullptr) {
- return GRPC_CALL_OK;
- }
- // Slow path: This was the first queued request and there are pendings:
- // We need to lock and start matching calls.
- gpr_mu_lock(&server->mu_call);
- while ((calld = rm->pending_head.load(std::memory_order_relaxed)) !=
- nullptr) {
- rc = reinterpret_cast<requested_call*>(
- gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx]));
- if (rc == nullptr) break;
- rm->pending_head.store(calld->pending_next, std::memory_order_relaxed);
- gpr_mu_unlock(&server->mu_call);
- if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) {
- // Zombied Call
- GRPC_CLOSURE_INIT(
- &calld->kill_zombie_closure, kill_zombie,
- grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0),
- grpc_schedule_on_exec_ctx);
- GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE);
- } else {
- publish_call(server, calld, cq_idx, rc);
- }
+ if (gpr_locked_mpscq_push(&rm->requests_per_cq[cq_idx], &rc->request_link)) {
+ /* this was the first queued request: we need to lock and start
+ matching calls */
gpr_mu_lock(&server->mu_call);
+ while ((calld = rm->pending_head) != nullptr) {
+ rc = reinterpret_cast<requested_call*>(
+ gpr_locked_mpscq_pop(&rm->requests_per_cq[cq_idx]));
+ if (rc == nullptr) break;
+ rm->pending_head = calld->pending_next;
+ gpr_mu_unlock(&server->mu_call);
+ if (!gpr_atm_full_cas(&calld->state, PENDING, ACTIVATED)) {
+ // Zombied Call
+ GRPC_CLOSURE_INIT(
+ &calld->kill_zombie_closure, kill_zombie,
+ grpc_call_stack_element(grpc_call_get_call_stack(calld->call), 0),
+ grpc_schedule_on_exec_ctx);
+ GRPC_CLOSURE_SCHED(&calld->kill_zombie_closure, GRPC_ERROR_NONE);
+ } else {
+ publish_call(server, calld, cq_idx, rc);
+ }
+ gpr_mu_lock(&server->mu_call);
+ }
+ gpr_mu_unlock(&server->mu_call);
}
- gpr_mu_unlock(&server->mu_call);
return GRPC_CALL_OK;
}
diff --git a/src/core/lib/surface/version.cc b/src/core/lib/surface/version.cc
index 4829cc80a5..70d7580bec 100644
--- a/src/core/lib/surface/version.cc
+++ b/src/core/lib/surface/version.cc
@@ -25,4 +25,4 @@
const char* grpc_version_string(void) { return "7.0.0-dev"; }
-const char* grpc_g_stands_for(void) { return "goose"; }
+const char* grpc_g_stands_for(void) { return "gold"; }
diff --git a/src/core/lib/transport/metadata.cc b/src/core/lib/transport/metadata.cc
index 60af22393e..30482a1b3b 100644
--- a/src/core/lib/transport/metadata.cc
+++ b/src/core/lib/transport/metadata.cc
@@ -187,6 +187,7 @@ static void gc_mdtab(mdtab_shard* shard) {
((destroy_user_data_func)gpr_atm_no_barrier_load(
&md->destroy_user_data))(user_data);
}
+ gpr_mu_destroy(&md->mu_user_data);
gpr_free(md);
*prev_next = next;
num_freed++;
diff --git a/src/core/tsi/ssl_transport_security.cc b/src/core/tsi/ssl_transport_security.cc
index d6a72ada0d..fb6ea19210 100644
--- a/src/core/tsi/ssl_transport_security.cc
+++ b/src/core/tsi/ssl_transport_security.cc
@@ -156,9 +156,13 @@ static unsigned long openssl_thread_id_cb(void) {
#endif
static void init_openssl(void) {
+#if OPENSSL_API_COMPAT >= 0x10100000L
+ OPENSSL_init_ssl(0, NULL);
+#else
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
+#endif
#if OPENSSL_VERSION_NUMBER < 0x10100000
if (!CRYPTO_get_locking_callback()) {
int num_locks = CRYPTO_num_locks();
@@ -1649,7 +1653,11 @@ tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
return TSI_INVALID_ARGUMENT;
}
+#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L
+ ssl_context = SSL_CTX_new(TLS_method());
+#else
ssl_context = SSL_CTX_new(TLSv1_2_method());
+#endif
if (ssl_context == nullptr) {
gpr_log(GPR_ERROR, "Could not create ssl context.");
return TSI_INVALID_ARGUMENT;
@@ -1806,7 +1814,11 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
for (i = 0; i < options->num_key_cert_pairs; i++) {
do {
+#if defined(OPENSSL_NO_TLS1_2_METHOD) || OPENSSL_API_COMPAT >= 0x10100000L
+ impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
+#else
impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
+#endif
if (impl->ssl_contexts[i] == nullptr) {
gpr_log(GPR_ERROR, "Could not create ssl context.");
result = TSI_OUT_OF_RESOURCES;
@@ -1850,31 +1862,30 @@ tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
break;
}
SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
- switch (options->client_certificate_request) {
- case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
- SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
- break;
- case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
- SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
- NullVerifyCallback);
- break;
- case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
- SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
- break;
- case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
- SSL_CTX_set_verify(
- impl->ssl_contexts[i],
- SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
- NullVerifyCallback);
- break;
- case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
- SSL_CTX_set_verify(
- impl->ssl_contexts[i],
- SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
- break;
- }
- /* TODO(jboeuf): Add revocation verification. */
}
+ switch (options->client_certificate_request) {
+ case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
+ SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
+ break;
+ case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
+ SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
+ NullVerifyCallback);
+ break;
+ case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
+ SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
+ break;
+ case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
+ SSL_CTX_set_verify(impl->ssl_contexts[i],
+ SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
+ NullVerifyCallback);
+ break;
+ case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
+ SSL_CTX_set_verify(impl->ssl_contexts[i],
+ SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
+ nullptr);
+ break;
+ }
+ /* TODO(jboeuf): Add revocation verification. */
result = extract_x509_subject_names_from_pem_cert(
options->pem_key_cert_pairs[i].cert_chain,
diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc
index d0abe441a6..4d0ed355ab 100644
--- a/src/cpp/client/secure_credentials.cc
+++ b/src/cpp/client/secure_credentials.cc
@@ -261,10 +261,10 @@ void MetadataCredentialsPluginWrapper::InvokePlugin(
grpc_status_code* status_code, const char** error_details) {
std::multimap<grpc::string, grpc::string> metadata;
- // const_cast is safe since the SecureAuthContext does not take owndership and
- // the object is passed as a const ref to plugin_->GetMetadata.
+ // const_cast is safe since the SecureAuthContext only inc/dec the refcount
+ // and the object is passed as a const ref to plugin_->GetMetadata.
SecureAuthContext cpp_channel_auth_context(
- const_cast<grpc_auth_context*>(context.channel_auth_context), false);
+ const_cast<grpc_auth_context*>(context.channel_auth_context));
Status status = plugin_->GetMetadata(context.service_url, context.method_name,
cpp_channel_auth_context, &metadata);
diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h
index 613f1d6dc2..4918bd5a4d 100644
--- a/src/cpp/client/secure_credentials.h
+++ b/src/cpp/client/secure_credentials.h
@@ -24,6 +24,7 @@
#include <grpcpp/security/credentials.h>
#include <grpcpp/support/config.h>
+#include "src/core/lib/security/credentials/credentials.h"
#include "src/cpp/server/thread_pool_interface.h"
namespace grpc {
@@ -31,7 +32,9 @@ namespace grpc {
class SecureChannelCredentials final : public ChannelCredentials {
public:
explicit SecureChannelCredentials(grpc_channel_credentials* c_creds);
- ~SecureChannelCredentials() { grpc_channel_credentials_release(c_creds_); }
+ ~SecureChannelCredentials() {
+ if (c_creds_ != nullptr) c_creds_->Unref();
+ }
grpc_channel_credentials* GetRawCreds() { return c_creds_; }
std::shared_ptr<grpc::Channel> CreateChannel(
@@ -51,7 +54,9 @@ class SecureChannelCredentials final : public ChannelCredentials {
class SecureCallCredentials final : public CallCredentials {
public:
explicit SecureCallCredentials(grpc_call_credentials* c_creds);
- ~SecureCallCredentials() { grpc_call_credentials_release(c_creds_); }
+ ~SecureCallCredentials() {
+ if (c_creds_ != nullptr) c_creds_->Unref();
+ }
grpc_call_credentials* GetRawCreds() { return c_creds_; }
bool ApplyToCall(grpc_call* call) override;
diff --git a/src/cpp/common/alarm.cc b/src/cpp/common/alarm.cc
index 5819a4210b..148f0b9bc9 100644
--- a/src/cpp/common/alarm.cc
+++ b/src/cpp/common/alarm.cc
@@ -31,10 +31,10 @@
#include <grpc/support/log.h>
#include "src/core/lib/debug/trace.h"
-namespace grpc {
+namespace grpc_impl {
namespace internal {
-class AlarmImpl : public CompletionQueueTag {
+class AlarmImpl : public ::grpc::internal::CompletionQueueTag {
public:
AlarmImpl() : cq_(nullptr), tag_(nullptr) {
gpr_ref_init(&refs_, 1);
@@ -51,7 +51,7 @@ class AlarmImpl : public CompletionQueueTag {
Unref();
return true;
}
- void Set(CompletionQueue* cq, gpr_timespec deadline, void* tag) {
+ void Set(::grpc::CompletionQueue* cq, gpr_timespec deadline, void* tag) {
grpc_core::ExecCtx exec_ctx;
GRPC_CQ_INTERNAL_REF(cq->cq(), "alarm");
cq_ = cq->cq();
@@ -114,13 +114,14 @@ class AlarmImpl : public CompletionQueueTag {
};
} // namespace internal
-static internal::GrpcLibraryInitializer g_gli_initializer;
+static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer;
Alarm::Alarm() : alarm_(new internal::AlarmImpl()) {
g_gli_initializer.summon();
}
-void Alarm::SetInternal(CompletionQueue* cq, gpr_timespec deadline, void* tag) {
+void Alarm::SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline,
+ void* tag) {
// Note that we know that alarm_ is actually an internal::AlarmImpl
// but we declared it as the base pointer to avoid a forward declaration
// or exposing core data structures in the C++ public headers.
@@ -145,4 +146,4 @@ Alarm::~Alarm() {
}
void Alarm::Cancel() { static_cast<internal::AlarmImpl*>(alarm_)->Cancel(); }
-} // namespace grpc
+} // namespace grpc_impl
diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc
index 50ee9d871f..214d72f853 100644
--- a/src/cpp/common/channel_arguments.cc
+++ b/src/cpp/common/channel_arguments.cc
@@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) {
}
if (!replaced) {
+ strings_.push_back(grpc::string(mutator_arg.key));
args_.push_back(mutator_arg);
+ args_.back().key = const_cast<char*>(strings_.back().c_str());
}
}
diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc
index 1d66dd3d1f..7a2b5afed6 100644
--- a/src/cpp/common/secure_auth_context.cc
+++ b/src/cpp/common/secure_auth_context.cc
@@ -22,19 +22,12 @@
namespace grpc {
-SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx,
- bool take_ownership)
- : ctx_(ctx), take_ownership_(take_ownership) {}
-
-SecureAuthContext::~SecureAuthContext() {
- if (take_ownership_) grpc_auth_context_release(ctx_);
-}
-
std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const {
- if (!ctx_) {
+ if (ctx_ == nullptr) {
return std::vector<grpc::string_ref>();
}
- grpc_auth_property_iterator iter = grpc_auth_context_peer_identity(ctx_);
+ grpc_auth_property_iterator iter =
+ grpc_auth_context_peer_identity(ctx_.get());
std::vector<grpc::string_ref> identity;
const grpc_auth_property* property = nullptr;
while ((property = grpc_auth_property_iterator_next(&iter))) {
@@ -45,20 +38,20 @@ std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const {
}
grpc::string SecureAuthContext::GetPeerIdentityPropertyName() const {
- if (!ctx_) {
+ if (ctx_ == nullptr) {
return "";
}
- const char* name = grpc_auth_context_peer_identity_property_name(ctx_);
+ const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get());
return name == nullptr ? "" : name;
}
std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues(
const grpc::string& name) const {
- if (!ctx_) {
+ if (ctx_ == nullptr) {
return std::vector<grpc::string_ref>();
}
grpc_auth_property_iterator iter =
- grpc_auth_context_find_properties_by_name(ctx_, name.c_str());
+ grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str());
const grpc_auth_property* property = nullptr;
std::vector<grpc::string_ref> values;
while ((property = grpc_auth_property_iterator_next(&iter))) {
@@ -68,9 +61,9 @@ std::vector<grpc::string_ref> SecureAuthContext::FindPropertyValues(
}
AuthPropertyIterator SecureAuthContext::begin() const {
- if (ctx_) {
+ if (ctx_ != nullptr) {
grpc_auth_property_iterator iter =
- grpc_auth_context_property_iterator(ctx_);
+ grpc_auth_context_property_iterator(ctx_.get());
const grpc_auth_property* property =
grpc_auth_property_iterator_next(&iter);
return AuthPropertyIterator(property, &iter);
@@ -85,19 +78,20 @@ AuthPropertyIterator SecureAuthContext::end() const {
void SecureAuthContext::AddProperty(const grpc::string& key,
const grpc::string_ref& value) {
- if (!ctx_) return;
- grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size());
+ if (ctx_ == nullptr) return;
+ grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(),
+ value.size());
}
bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) {
- if (!ctx_) return false;
- return grpc_auth_context_set_peer_identity_property_name(ctx_,
+ if (ctx_ == nullptr) return false;
+ return grpc_auth_context_set_peer_identity_property_name(ctx_.get(),
name.c_str()) != 0;
}
bool SecureAuthContext::IsPeerAuthenticated() const {
- if (!ctx_) return false;
- return grpc_auth_context_peer_is_authenticated(ctx_) != 0;
+ if (ctx_ == nullptr) return false;
+ return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0;
}
} // namespace grpc
diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h
index 142617959c..2e8f793721 100644
--- a/src/cpp/common/secure_auth_context.h
+++ b/src/cpp/common/secure_auth_context.h
@@ -21,15 +21,17 @@
#include <grpcpp/security/auth_context.h>
-struct grpc_auth_context;
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
+#include "src/core/lib/security/context/security_context.h"
namespace grpc {
class SecureAuthContext final : public AuthContext {
public:
- SecureAuthContext(grpc_auth_context* ctx, bool take_ownership);
+ explicit SecureAuthContext(grpc_auth_context* ctx)
+ : ctx_(ctx != nullptr ? ctx->Ref() : nullptr) {}
- ~SecureAuthContext() override;
+ ~SecureAuthContext() override = default;
bool IsPeerAuthenticated() const override;
@@ -50,8 +52,7 @@ class SecureAuthContext final : public AuthContext {
virtual bool SetPeerIdentityPropertyName(const grpc::string& name) override;
private:
- grpc_auth_context* ctx_;
- bool take_ownership_;
+ grpc_core::RefCountedPtr<grpc_auth_context> ctx_;
};
} // namespace grpc
diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc
index bc1387c8d7..908c46629e 100644
--- a/src/cpp/common/secure_create_auth_context.cc
+++ b/src/cpp/common/secure_create_auth_context.cc
@@ -20,6 +20,7 @@
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
#include <grpcpp/security/auth_context.h>
+#include "src/core/lib/gprpp/ref_counted_ptr.h"
#include "src/cpp/common/secure_auth_context.h"
namespace grpc {
@@ -28,8 +29,8 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) {
if (call == nullptr) {
return std::shared_ptr<const AuthContext>();
}
- return std::shared_ptr<const AuthContext>(
- new SecureAuthContext(grpc_call_auth_context(call), true));
+ grpc_core::RefCountedPtr<grpc_auth_context> ctx(grpc_call_auth_context(call));
+ return std::make_shared<SecureAuthContext>(ctx.get());
}
} // namespace grpc
diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc
index 55da89e6c8..358131c7c4 100644
--- a/src/cpp/common/version_cc.cc
+++ b/src/cpp/common/version_cc.cc
@@ -22,5 +22,5 @@
#include <grpcpp/grpcpp.h>
namespace grpc {
-grpc::string Version() { return "1.18.0-dev"; }
+grpc::string Version() { return "1.19.0-dev"; }
} // namespace grpc
diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc
index 78fc69a805..160590353a 100644
--- a/src/cpp/ext/filters/census/context.cc
+++ b/src/cpp/ext/filters/census/context.cc
@@ -28,6 +28,9 @@ using ::opencensus::trace::SpanContext;
void GenerateServerContext(absl::string_view tracing, absl::string_view stats,
absl::string_view primary_role,
absl::string_view method, CensusContext* context) {
+ // Destruct the current CensusContext to free the Span memory before
+ // overwriting it below.
+ context->~CensusContext();
GrpcTraceContext trace_ctxt;
if (TraceContextEncoding::Decode(tracing, &trace_ctxt) !=
TraceContextEncoding::kEncodeDecodeFailure) {
@@ -42,6 +45,9 @@ void GenerateServerContext(absl::string_view tracing, absl::string_view stats,
void GenerateClientContext(absl::string_view method, CensusContext* ctxt,
CensusContext* parent_ctxt) {
+ // Destruct the current CensusContext to free the Span memory before
+ // overwriting it below.
+ ctxt->~CensusContext();
if (parent_ctxt != nullptr) {
SpanContext span_ctxt = parent_ctxt->Context();
Span span = parent_ctxt->Span();
diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc
index ebb17def32..453e76eb25 100644
--- a/src/cpp/server/secure_server_credentials.cc
+++ b/src/cpp/server/secure_server_credentials.cc
@@ -61,7 +61,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor(
metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key),
StringRefFromSlice(&md[i].value)));
}
- SecureAuthContext context(ctx, false);
+ SecureAuthContext context(ctx);
AuthMetadataProcessor::OutputMetadata consumed_metadata;
AuthMetadataProcessor::OutputMetadata response_metadata;
diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc
index 1e3c57446f..13741ce7aa 100644
--- a/src/cpp/server/server_cc.cc
+++ b/src/cpp/server/server_cc.cc
@@ -278,7 +278,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
request_payload_ = nullptr;
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
- interceptor_methods_.SetRecvMessage(request_);
+ interceptor_methods_.SetRecvMessage(request_, nullptr);
}
if (interceptor_methods_.RunInterceptors(
@@ -446,7 +446,7 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
req_->request_payload_ = nullptr;
req_->interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
- req_->interceptor_methods_.SetRecvMessage(req_->request_);
+ req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr);
}
if (req_->interceptor_methods_.RunInterceptors(
diff --git a/src/csharp/Grpc.Core/Version.csproj.include b/src/csharp/Grpc.Core/Version.csproj.include
index 4fffe4f644..52ab2215eb 100755
--- a/src/csharp/Grpc.Core/Version.csproj.include
+++ b/src/csharp/Grpc.Core/Version.csproj.include
@@ -1,7 +1,7 @@
<!-- This file is generated -->
<Project>
<PropertyGroup>
- <GrpcCsharpVersion>1.18.0-dev</GrpcCsharpVersion>
+ <GrpcCsharpVersion>1.19.0-dev</GrpcCsharpVersion>
<GoogleProtobufVersion>3.6.1</GoogleProtobufVersion>
</PropertyGroup>
</Project>
diff --git a/src/csharp/Grpc.Core/VersionInfo.cs b/src/csharp/Grpc.Core/VersionInfo.cs
index 633880189c..8f3be310ee 100644
--- a/src/csharp/Grpc.Core/VersionInfo.cs
+++ b/src/csharp/Grpc.Core/VersionInfo.cs
@@ -33,11 +33,11 @@ namespace Grpc.Core
/// <summary>
/// Current <c>AssemblyFileVersion</c> of gRPC C# assemblies
/// </summary>
- public const string CurrentAssemblyFileVersion = "1.18.0.0";
+ public const string CurrentAssemblyFileVersion = "1.19.0.0";
/// <summary>
/// Current version of gRPC C#
/// </summary>
- public const string CurrentVersion = "1.18.0-dev";
+ public const string CurrentVersion = "1.19.0-dev";
}
}
diff --git a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs
index e83a8a7274..4750353082 100644
--- a/src/csharp/Grpc.IntegrationTesting/InteropClient.cs
+++ b/src/csharp/Grpc.IntegrationTesting/InteropClient.cs
@@ -45,7 +45,7 @@ namespace Grpc.IntegrationTesting
[Option("server_host", Default = "localhost")]
public string ServerHost { get; set; }
- [Option("server_host_override", Default = TestCredentials.DefaultHostOverride)]
+ [Option("server_host_override")]
public string ServerHostOverride { get; set; }
[Option("server_port", Required = true)]
diff --git a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets
index 5f76c03ce5..3fe1ccc918 100644
--- a/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets
+++ b/src/csharp/Grpc.Tools/build/_grpc/_Grpc.Tools.targets
@@ -22,9 +22,8 @@
<Target Name="gRPC_ResolvePluginFullPath" AfterTargets="Protobuf_ResolvePlatform">
<PropertyGroup>
<!-- TODO(kkm): Do not use Protobuf_PackagedToolsPath, roll gRPC's own. -->
- <!-- TODO(kkm): Do not package windows x64 builds (#13098). -->
<gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' "
- >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath>
+ >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\$(gRPC_PluginFileName).exe</gRPC_PluginFullPath>
<gRPC_PluginFullPath Condition=" '$(gRPC_PluginFullPath)' == '' "
>$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/$(gRPC_PluginFileName)</gRPC_PluginFullPath>
</PropertyGroup>
diff --git a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets
index 1d233d23a8..26f9efb5a8 100644
--- a/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets
+++ b/src/csharp/Grpc.Tools/build/_protobuf/Google.Protobuf.Tools.targets
@@ -74,9 +74,8 @@
<!-- Next try OS and CPU resolved by ProtoToolsPlatform. -->
<Protobuf_ToolsOs Condition=" '$(Protobuf_ToolsOs)' == '' ">$(_Protobuf_ToolsOs)</Protobuf_ToolsOs>
<Protobuf_ToolsCpu Condition=" '$(Protobuf_ToolsCpu)' == '' ">$(_Protobuf_ToolsCpu)</Protobuf_ToolsCpu>
- <!-- TODO(kkm): Do not package windows x64 builds (#13098). -->
<Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' and '$(Protobuf_ToolsOs)' == 'windows' "
- >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_x86\protoc.exe</Protobuf_ProtocFullPath>
+ >$(Protobuf_PackagedToolsPath)\$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)\protoc.exe</Protobuf_ProtocFullPath>
<Protobuf_ProtocFullPath Condition=" '$(Protobuf_ProtocFullPath)' == '' "
>$(Protobuf_PackagedToolsPath)/$(Protobuf_ToolsOs)_$(Protobuf_ToolsCpu)/protoc</Protobuf_ProtocFullPath>
</PropertyGroup>
diff --git a/src/csharp/build_packages_dotnetcli.bat b/src/csharp/build_packages_dotnetcli.bat
index 76d4f14390..fef1a43bb8 100755
--- a/src/csharp/build_packages_dotnetcli.bat
+++ b/src/csharp/build_packages_dotnetcli.bat
@@ -13,7 +13,7 @@
@rem limitations under the License.
@rem Current package versions
-set VERSION=1.18.0-dev
+set VERSION=1.19.0-dev
@rem Adjust the location of nuget.exe
set NUGET=C:\nuget\nuget.exe
diff --git a/src/csharp/build_unitypackage.bat b/src/csharp/build_unitypackage.bat
index 3334d24c11..6b66b941a8 100644
--- a/src/csharp/build_unitypackage.bat
+++ b/src/csharp/build_unitypackage.bat
@@ -13,7 +13,7 @@
@rem limitations under the License.
@rem Current package versions
-set VERSION=1.18.0-dev
+set VERSION=1.19.0-dev
@rem Adjust the location of nuget.exe
set NUGET=C:\nuget\nuget.exe
diff --git a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec
index 55ca6048bc..659cfebbdc 100644
--- a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec
+++ b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec
@@ -42,7 +42,7 @@ Pod::Spec.new do |s|
# exclamation mark ensures that other "regular" pods will be able to find it as it'll be installed
# before them.
s.name = '!ProtoCompiler-gRPCPlugin'
- v = '1.18.0-dev'
+ v = '1.19.0-dev'
s.version = v
s.summary = 'The gRPC ProtoC plugin generates Objective-C files from .proto services.'
s.description = <<-DESC
diff --git a/src/objective-c/GRPCClient/private/version.h b/src/objective-c/GRPCClient/private/version.h
index 0be0e3c9a0..5e089fde31 100644
--- a/src/objective-c/GRPCClient/private/version.h
+++ b/src/objective-c/GRPCClient/private/version.h
@@ -22,4 +22,4 @@
// instead. This file can be regenerated from the template by running
// `tools/buildgen/generate_projects.sh`.
-#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev"
+#define GRPC_OBJC_VERSION_STRING @"1.19.0-dev"
diff --git a/src/objective-c/README.md b/src/objective-c/README.md
index 32e3956a1e..83775f86e1 100644
--- a/src/objective-c/README.md
+++ b/src/objective-c/README.md
@@ -242,3 +242,12 @@ pod `gRPC-Core`, :podspec => "." # assuming gRPC-Core.podspec is in the same dir
These steps should allow gRPC to use OpenSSL and drop BoringSSL dependency. If you see any issue,
file an issue to us.
+
+## Upgrade issue with BoringSSL
+If you were using an old version of gRPC (<= v1.14) which depended on pod `BoringSSL` rather than
+`BoringSSL-GRPC` and meet issue with the library like:
+```
+ld: framework not found openssl
+```
+updating `-framework openssl` in Other Linker Flags to `-framework openssl_grpc` in your project
+may resolve this issue (see [#16821](https://github.com/grpc/grpc/issues/16821)).
diff --git a/src/objective-c/tests/version.h b/src/objective-c/tests/version.h
index f2fd692070..54f95ad16a 100644
--- a/src/objective-c/tests/version.h
+++ b/src/objective-c/tests/version.h
@@ -22,5 +22,5 @@
// instead. This file can be regenerated from the template by running
// `tools/buildgen/generate_projects.sh`.
-#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev"
+#define GRPC_OBJC_VERSION_STRING @"1.19.0-dev"
#define GRPC_C_VERSION_STRING @"7.0.0-dev"
diff --git a/src/php/composer.json b/src/php/composer.json
index 9c298c0e85..75fab483f1 100644
--- a/src/php/composer.json
+++ b/src/php/composer.json
@@ -2,7 +2,7 @@
"name": "grpc/grpc-dev",
"description": "gRPC library for PHP - for Developement use only",
"license": "Apache-2.0",
- "version": "1.18.0",
+ "version": "1.19.0",
"require": {
"php": ">=5.5.0",
"google/protobuf": "^v3.3.0"
diff --git a/src/php/ext/grpc/version.h b/src/php/ext/grpc/version.h
index 1ddf90a667..c85ee4d315 100644
--- a/src/php/ext/grpc/version.h
+++ b/src/php/ext/grpc/version.h
@@ -20,6 +20,6 @@
#ifndef VERSION_H
#define VERSION_H
-#define PHP_GRPC_VERSION "1.18.0dev"
+#define PHP_GRPC_VERSION "1.19.0dev"
#endif /* VERSION_H */
diff --git a/src/php/tests/interop/interop_client.php b/src/php/tests/interop/interop_client.php
index c865678f70..19cbf21bc2 100755
--- a/src/php/tests/interop/interop_client.php
+++ b/src/php/tests/interop/interop_client.php
@@ -530,7 +530,7 @@ function _makeStub($args)
throw new Exception('Missing argument: --test_case is required');
}
- if ($args['server_port'] === 443) {
+ if ($args['server_port'] === '443') {
$server_address = $args['server_host'];
} else {
$server_address = $args['server_host'].':'.$args['server_port'];
@@ -538,7 +538,7 @@ function _makeStub($args)
$test_case = $args['test_case'];
- $host_override = 'foo.test.google.fr';
+ $host_override = '';
if (array_key_exists('server_host_override', $args)) {
$host_override = $args['server_host_override'];
}
@@ -565,7 +565,9 @@ function _makeStub($args)
$ssl_credentials = Grpc\ChannelCredentials::createSsl();
}
$opts['credentials'] = $ssl_credentials;
- $opts['grpc.ssl_target_name_override'] = $host_override;
+ if (!empty($host_override)) {
+ $opts['grpc.ssl_target_name_override'] = $host_override;
+ }
} else {
$opts['credentials'] = Grpc\ChannelCredentials::createInsecure();
}
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index 6022fc3ef2..70d7618e05 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -23,6 +23,11 @@ from grpc._cython import cygrpc as _cygrpc
logging.getLogger(__name__).addHandler(logging.NullHandler())
+try:
+ from ._grpcio_metadata import __version__
+except ImportError:
+ __version__ = "dev0"
+
############################## Future Interface ###############################
@@ -266,6 +271,22 @@ class StatusCode(enum.Enum):
UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated')
+############################# gRPC Status ################################
+
+
+class Status(six.with_metaclass(abc.ABCMeta)):
+ """Describes the status of an RPC.
+
+ This is an EXPERIMENTAL API.
+
+ Attributes:
+ code: A StatusCode object to be sent to the client.
+ details: An ASCII-encodable string to be sent to the client upon
+ termination of the RPC.
+ trailing_metadata: The trailing :term:`metadata` in the RPC.
+ """
+
+
############################# gRPC Exceptions ################################
@@ -1119,6 +1140,25 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
raise NotImplementedError()
@abc.abstractmethod
+ def abort_with_status(self, status):
+ """Raises an exception to terminate the RPC with a non-OK status.
+
+ The status passed as argument will supercede any existing status code,
+ status message and trailing metadata.
+
+ This is an EXPERIMENTAL API.
+
+ Args:
+ status: A grpc.Status object. The status code in it must not be
+ StatusCode.OK.
+
+ Raises:
+ Exception: An exception is always raised to signal the abortion the
+ RPC to the gRPC runtime.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
def set_code(self, code):
"""Sets the value to be used as status code upon RPC completion.
@@ -1747,6 +1787,7 @@ __all__ = (
'Future',
'ChannelConnectivity',
'StatusCode',
+ 'Status',
'RpcError',
'RpcContext',
'Call',
diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py
index c17824563d..9b990f490d 100644
--- a/src/python/grpcio/grpc/_auth.py
+++ b/src/python/grpcio/grpc/_auth.py
@@ -46,7 +46,7 @@ class GoogleCallCredentials(grpc.AuthMetadataPlugin):
# Hack to determine if these are JWT creds and we need to pass
# additional_claims when getting a token
- self._is_jwt = 'additional_claims' in inspect.getargspec(
+ self._is_jwt = 'additional_claims' in inspect.getargspec( # pylint: disable=deprecated-method
credentials.get_access_token).args
def __call__(self, context, callback):
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index 35fa82d56b..8051fb306c 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -499,6 +499,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
+ self._context = cygrpc.build_context()
def _prepare(self, request, timeout, metadata, wait_for_ready):
deadline, serialized_request, rendezvous = _start_unary_request(
@@ -525,17 +526,18 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready)
if state is None:
- raise rendezvous
+ raise rendezvous # pylint: disable-msg=raising-bad-type
else:
call = self._channel.segregated_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
+ self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials, ((
operations,
None,
- ),))
+ ),), self._context)
event = call.next_event()
_handle_event(event, state, self._response_deserializer)
- return state, call,
+ return state, call
def __call__(self,
request,
@@ -566,13 +568,14 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
state, operations, deadline, rendezvous = self._prepare(
request, timeout, metadata, wait_for_ready)
if state is None:
- raise rendezvous
+ raise rendezvous # pylint: disable-msg=raising-bad-type
else:
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
+ self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
- (operations,), event_handler)
+ (operations,), event_handler, self._context)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
@@ -587,6 +590,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
+ self._context = cygrpc.build_context()
def __call__(self,
request,
@@ -599,7 +603,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
if serialized_request is None:
- raise rendezvous
+ raise rendezvous # pylint: disable-msg=raising-bad-type
else:
state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None)
operationses = (
@@ -615,9 +619,10 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS,
+ self._method, None, deadline, metadata, None
if credentials is None else credentials._credentials,
- operationses, event_handler)
+ operationses, event_handler, self._context)
return _Rendezvous(state, call, self._response_deserializer,
deadline)
@@ -632,6 +637,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
+ self._context = cygrpc.build_context()
def _blocking(self, request_iterator, timeout, metadata, credentials,
wait_for_ready):
@@ -640,10 +646,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
call = self._channel.segregated_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
+ None, deadline, metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses_and_tags(
- metadata, initial_metadata_flags))
+ metadata, initial_metadata_flags), self._context)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, None)
while True:
@@ -653,7 +660,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
state.condition.notify_all()
if not state.due:
break
- return state, call,
+ return state, call
def __call__(self,
request_iterator,
@@ -687,10 +694,11 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready(
wait_for_ready)
call = self._managed_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
+ None, deadline, metadata, None
if credentials is None else credentials._credentials,
_stream_unary_invocation_operationses(
- metadata, initial_metadata_flags), event_handler)
+ metadata, initial_metadata_flags), event_handler, self._context)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@@ -706,6 +714,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
+ self._context = cygrpc.build_context()
def __call__(self,
request_iterator,
@@ -727,9 +736,10 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
)
event_handler = _event_handler(state, self._response_deserializer)
call = self._managed_call(
- 0, self._method, None, deadline, metadata, None
+ cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method,
+ None, deadline, metadata, None
if credentials is None else credentials._credentials, operationses,
- event_handler)
+ event_handler, self._context)
_consume_request_iterator(request_iterator, state, call,
self._request_serializer, event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@@ -745,10 +755,10 @@ class _InitialMetadataFlags(int):
def with_wait_for_ready(self, wait_for_ready):
if wait_for_ready is not None:
if wait_for_ready:
- self = self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \
+ return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \
cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set)
elif not wait_for_ready:
- self = self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \
+ return self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \
cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set)
return self
@@ -789,7 +799,7 @@ def _channel_managed_call_management(state):
# pylint: disable=too-many-arguments
def create(flags, method, host, deadline, metadata, credentials,
- operationses, event_handler):
+ operationses, event_handler, context):
"""Creates a cygrpc.IntegratedCall.
Args:
@@ -804,7 +814,7 @@ def _channel_managed_call_management(state):
started on the call.
event_handler: A behavior to call to handle the events resultant from
the operations on the call.
-
+ context: Context object for distributed tracing.
Returns:
A cygrpc.IntegratedCall with which to conduct an RPC.
"""
@@ -815,7 +825,7 @@ def _channel_managed_call_management(state):
with state.lock:
call = state.channel.integrated_call(flags, method, host, deadline,
metadata, credentials,
- operationses_and_tags)
+ operationses_and_tags, context)
if state.managed_calls == 0:
state.managed_calls = 1
_run_channel_spin_thread(state)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi
index e0e068e452..01b8237484 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pxd.pxi
@@ -28,19 +28,22 @@ cdef tuple _wrap_grpc_arg(grpc_arg arg)
cdef grpc_arg _unwrap_grpc_arg(tuple wrapped_arg)
-cdef class _ArgumentProcessor:
+cdef class _ChannelArg:
cdef grpc_arg c_argument
cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references) except *
-cdef class _ArgumentsProcessor:
+cdef class _ChannelArgs:
cdef readonly tuple _arguments
- cdef list _argument_processors
+ cdef list _channel_args
cdef readonly list _references
cdef grpc_channel_args _c_arguments
- cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable) except *
- cdef un_c(self)
+ cdef void _c(self, grpc_arg_pointer_vtable *vtable) except *
+ cdef grpc_channel_args *c_args(self) except *
+
+ @staticmethod
+ cdef _ChannelArgs from_args(object arguments, grpc_arg_pointer_vtable * vtable)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi
index b7a4277ff6..bf12871015 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/arguments.pyx.pxi
@@ -50,7 +50,7 @@ cdef grpc_arg _unwrap_grpc_arg(tuple wrapped_arg):
return wrapped.arg
-cdef class _ArgumentProcessor:
+cdef class _ChannelArg:
cdef void c(self, argument, grpc_arg_pointer_vtable *vtable, references) except *:
key, value = argument
@@ -82,27 +82,34 @@ cdef class _ArgumentProcessor:
'Expected int, bytes, or behavior, got {}'.format(type(value)))
-cdef class _ArgumentsProcessor:
+cdef class _ChannelArgs:
def __cinit__(self, arguments):
self._arguments = () if arguments is None else tuple(arguments)
- self._argument_processors = []
+ self._channel_args = []
self._references = []
+ self._c_arguments.arguments = NULL
- cdef grpc_channel_args *c(self, grpc_arg_pointer_vtable *vtable) except *:
+ cdef void _c(self, grpc_arg_pointer_vtable *vtable) except *:
self._c_arguments.arguments_length = len(self._arguments)
- if self._c_arguments.arguments_length == 0:
- return NULL
- else:
+ if self._c_arguments.arguments_length != 0:
self._c_arguments.arguments = <grpc_arg *>gpr_malloc(
self._c_arguments.arguments_length * sizeof(grpc_arg))
for index, argument in enumerate(self._arguments):
- argument_processor = _ArgumentProcessor()
- argument_processor.c(argument, vtable, self._references)
- self._c_arguments.arguments[index] = argument_processor.c_argument
- self._argument_processors.append(argument_processor)
- return &self._c_arguments
-
- cdef un_c(self):
- if self._arguments:
+ channel_arg = _ChannelArg()
+ channel_arg.c(argument, vtable, self._references)
+ self._c_arguments.arguments[index] = channel_arg.c_argument
+ self._channel_args.append(channel_arg)
+
+ cdef grpc_channel_args *c_args(self) except *:
+ return &self._c_arguments
+
+ def __dealloc__(self):
+ if self._c_arguments.arguments != NULL:
gpr_free(self._c_arguments.arguments)
+
+ @staticmethod
+ cdef _ChannelArgs from_args(object arguments, grpc_arg_pointer_vtable * vtable):
+ cdef _ChannelArgs channel_args = _ChannelArgs(arguments)
+ channel_args._c(vtable)
+ return channel_args
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
index 135d224095..70d4abb730 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
@@ -423,16 +423,15 @@ cdef class Channel:
self._vtable.copy = &_copy_pointer
self._vtable.destroy = &_destroy_pointer
self._vtable.cmp = &_compare_pointer
- cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor(
- arguments)
- cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable)
+ cdef _ChannelArgs channel_args = _ChannelArgs.from_args(
+ arguments, &self._vtable)
if channel_credentials is None:
self._state.c_channel = grpc_insecure_channel_create(
- <char *>target, c_arguments, NULL)
+ <char *>target, channel_args.c_args(), NULL)
else:
c_channel_credentials = channel_credentials.c()
self._state.c_channel = grpc_secure_channel_create(
- c_channel_credentials, <char *>target, c_arguments, NULL)
+ c_channel_credentials, <char *>target, channel_args.c_args(), NULL)
grpc_channel_credentials_release(c_channel_credentials)
self._state.c_call_completion_queue = (
grpc_completion_queue_create_for_next(NULL))
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
index 141116df5d..3c33b46dbb 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/completion_queue.pyx.pxi
@@ -49,7 +49,7 @@ cdef grpc_event _next(grpc_completion_queue *c_completion_queue, deadline):
cdef _interpret_event(grpc_event c_event):
cdef _Tag tag
if c_event.type == GRPC_QUEUE_TIMEOUT:
- # NOTE(nathaniel): For now we coopt ConnectivityEvent here.
+ # TODO(ericgribkoff) Do not coopt ConnectivityEvent here.
return None, ConnectivityEvent(GRPC_QUEUE_TIMEOUT, False, None)
elif c_event.type == GRPC_QUEUE_SHUTDOWN:
# NOTE(nathaniel): For now we coopt ConnectivityEvent here.
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi
index 52cfccb677..4a6fbe0f96 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pxd.pxi
@@ -16,7 +16,6 @@
cdef class Server:
cdef grpc_arg_pointer_vtable _vtable
- cdef readonly _ArgumentsProcessor _arguments_processor
cdef grpc_server *c_server
cdef bint is_started # start has been called
cdef bint is_shutting_down # shutdown has been called
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
index ce701724fd..d72648a35d 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/server.pyx.pxi
@@ -29,11 +29,9 @@ cdef class Server:
self._vtable.copy = &_copy_pointer
self._vtable.destroy = &_destroy_pointer
self._vtable.cmp = &_compare_pointer
- cdef _ArgumentsProcessor arguments_processor = _ArgumentsProcessor(
- arguments)
- cdef grpc_channel_args *c_arguments = arguments_processor.c(&self._vtable)
- self.c_server = grpc_server_create(c_arguments, NULL)
- arguments_processor.un_c()
+ cdef _ChannelArgs channel_args = _ChannelArgs.from_args(
+ arguments, &self._vtable)
+ self.c_server = grpc_server_create(channel_args.c_args(), NULL)
self.references.append(arguments)
self.is_started = False
self.is_shutting_down = False
@@ -128,7 +126,10 @@ cdef class Server:
with nogil:
grpc_server_cancel_all_calls(self.c_server)
- def __dealloc__(self):
+ # TODO(https://github.com/grpc/grpc/issues/17515) Determine what, if any,
+ # portion of this is safe to call from __dealloc__, and potentially remove
+ # backup_shutdown_queue.
+ def destroy(self):
if self.c_server != NULL:
if not self.is_started:
pass
@@ -146,4 +147,8 @@ cdef class Server:
while not self.is_shutdown:
time.sleep(0)
grpc_server_destroy(self.c_server)
- grpc_shutdown()
+ self.c_server = NULL
+
+ def __dealloc(self):
+ if self.c_server == NULL:
+ grpc_shutdown()
diff --git a/src/python/grpcio/grpc/_grpcio_metadata.py b/src/python/grpcio/grpc/_grpcio_metadata.py
index 7a9f173947..dd9d436c3f 100644
--- a/src/python/grpcio/grpc/_grpcio_metadata.py
+++ b/src/python/grpcio/grpc/_grpcio_metadata.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc/_grpcio_metadata.py.template`!!!
-__version__ = """1.18.0.dev0"""
+__version__ = """1.19.0.dev0"""
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index e939f615df..eb750ef1a8 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -48,7 +48,7 @@ _CANCELLED = 'cancelled'
_EMPTY_FLAGS = 0
-_UNEXPECTED_EXIT_SERVER_GRACE = 1.0
+_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
def _serialized_request(request_event):
@@ -291,6 +291,10 @@ class _Context(grpc.ServicerContext):
self._state.abortion = Exception()
raise self._state.abortion
+ def abort_with_status(self, status):
+ self._state.trailing_metadata = status.trailing_metadata
+ self.abort(status.code, status.details)
+
def set_code(self, code):
with self._state.condition:
self._state.code = code
@@ -672,6 +676,9 @@ class _ServerState(object):
self.rpc_states = set()
self.due = set()
+ # A "volatile" flag to interrupt the daemon serving thread
+ self.server_deallocated = False
+
def _add_generic_handlers(state, generic_handlers):
with state.lock:
@@ -698,6 +705,7 @@ def _request_call(state):
# TODO(https://github.com/grpc/grpc/issues/6597): delete this function.
def _stop_serving(state):
if not state.rpc_states and not state.due:
+ state.server.destroy()
for shutdown_event in state.shutdown_events:
shutdown_event.set()
state.stage = _ServerStage.STOPPED
@@ -711,49 +719,69 @@ def _on_call_completed(state):
state.active_rpc_count -= 1
-def _serve(state):
- while True:
- event = state.completion_queue.poll()
- if event.tag is _SHUTDOWN_TAG:
+def _process_event_and_continue(state, event):
+ should_continue = True
+ if event.tag is _SHUTDOWN_TAG:
+ with state.lock:
+ state.due.remove(_SHUTDOWN_TAG)
+ if _stop_serving(state):
+ should_continue = False
+ elif event.tag is _REQUEST_CALL_TAG:
+ with state.lock:
+ state.due.remove(_REQUEST_CALL_TAG)
+ concurrency_exceeded = (
+ state.maximum_concurrent_rpcs is not None and
+ state.active_rpc_count >= state.maximum_concurrent_rpcs)
+ rpc_state, rpc_future = _handle_call(
+ event, state.generic_handlers, state.interceptor_pipeline,
+ state.thread_pool, concurrency_exceeded)
+ if rpc_state is not None:
+ state.rpc_states.add(rpc_state)
+ if rpc_future is not None:
+ state.active_rpc_count += 1
+ rpc_future.add_done_callback(
+ lambda unused_future: _on_call_completed(state))
+ if state.stage is _ServerStage.STARTED:
+ _request_call(state)
+ elif _stop_serving(state):
+ should_continue = False
+ else:
+ rpc_state, callbacks = event.tag(event)
+ for callback in callbacks:
+ callable_util.call_logging_exceptions(callback,
+ 'Exception calling callback!')
+ if rpc_state is not None:
with state.lock:
- state.due.remove(_SHUTDOWN_TAG)
+ state.rpc_states.remove(rpc_state)
if _stop_serving(state):
- return
- elif event.tag is _REQUEST_CALL_TAG:
- with state.lock:
- state.due.remove(_REQUEST_CALL_TAG)
- concurrency_exceeded = (
- state.maximum_concurrent_rpcs is not None and
- state.active_rpc_count >= state.maximum_concurrent_rpcs)
- rpc_state, rpc_future = _handle_call(
- event, state.generic_handlers, state.interceptor_pipeline,
- state.thread_pool, concurrency_exceeded)
- if rpc_state is not None:
- state.rpc_states.add(rpc_state)
- if rpc_future is not None:
- state.active_rpc_count += 1
- rpc_future.add_done_callback(
- lambda unused_future: _on_call_completed(state))
- if state.stage is _ServerStage.STARTED:
- _request_call(state)
- elif _stop_serving(state):
- return
- else:
- rpc_state, callbacks = event.tag(event)
- for callback in callbacks:
- callable_util.call_logging_exceptions(
- callback, 'Exception calling callback!')
- if rpc_state is not None:
- with state.lock:
- state.rpc_states.remove(rpc_state)
- if _stop_serving(state):
- return
+ should_continue = False
+ return should_continue
+
+
+def _serve(state):
+ while True:
+ timeout = time.time() + _DEALLOCATED_SERVER_CHECK_PERIOD_S
+ event = state.completion_queue.poll(timeout)
+ if state.server_deallocated:
+ _begin_shutdown_once(state)
+ if event.completion_type != cygrpc.CompletionType.queue_timeout:
+ if not _process_event_and_continue(state, event):
+ return
# We want to force the deletion of the previous event
# ~before~ we poll again; if the event has a reference
# to a shutdown Call object, this can induce spinlock.
event = None
+def _begin_shutdown_once(state):
+ with state.lock:
+ if state.stage is _ServerStage.STARTED:
+ state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
+ state.stage = _ServerStage.GRACE
+ state.shutdown_events = []
+ state.due.add(_SHUTDOWN_TAG)
+
+
def _stop(state, grace):
with state.lock:
if state.stage is _ServerStage.STOPPED:
@@ -761,11 +789,7 @@ def _stop(state, grace):
shutdown_event.set()
return shutdown_event
else:
- if state.stage is _ServerStage.STARTED:
- state.server.shutdown(state.completion_queue, _SHUTDOWN_TAG)
- state.stage = _ServerStage.GRACE
- state.shutdown_events = []
- state.due.add(_SHUTDOWN_TAG)
+ _begin_shutdown_once(state)
shutdown_event = threading.Event()
state.shutdown_events.append(shutdown_event)
if grace is None:
@@ -836,7 +860,9 @@ class _Server(grpc.Server):
return _stop(self._state, grace)
def __del__(self):
- _stop(self._state, None)
+ # We can not grab a lock in __del__(), so set a flag to signal the
+ # serving daemon thread (if it exists) to initiate shutdown.
+ self._state.server_deallocated = True
def create_server(thread_pool, generic_rpc_handlers, interceptors, options,
diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py
index d90b34bcbd..2938a38b44 100644
--- a/src/python/grpcio/grpc/_utilities.py
+++ b/src/python/grpcio/grpc/_utilities.py
@@ -132,15 +132,12 @@ class _ChannelReadyFuture(grpc.Future):
def result(self, timeout=None):
self._block(timeout)
- return None
def exception(self, timeout=None):
self._block(timeout)
- return None
def traceback(self, timeout=None):
self._block(timeout)
- return None
def add_done_callback(self, fn):
with self._condition:
diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py
index c6ca970bee..6a1fd676ca 100644
--- a/src/python/grpcio/grpc_core_dependencies.py
+++ b/src/python/grpcio/grpc_core_dependencies.py
@@ -326,6 +326,7 @@ CORE_SOURCE_FILES = [
'src/core/ext/filters/client_channel/parse_address.cc',
'src/core/ext/filters/client_channel/proxy_mapper.cc',
'src/core/ext/filters/client_channel/proxy_mapper_registry.cc',
+ 'src/core/ext/filters/client_channel/request_routing.cc',
'src/core/ext/filters/client_channel/resolver.cc',
'src/core/ext/filters/client_channel/resolver_registry.cc',
'src/core/ext/filters/client_channel/resolver_result_parsing.cc',
diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py
index 2e91818d2c..8e2f4d30bb 100644
--- a/src/python/grpcio/grpc_version.py
+++ b/src/python/grpcio/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_channelz/grpc_version.py b/src/python/grpcio_channelz/grpc_version.py
index 16356ea402..5f3a894a2a 100644
--- a/src/python/grpcio_channelz/grpc_version.py
+++ b/src/python/grpcio_channelz/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_channelz/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py
index 0583659428..0a5bbb5504 100644
--- a/src/python/grpcio_health_checking/grpc_health/v1/health.py
+++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py
@@ -23,15 +23,61 @@ from grpc_health.v1 import health_pb2_grpc as _health_pb2_grpc
SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name['Health'].full_name
+class _Watcher():
+
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._responses = list()
+ self._open = True
+
+ def __iter__(self):
+ return self
+
+ def _next(self):
+ with self._condition:
+ while not self._responses and self._open:
+ self._condition.wait()
+ if self._responses:
+ return self._responses.pop(0)
+ else:
+ raise StopIteration()
+
+ def next(self):
+ return self._next()
+
+ def __next__(self):
+ return self._next()
+
+ def add(self, response):
+ with self._condition:
+ self._responses.append(response)
+ self._condition.notify()
+
+ def close(self):
+ with self._condition:
+ self._open = False
+ self._condition.notify()
+
+
class HealthServicer(_health_pb2_grpc.HealthServicer):
"""Servicer handling RPCs for service statuses."""
def __init__(self):
- self._server_status_lock = threading.Lock()
+ self._lock = threading.RLock()
self._server_status = {}
+ self._watchers = {}
+
+ def _on_close_callback(self, watcher, service):
+
+ def callback():
+ with self._lock:
+ self._watchers[service].remove(watcher)
+ watcher.close()
+
+ return callback
def Check(self, request, context):
- with self._server_status_lock:
+ with self._lock:
status = self._server_status.get(request.service)
if status is None:
context.set_code(grpc.StatusCode.NOT_FOUND)
@@ -39,14 +85,30 @@ class HealthServicer(_health_pb2_grpc.HealthServicer):
else:
return _health_pb2.HealthCheckResponse(status=status)
+ def Watch(self, request, context):
+ service = request.service
+ with self._lock:
+ status = self._server_status.get(service)
+ if status is None:
+ status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member
+ watcher = _Watcher()
+ watcher.add(_health_pb2.HealthCheckResponse(status=status))
+ if service not in self._watchers:
+ self._watchers[service] = set()
+ self._watchers[service].add(watcher)
+ context.add_callback(self._on_close_callback(watcher, service))
+ return watcher
+
def set(self, service, status):
"""Sets the status of a service.
- Args:
- service: string, the name of the service.
- NOTE, '' must be set.
- status: HealthCheckResponse.status enum value indicating
- the status of the service
- """
- with self._server_status_lock:
+ Args:
+ service: string, the name of the service. NOTE, '' must be set.
+ status: HealthCheckResponse.status enum value indicating the status of
+ the service
+ """
+ with self._lock:
self._server_status[service] = status
+ if service in self._watchers:
+ for watcher in self._watchers[service]:
+ watcher.add(_health_pb2.HealthCheckResponse(status=status))
diff --git a/src/python/grpcio_health_checking/grpc_version.py b/src/python/grpcio_health_checking/grpc_version.py
index 85fa762f7e..4c2d434066 100644
--- a/src/python/grpcio_health_checking/grpc_version.py
+++ b/src/python/grpcio_health_checking/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_health_checking/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_reflection/grpc_version.py b/src/python/grpcio_reflection/grpc_version.py
index e62ab169a2..6b88b2dfc5 100644
--- a/src/python/grpcio_reflection/grpc_version.py
+++ b/src/python/grpcio_reflection/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_reflection/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_status/.gitignore b/src/python/grpcio_status/.gitignore
new file mode 100644
index 0000000000..19d1523efd
--- /dev/null
+++ b/src/python/grpcio_status/.gitignore
@@ -0,0 +1,3 @@
+build/
+grpcio_status.egg-info/
+dist/
diff --git a/src/python/grpcio_status/MANIFEST.in b/src/python/grpcio_status/MANIFEST.in
new file mode 100644
index 0000000000..09b8ea721e
--- /dev/null
+++ b/src/python/grpcio_status/MANIFEST.in
@@ -0,0 +1,4 @@
+include grpc_version.py
+recursive-include grpc_status *.py
+global-exclude *.pyc
+include LICENSE
diff --git a/src/python/grpcio_status/README.rst b/src/python/grpcio_status/README.rst
new file mode 100644
index 0000000000..dc2f7b1dab
--- /dev/null
+++ b/src/python/grpcio_status/README.rst
@@ -0,0 +1,9 @@
+gRPC Python Status Proto
+===========================
+
+Reference package for GRPC Python status proto mapping.
+
+Dependencies
+------------
+
+Depends on the `grpcio` package, available from PyPI via `pip install grpcio`.
diff --git a/src/python/grpcio_status/grpc_status/BUILD.bazel b/src/python/grpcio_status/grpc_status/BUILD.bazel
new file mode 100644
index 0000000000..223a077c3f
--- /dev/null
+++ b/src/python/grpcio_status/grpc_status/BUILD.bazel
@@ -0,0 +1,14 @@
+load("@grpc_python_dependencies//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "grpc_status",
+ srcs = ["rpc_status.py",],
+ deps = [
+ "//src/python/grpcio/grpc:grpcio",
+ requirement('protobuf'),
+ requirement('googleapis-common-protos'),
+ ],
+ imports=["../",],
+)
diff --git a/src/python/grpcio_status/grpc_status/__init__.py b/src/python/grpcio_status/grpc_status/__init__.py
new file mode 100644
index 0000000000..38fdfc9c5c
--- /dev/null
+++ b/src/python/grpcio_status/grpc_status/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/python/grpcio_status/grpc_status/rpc_status.py b/src/python/grpcio_status/grpc_status/rpc_status.py
new file mode 100644
index 0000000000..87618fa541
--- /dev/null
+++ b/src/python/grpcio_status/grpc_status/rpc_status.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Reference implementation for status mapping in gRPC Python."""
+
+import collections
+
+import grpc
+
+# TODO(https://github.com/bazelbuild/bazel/issues/6844)
+# Due to Bazel issue, the namespace packages won't resolve correctly.
+# Adding this unused-import as a workaround to avoid module-not-found error
+# under Bazel builds.
+import google.protobuf # pylint: disable=unused-import
+from google.rpc import status_pb2
+
+_CODE_TO_GRPC_CODE_MAPPING = {x.value[0]: x for x in grpc.StatusCode}
+
+_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin'
+
+
+class _Status(
+ collections.namedtuple(
+ '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status):
+ pass
+
+
+def _code_to_grpc_status_code(code):
+ try:
+ return _CODE_TO_GRPC_CODE_MAPPING[code]
+ except KeyError:
+ raise ValueError('Invalid status code %s' % code)
+
+
+def from_call(call):
+ """Returns a google.rpc.status.Status message corresponding to a given grpc.Call.
+
+ This is an EXPERIMENTAL API.
+
+ Args:
+ call: A grpc.Call instance.
+
+ Returns:
+ A google.rpc.status.Status message representing the status of the RPC.
+
+ Raises:
+ ValueError: If the gRPC call's code or details are inconsistent with the
+ status code and message inside of the google.rpc.status.Status.
+ """
+ for key, value in call.trailing_metadata():
+ if key == _GRPC_DETAILS_METADATA_KEY:
+ rich_status = status_pb2.Status.FromString(value)
+ if call.code().value[0] != rich_status.code:
+ raise ValueError(
+ 'Code in Status proto (%s) doesn\'t match status code (%s)'
+ % (_code_to_grpc_status_code(rich_status.code),
+ call.code()))
+ if call.details() != rich_status.message:
+ raise ValueError(
+ 'Message in Status proto (%s) doesn\'t match status details (%s)'
+ % (rich_status.message, call.details()))
+ return rich_status
+ return None
+
+
+def to_status(status):
+ """Convert a google.rpc.status.Status message to grpc.Status.
+
+ This is an EXPERIMENTAL API.
+
+ Args:
+ status: a google.rpc.status.Status message representing the non-OK status
+ to terminate the RPC with and communicate it to the client.
+
+ Returns:
+ A grpc.Status instance representing the input google.rpc.status.Status message.
+ """
+ return _Status(
+ code=_code_to_grpc_status_code(status.code),
+ details=status.message,
+ trailing_metadata=((_GRPC_DETAILS_METADATA_KEY,
+ status.SerializeToString()),))
diff --git a/src/python/grpcio_status/grpc_version.py b/src/python/grpcio_status/grpc_version.py
new file mode 100644
index 0000000000..2e58eb3b26
--- /dev/null
+++ b/src/python/grpcio_status/grpc_version.py
@@ -0,0 +1,17 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_status/grpc_version.py.template`!!!
+
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_status/setup.py b/src/python/grpcio_status/setup.py
new file mode 100644
index 0000000000..983d3ea430
--- /dev/null
+++ b/src/python/grpcio_status/setup.py
@@ -0,0 +1,93 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Setup module for the GRPC Python package's status mapping."""
+
+import os
+
+import setuptools
+
+# Ensure we're in the proper directory whether or not we're being used by pip.
+os.chdir(os.path.dirname(os.path.abspath(__file__)))
+
+# Break import-style to ensure we can actually find our local modules.
+import grpc_version
+
+
+class _NoOpCommand(setuptools.Command):
+ """No-op command."""
+
+ description = ''
+ user_options = []
+
+ def initialize_options(self):
+ pass
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ pass
+
+
+CLASSIFIERS = [
+ 'Development Status :: 5 - Production/Stable',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'License :: OSI Approved :: Apache Software License',
+]
+
+PACKAGE_DIRECTORIES = {
+ '': '.',
+}
+
+INSTALL_REQUIRES = (
+ 'protobuf>=3.6.0',
+ 'grpcio>={version}'.format(version=grpc_version.VERSION),
+ 'googleapis-common-protos>=1.5.5',
+)
+
+try:
+ import status_commands as _status_commands
+ # we are in the build environment, otherwise the above import fails
+ COMMAND_CLASS = {
+ # Run preprocess from the repository *before* doing any packaging!
+ 'preprocess': _status_commands.Preprocess,
+ 'build_package_protos': _NoOpCommand,
+ }
+except ImportError:
+ COMMAND_CLASS = {
+ # wire up commands to no-op not to break the external dependencies
+ 'preprocess': _NoOpCommand,
+ 'build_package_protos': _NoOpCommand,
+ }
+
+setuptools.setup(
+ name='grpcio-status',
+ version=grpc_version.VERSION,
+ description='Status proto mapping for gRPC',
+ author='The gRPC Authors',
+ author_email='grpc-io@googlegroups.com',
+ url='https://grpc.io',
+ license='Apache License 2.0',
+ classifiers=CLASSIFIERS,
+ package_dir=PACKAGE_DIRECTORIES,
+ packages=setuptools.find_packages('.'),
+ install_requires=INSTALL_REQUIRES,
+ cmdclass=COMMAND_CLASS)
diff --git a/src/python/grpcio_status/status_commands.py b/src/python/grpcio_status/status_commands.py
new file mode 100644
index 0000000000..78cd497f62
--- /dev/null
+++ b/src/python/grpcio_status/status_commands.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Provides distutils command classes for the GRPC Python setup process."""
+
+import os
+import shutil
+
+import setuptools
+
+ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
+LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE')
+
+
+class Preprocess(setuptools.Command):
+ """Command to copy LICENSE from root directory."""
+
+ description = ''
+ user_options = []
+
+ def initialize_options(self):
+ pass
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ if os.path.isfile(LICENSE):
+ shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE'))
diff --git a/src/python/grpcio_testing/grpc_testing/_server/_handler.py b/src/python/grpcio_testing/grpc_testing/_server/_handler.py
index 0e3404b0d0..100d8195f6 100644
--- a/src/python/grpcio_testing/grpc_testing/_server/_handler.py
+++ b/src/python/grpcio_testing/grpc_testing/_server/_handler.py
@@ -185,7 +185,7 @@ class _Handler(Handler):
elif self._code is None:
self._condition.wait()
else:
- return self._trailing_metadata, self._code, self._details,
+ return self._trailing_metadata, self._code, self._details
def expire(self):
with self._condition:
diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
index 90eeb130d3..5b1dfeacdf 100644
--- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
+++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
@@ -70,6 +70,9 @@ class ServicerContext(grpc.ServicerContext):
def abort(self, code, details):
raise NotImplementedError()
+ def abort_with_status(self, status):
+ raise NotImplementedError()
+
def set_code(self, code):
self._rpc.set_code(code)
diff --git a/src/python/grpcio_testing/grpc_version.py b/src/python/grpcio_testing/grpc_version.py
index 7b4c1695fa..d4c5d94ecb 100644
--- a/src/python/grpcio_testing/grpc_version.py
+++ b/src/python/grpcio_testing/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_testing/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py
index 65e9a99950..582ce898de 100644
--- a/src/python/grpcio_tests/commands.py
+++ b/src/python/grpcio_tests/commands.py
@@ -22,7 +22,6 @@ import re
import shutil
import subprocess
import sys
-import traceback
import setuptools
from setuptools.command import build_ext
@@ -133,6 +132,16 @@ class TestGevent(setuptools.Command):
# TODO(https://github.com/grpc/grpc/issues/15411) unpin gevent version
# This test will stuck while running higher version of gevent
'unit._auth_context_test.AuthContextTest.testSessionResumption',
+ # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests
+ 'unit._metadata_flags_test',
+ 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call',
+ 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call',
+ 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call',
+ 'unit._exit_test.ExitTest.test_in_flight_stream_stream_call',
+ 'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call',
+ 'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call',
+ 'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call',
+ 'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list',
# TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests
'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels',
'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets',
diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py
index 2fcd1ad617..e1645ab1b8 100644
--- a/src/python/grpcio_tests/grpc_version.py
+++ b/src/python/grpcio_tests/grpc_version.py
@@ -14,4 +14,4 @@
# AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!!
-VERSION = '1.18.0.dev0'
+VERSION = '1.19.0.dev0'
diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py
index f56425ac6d..800b865da6 100644
--- a/src/python/grpcio_tests/setup.py
+++ b/src/python/grpcio_tests/setup.py
@@ -37,12 +37,19 @@ PACKAGE_DIRECTORIES = {
}
INSTALL_REQUIRES = (
- 'coverage>=4.0', 'enum34>=1.0.4',
+ 'coverage>=4.0',
+ 'enum34>=1.0.4',
'grpcio>={version}'.format(version=grpc_version.VERSION),
- 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION),
+ # TODO(https://github.com/pypa/warehouse/issues/5196)
+ # Re-enable it once we got the name back
+ # 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION),
+ 'grpcio-status>={version}'.format(version=grpc_version.VERSION),
'grpcio-tools>={version}'.format(version=grpc_version.VERSION),
'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION),
- 'oauth2client>=1.4.7', 'protobuf>=3.6.0', 'six>=1.10', 'google-auth>=1.0.0',
+ 'oauth2client>=1.4.7',
+ 'protobuf>=3.6.0',
+ 'six>=1.10',
+ 'google-auth>=1.0.0',
'requests>=2.14.2')
if not PY3:
diff --git a/src/python/grpcio_tests/tests/_runner.py b/src/python/grpcio_tests/tests/_runner.py
index eaaa027e61..9ef0f17684 100644
--- a/src/python/grpcio_tests/tests/_runner.py
+++ b/src/python/grpcio_tests/tests/_runner.py
@@ -203,7 +203,7 @@ class Runner(object):
check_kill_self()
time.sleep(0)
case_thread.join()
- except:
+ except: # pylint: disable=try-except-raise
# re-raise the exception after forcing the with-block to end
raise
result.set_output(augmented_case.case, stdout_pipe.output(),
diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
index 8ca5189522..c63ff5cd84 100644
--- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
+++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
@@ -88,11 +88,10 @@ def _generate_channel_server_pairs(n):
def _close_channel_server_pairs(pairs):
for pair in pairs:
pair.server.stop(None)
- # TODO(ericgribkoff) This del should not be required
- del pair.server
pair.channel.close()
+@unittest.skip('https://github.com/pypa/warehouse/issues/5196')
class ChannelzServicerTest(unittest.TestCase):
def _send_successful_unary_unary(self, idx):
diff --git a/src/python/grpcio_tests/tests/health_check/BUILD.bazel b/src/python/grpcio_tests/tests/health_check/BUILD.bazel
index 19e1e1b2e1..77bc61aa30 100644
--- a/src/python/grpcio_tests/tests/health_check/BUILD.bazel
+++ b/src/python/grpcio_tests/tests/health_check/BUILD.bazel
@@ -9,6 +9,7 @@ py_test(
"//src/python/grpcio/grpc:grpcio",
"//src/python/grpcio_health_checking/grpc_health/v1:grpc_health",
"//src/python/grpcio_tests/tests/unit:test_common",
+ "//src/python/grpcio_tests/tests/unit/framework/common:common",
],
imports = ["../../",],
)
diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
index 350b5eebe5..35794987bc 100644
--- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
+++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Tests of grpc_health.v1.health."""
+import threading
+import time
import unittest
import grpc
@@ -21,58 +23,199 @@ from grpc_health.v1 import health_pb2
from grpc_health.v1 import health_pb2_grpc
from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+from six.moves import queue
+
+_SERVING_SERVICE = 'grpc.test.TestServiceServing'
+_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown'
+_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing'
+_WATCH_SERVICE = 'grpc.test.WatchService'
+
+
+def _consume_responses(response_iterator, response_queue):
+ for response in response_iterator:
+ response_queue.put(response)
class HealthServicerTest(unittest.TestCase):
def setUp(self):
- servicer = health.HealthServicer()
- servicer.set('', health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceServing',
- health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceUnknown',
- health_pb2.HealthCheckResponse.UNKNOWN)
- servicer.set('grpc.test.TestServiceNotServing',
- health_pb2.HealthCheckResponse.NOT_SERVING)
+ self._servicer = health.HealthServicer()
+ self._servicer.set('', health_pb2.HealthCheckResponse.SERVING)
+ self._servicer.set(_SERVING_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ self._servicer.set(_UNKNOWN_SERVICE,
+ health_pb2.HealthCheckResponse.UNKNOWN)
+ self._servicer.set(_NOT_SERVING_SERVICE,
+ health_pb2.HealthCheckResponse.NOT_SERVING)
self._server = test_common.test_server()
port = self._server.add_insecure_port('[::]:0')
- health_pb2_grpc.add_HealthServicer_to_server(servicer, self._server)
+ health_pb2_grpc.add_HealthServicer_to_server(self._servicer,
+ self._server)
self._server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = health_pb2_grpc.HealthStub(channel)
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = health_pb2_grpc.HealthStub(self._channel)
- def test_empty_service(self):
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
+
+ def test_check_empty_service(self):
request = health_pb2.HealthCheckRequest()
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceServing')
+ def test_check_serving_service(self):
+ request = health_pb2.HealthCheckRequest(service=_SERVING_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_unknown_serivce(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceUnknown')
+ def test_check_unknown_serivce(self):
+ request = health_pb2.HealthCheckRequest(service=_UNKNOWN_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status)
- def test_not_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceNotServing')
+ def test_check_not_serving_service(self):
+ request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE)
resp = self._stub.Check(request)
self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
resp.status)
- def test_not_found_service(self):
+ def test_check_not_found_service(self):
request = health_pb2.HealthCheckRequest(service='not-found')
with self.assertRaises(grpc.RpcError) as context:
resp = self._stub.Check(request)
self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
+ def test_watch_empty_service(self):
+ request = health_pb2.HealthCheckRequest(service='')
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response.status)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_watch_new_service(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.NOT_SERVING)
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
+ response.status)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_watch_service_isolation(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ self._servicer.set('some-other-service',
+ health_pb2.HealthCheckResponse.SERVING)
+ with self.assertRaises(queue.Empty):
+ response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+
+ rendezvous.cancel()
+ thread.join()
+ self.assertTrue(response_queue.empty())
+
+ def test_two_watchers(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue1 = queue.Queue()
+ response_queue2 = queue.Queue()
+ rendezvous1 = self._stub.Watch(request)
+ rendezvous2 = self._stub.Watch(request)
+ thread1 = threading.Thread(
+ target=_consume_responses, args=(rendezvous1, response_queue1))
+ thread2 = threading.Thread(
+ target=_consume_responses, args=(rendezvous2, response_queue2))
+ thread1.start()
+ thread2.start()
+
+ response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
+ response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response1.status)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response2.status)
+
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ response1 = response_queue1.get(timeout=test_constants.SHORT_TIMEOUT)
+ response2 = response_queue2.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response1.status)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING,
+ response2.status)
+
+ rendezvous1.cancel()
+ rendezvous2.cancel()
+ thread1.join()
+ thread2.join()
+ self.assertTrue(response_queue1.empty())
+ self.assertTrue(response_queue2.empty())
+
+ def test_cancelled_watch_removed_from_watch_list(self):
+ request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE)
+ response_queue = queue.Queue()
+ rendezvous = self._stub.Watch(request)
+ thread = threading.Thread(
+ target=_consume_responses, args=(rendezvous, response_queue))
+ thread.start()
+
+ response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN,
+ response.status)
+
+ rendezvous.cancel()
+ self._servicer.set(_WATCH_SERVICE,
+ health_pb2.HealthCheckResponse.SERVING)
+ thread.join()
+
+ # Wait, if necessary, for serving thread to process client cancellation
+ timeout = time.time() + test_constants.SHORT_TIMEOUT
+ while time.time() < timeout and self._servicer._watchers[_WATCH_SERVICE]:
+ time.sleep(1)
+ self.assertFalse(self._servicer._watchers[_WATCH_SERVICE],
+ 'watch set should be empty')
+ self.assertTrue(response_queue.empty())
+
def test_health_service_name(self):
self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health')
diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py
index 698c37017f..56cb29477c 100644
--- a/src/python/grpcio_tests/tests/interop/client.py
+++ b/src/python/grpcio_tests/tests/interop/client.py
@@ -54,7 +54,6 @@ def _args():
help='replace platform root CAs with ca.pem')
parser.add_argument(
'--server_host_override',
- default="foo.test.google.fr",
type=str,
help='the server host to which to claim to connect')
parser.add_argument(
@@ -100,10 +99,13 @@ def _stub(args):
channel_credentials = grpc.composite_channel_credentials(
channel_credentials, call_credentials)
- channel = grpc.secure_channel(target, channel_credentials, ((
- 'grpc.ssl_target_name_override',
- args.server_host_override,
- ),))
+ channel_opts = None
+ if args.server_host_override:
+ channel_opts = ((
+ 'grpc.ssl_target_name_override',
+ args.server_host_override,
+ ),)
+ channel = grpc.secure_channel(target, channel_credentials, channel_opts)
else:
channel = grpc.insecure_channel(target)
if args.test_case == "unimplemented_service":
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
index e21ea0010a..2b735526cb 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
@@ -144,7 +144,7 @@ class _ProtoBeforeGrpcProtocStyle(object):
absolute_proto_file_names)
pb2_grpc_protoc_exit_code = _protoc(
proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names)
- return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code,
+ return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code
class _GrpcBeforeProtoProtocStyle(object):
@@ -160,7 +160,7 @@ class _GrpcBeforeProtoProtocStyle(object):
proto_path, None, 'grpc_2_0', python_out, absolute_proto_file_names)
pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None,
absolute_proto_file_names)
- return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code,
+ return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code
_PROTOC_STYLES = (
@@ -243,9 +243,9 @@ class _Test(six.with_metaclass(abc.ABCMeta, unittest.TestCase)):
def _services_modules(self):
if self.PROTOC_STYLE.grpc_in_pb2_expected():
- return self._services_pb2, self._services_pb2_grpc,
+ return self._services_pb2, self._services_pb2_grpc
else:
- return self._services_pb2_grpc,
+ return (self._services_pb2_grpc,)
def test_imported_attributes(self):
self._protoc()
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
index b46e53315e..43c90af6a7 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
@@ -223,7 +223,7 @@ def _CreateService(payload_pb2, responses_pb2, service_pb2):
server.start()
channel = implementations.insecure_channel('localhost', port)
stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel)
- yield servicer_methods, stub,
+ yield servicer_methods, stub
server.stop(0)
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py
index 0488450740..fac0e44e5a 100644
--- a/src/python/grpcio_tests/tests/qps/benchmark_client.py
+++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py
@@ -180,7 +180,7 @@ class StreamingSyncBenchmarkClient(BenchmarkClient):
self._streams = [
_SyncStream(self._stub, self._generic, self._request,
self._handle_response)
- for _ in xrange(config.outstanding_rpcs_per_channel)
+ for _ in range(config.outstanding_rpcs_per_channel)
]
self._curr_stream = 0
diff --git a/src/python/grpcio_tests/tests/qps/client_runner.py b/src/python/grpcio_tests/tests/qps/client_runner.py
index e79abab3c7..a57524c74e 100644
--- a/src/python/grpcio_tests/tests/qps/client_runner.py
+++ b/src/python/grpcio_tests/tests/qps/client_runner.py
@@ -77,7 +77,7 @@ class ClosedLoopClientRunner(ClientRunner):
def start(self):
self._is_running = True
self._client.start()
- for _ in xrange(self._request_count):
+ for _ in range(self._request_count):
self._client.send_request()
def stop(self):
diff --git a/src/python/grpcio_tests/tests/qps/worker_server.py b/src/python/grpcio_tests/tests/qps/worker_server.py
index 337a94b546..a03367ec63 100644
--- a/src/python/grpcio_tests/tests/qps/worker_server.py
+++ b/src/python/grpcio_tests/tests/qps/worker_server.py
@@ -109,7 +109,7 @@ class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer):
start_time = time.time()
# Create a client for each channel
- for i in xrange(config.client_channels):
+ for i in range(config.client_channels):
server = config.server_targets[i % len(config.server_targets)]
runner = self._create_client_runner(server, config, qps_data)
client_runners.append(runner)
diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
index bcd9e14a38..560f6d3ddb 100644
--- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
+++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
@@ -56,8 +56,12 @@ class ReflectionServicerTest(unittest.TestCase):
port = self._server.add_insecure_port('[::]:0')
self._server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = reflection_pb2_grpc.ServerReflectionStub(channel)
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel)
+
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
def testFileByName(self):
requests = (
diff --git a/src/python/grpcio_tests/tests/status/BUILD.bazel b/src/python/grpcio_tests/tests/status/BUILD.bazel
new file mode 100644
index 0000000000..937e50498e
--- /dev/null
+++ b/src/python/grpcio_tests/tests/status/BUILD.bazel
@@ -0,0 +1,19 @@
+load("@grpc_python_dependencies//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_test(
+ name = "grpc_status_test",
+ srcs = ["_grpc_status_test.py"],
+ main = "_grpc_status_test.py",
+ size = "small",
+ deps = [
+ "//src/python/grpcio/grpc:grpcio",
+ "//src/python/grpcio_status/grpc_status:grpc_status",
+ "//src/python/grpcio_tests/tests/unit:test_common",
+ "//src/python/grpcio_tests/tests/unit/framework/common:common",
+ requirement('protobuf'),
+ requirement('googleapis-common-protos'),
+ ],
+ imports = ["../../",],
+)
diff --git a/src/python/grpcio_tests/tests/status/__init__.py b/src/python/grpcio_tests/tests/status/__init__.py
new file mode 100644
index 0000000000..38fdfc9c5c
--- /dev/null
+++ b/src/python/grpcio_tests/tests/status/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py
new file mode 100644
index 0000000000..519c372a96
--- /dev/null
+++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py
@@ -0,0 +1,173 @@
+# Copyright 2018 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests of grpc_status."""
+
+import unittest
+
+import logging
+import traceback
+
+import grpc
+from grpc_status import rpc_status
+
+from tests.unit import test_common
+
+from google.protobuf import any_pb2
+from google.rpc import code_pb2, status_pb2, error_details_pb2
+
+_STATUS_OK = '/test/StatusOK'
+_STATUS_NOT_OK = '/test/StatusNotOk'
+_ERROR_DETAILS = '/test/ErrorDetails'
+_INCONSISTENT = '/test/Inconsistent'
+_INVALID_CODE = '/test/InvalidCode'
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x01\x01\x01'
+
+_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin'
+
+_STATUS_DETAILS = 'This is an error detail'
+_STATUS_DETAILS_ANOTHER = 'This is another error detail'
+
+
+def _ok_unary_unary(request, servicer_context):
+ return _RESPONSE
+
+
+def _not_ok_unary_unary(request, servicer_context):
+ servicer_context.abort(grpc.StatusCode.INTERNAL, _STATUS_DETAILS)
+
+
+def _error_details_unary_unary(request, servicer_context):
+ details = any_pb2.Any()
+ details.Pack(
+ error_details_pb2.DebugInfo(
+ stack_entries=traceback.format_stack(),
+ detail='Intentionally invoked'))
+ rich_status = status_pb2.Status(
+ code=code_pb2.INTERNAL,
+ message=_STATUS_DETAILS,
+ details=[details],
+ )
+ servicer_context.abort_with_status(rpc_status.to_status(rich_status))
+
+
+def _inconsistent_unary_unary(request, servicer_context):
+ rich_status = status_pb2.Status(
+ code=code_pb2.INTERNAL,
+ message=_STATUS_DETAILS,
+ )
+ servicer_context.set_code(grpc.StatusCode.NOT_FOUND)
+ servicer_context.set_details(_STATUS_DETAILS_ANOTHER)
+ # User put inconsistent status information in trailing metadata
+ servicer_context.set_trailing_metadata(((_GRPC_DETAILS_METADATA_KEY,
+ rich_status.SerializeToString()),))
+
+
+def _invalid_code_unary_unary(request, servicer_context):
+ rich_status = status_pb2.Status(
+ code=42,
+ message='Invalid code',
+ )
+ servicer_context.abort_with_status(rpc_status.to_status(rich_status))
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _STATUS_OK:
+ return grpc.unary_unary_rpc_method_handler(_ok_unary_unary)
+ elif handler_call_details.method == _STATUS_NOT_OK:
+ return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary)
+ elif handler_call_details.method == _ERROR_DETAILS:
+ return grpc.unary_unary_rpc_method_handler(
+ _error_details_unary_unary)
+ elif handler_call_details.method == _INCONSISTENT:
+ return grpc.unary_unary_rpc_method_handler(
+ _inconsistent_unary_unary)
+ elif handler_call_details.method == _INVALID_CODE:
+ return grpc.unary_unary_rpc_method_handler(
+ _invalid_code_unary_unary)
+ else:
+ return None
+
+
+class StatusTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server = test_common.test_server()
+ self._server.add_generic_rpc_handlers((_GenericHandler(),))
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
+
+ def test_status_ok(self):
+ _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST)
+
+ # Succeed RPC doesn't have status
+ status = rpc_status.from_call(call)
+ self.assertIs(status, None)
+
+ def test_status_not_ok(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST)
+ rpc_error = exception_context.exception
+
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
+ # Failed RPC doesn't automatically generate status
+ status = rpc_status.from_call(rpc_error)
+ self.assertIs(status, None)
+
+ def test_error_details(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST)
+ rpc_error = exception_context.exception
+
+ status = rpc_status.from_call(rpc_error)
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
+ self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL'))
+
+ # Check if the underlying proto message is intact
+ self.assertEqual(status.details[0].Is(
+ error_details_pb2.DebugInfo.DESCRIPTOR), True)
+ info = error_details_pb2.DebugInfo()
+ status.details[0].Unpack(info)
+ self.assertIn('_error_details_unary_unary', info.stack_entries[-1])
+
+ def test_code_message_validation(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST)
+ rpc_error = exception_context.exception
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND)
+
+ # Code/Message validation failed
+ self.assertRaises(ValueError, rpc_status.from_call, rpc_error)
+
+ def test_invalid_code(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST)
+ rpc_error = exception_context.exception
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)
+ # Invalid status code exception raised during coversion
+ self.assertIn('Invalid status code', rpc_error.details())
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py
index 41f2e1b6c2..4c35b05044 100644
--- a/src/python/grpcio_tests/tests/stress/client.py
+++ b/src/python/grpcio_tests/tests/stress/client.py
@@ -71,7 +71,6 @@ def _args():
'--use_tls', help='Whether to use TLS', default=False, type=bool)
parser.add_argument(
'--server_host_override',
- default="foo.test.google.fr",
help='the server host to which to claim to connect',
type=str)
return parser.parse_args()
@@ -132,9 +131,9 @@ def run_test(args):
server.start()
for test_server_target in test_server_targets:
- for _ in xrange(args.num_channels_per_server):
+ for _ in range(args.num_channels_per_server):
channel = _get_channel(test_server_target, args)
- for _ in xrange(args.num_stubs_per_channel):
+ for _ in range(args.num_stubs_per_channel):
stub = test_pb2_grpc.TestServiceStub(channel)
runner = test_runner.TestRunner(stub, test_cases, hist,
exception_queue, stop_event)
diff --git a/src/python/grpcio_tests/tests/testing/_client_application.py b/src/python/grpcio_tests/tests/testing/_client_application.py
index 3ddeba2373..4d42df0389 100644
--- a/src/python/grpcio_tests/tests/testing/_client_application.py
+++ b/src/python/grpcio_tests/tests/testing/_client_application.py
@@ -130,9 +130,9 @@ def _run_stream_stream(stub):
request_pipe = _Pipe()
response_iterator = stub.StreStre(iter(request_pipe))
request_pipe.add(_application_common.STREAM_STREAM_REQUEST)
- first_responses = next(response_iterator), next(response_iterator),
+ first_responses = next(response_iterator), next(response_iterator)
request_pipe.add(_application_common.STREAM_STREAM_REQUEST)
- second_responses = next(response_iterator), next(response_iterator),
+ second_responses = next(response_iterator), next(response_iterator)
request_pipe.close()
try:
next(response_iterator)
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index 9cffd3df19..de4c2c1fdd 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -15,10 +15,12 @@
"protoc_plugin._split_definitions_test.SplitProtoSingleProtocExecutionProtocStyleTest",
"protoc_plugin.beta_python_plugin_test.PythonPluginTest",
"reflection._reflection_servicer_test.ReflectionServicerTest",
+ "status._grpc_status_test.StatusTest",
"testing._client_test.ClientTest",
"testing._server_test.FirstServiceServicerTest",
"testing._time_test.StrictFakeTimeTest",
"testing._time_test.StrictRealTimeTest",
+ "unit._abort_test.AbortTest",
"unit._api_test.AllTest",
"unit._api_test.ChannelConnectivityTest",
"unit._api_test.ChannelTest",
@@ -55,12 +57,14 @@
"unit._reconnect_test.ReconnectTest",
"unit._resource_exhausted_test.ResourceExhaustedTest",
"unit._rpc_test.RPCTest",
+ "unit._server_shutdown_test.ServerShutdown",
"unit._server_ssl_cert_config_test.ServerSSLCertConfigFetcherParamsChecks",
"unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse",
"unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth",
"unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithoutClientAuth",
"unit._server_test.ServerTest",
"unit._session_cache_test.SSLSessionCacheTest",
+ "unit._version_test.VersionTest",
"unit.beta._beta_features_test.BetaFeaturesTest",
"unit.beta._beta_features_test.ContextManagementAndLifecycleTest",
"unit.beta._connectivity_channel_test.ConnectivityStatesTest",
diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel
index de33b81e32..a9bcd9f304 100644
--- a/src/python/grpcio_tests/tests/unit/BUILD.bazel
+++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel
@@ -3,9 +3,11 @@ load("@grpc_python_dependencies//:requirements.bzl", "requirement")
package(default_visibility = ["//visibility:public"])
GRPCIO_TESTS_UNIT = [
+ "_abort_test.py",
"_api_test.py",
"_auth_context_test.py",
"_auth_test.py",
+ "_version_test.py",
"_channel_args_test.py",
"_channel_close_test.py",
"_channel_connectivity_test.py",
@@ -27,6 +29,7 @@ GRPCIO_TESTS_UNIT = [
# TODO(ghostwriternr): To be added later.
# "_server_ssl_cert_config_test.py",
"_server_test.py",
+ "_server_shutdown_test.py",
"_session_cache_test.py",
]
@@ -49,6 +52,11 @@ py_library(
)
py_library(
+ name = "_server_shutdown_scenarios",
+ srcs = ["_server_shutdown_scenarios.py"],
+)
+
+py_library(
name = "_thread_pool",
srcs = ["_thread_pool.py"],
)
@@ -69,6 +77,7 @@ py_library(
":resources",
":test_common",
":_exit_scenarios",
+ ":_server_shutdown_scenarios",
":_thread_pool",
":_from_grpc_import_star",
"//src/python/grpcio_tests/tests/unit/framework/common",
diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py
new file mode 100644
index 0000000000..6438f6897a
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_abort_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests server context abort mechanism"""
+
+import unittest
+import collections
+import logging
+
+import grpc
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_ABORT = '/test/abort'
+_ABORT_WITH_STATUS = '/test/AbortWithStatus'
+_INVALID_CODE = '/test/InvalidCode'
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x00\x00\x00'
+
+_ABORT_DETAILS = 'Abandon ship!'
+_ABORT_METADATA = (('a-trailing-metadata', '42'),)
+
+
+class _Status(
+ collections.namedtuple(
+ '_Status', ('code', 'details', 'trailing_metadata')), grpc.Status):
+ pass
+
+
+def abort_unary_unary(request, servicer_context):
+ servicer_context.abort(
+ grpc.StatusCode.INTERNAL,
+ _ABORT_DETAILS,
+ )
+ raise Exception('This line should not be executed!')
+
+
+def abort_with_status_unary_unary(request, servicer_context):
+ servicer_context.abort_with_status(
+ _Status(
+ code=grpc.StatusCode.INTERNAL,
+ details=_ABORT_DETAILS,
+ trailing_metadata=_ABORT_METADATA,
+ ))
+ raise Exception('This line should not be executed!')
+
+
+def invalid_code_unary_unary(request, servicer_context):
+ servicer_context.abort(
+ 42,
+ _ABORT_DETAILS,
+ )
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _ABORT:
+ return grpc.unary_unary_rpc_method_handler(abort_unary_unary)
+ elif handler_call_details.method == _ABORT_WITH_STATUS:
+ return grpc.unary_unary_rpc_method_handler(
+ abort_with_status_unary_unary)
+ elif handler_call_details.method == _INVALID_CODE:
+ return grpc.stream_stream_rpc_method_handler(
+ invalid_code_unary_unary)
+ else:
+ return None
+
+
+class AbortTest(unittest.TestCase):
+
+ def setUp(self):
+ self._server = test_common.test_server()
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(),))
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._channel.close()
+ self._server.stop(0)
+
+ def test_abort(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_ABORT)(_REQUEST)
+ rpc_error = exception_context.exception
+
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
+ self.assertEqual(rpc_error.details(), _ABORT_DETAILS)
+
+ def test_abort_with_status(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST)
+ rpc_error = exception_context.exception
+
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
+ self.assertEqual(rpc_error.details(), _ABORT_DETAILS)
+ self.assertEqual(rpc_error.trailing_metadata(), _ABORT_METADATA)
+
+ def test_invalid_code(self):
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary(_INVALID_CODE)(_REQUEST)
+ rpc_error = exception_context.exception
+
+ self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)
+ self.assertEqual(rpc_error.details(), _ABORT_DETAILS)
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py
index 38072861a4..0dc6a8718c 100644
--- a/src/python/grpcio_tests/tests/unit/_api_test.py
+++ b/src/python/grpcio_tests/tests/unit/_api_test.py
@@ -32,6 +32,7 @@ class AllTest(unittest.TestCase):
'Future',
'ChannelConnectivity',
'StatusCode',
+ 'Status',
'RpcError',
'RpcContext',
'Call',
@@ -100,6 +101,7 @@ class ChannelTest(unittest.TestCase):
def test_secure_channel(self):
channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel('google.com:443', channel_credentials)
+ channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py
index b1b5bbdcab..96c4e9ec76 100644
--- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py
+++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py
@@ -71,8 +71,8 @@ class AuthContextTest(unittest.TestCase):
port = server.add_insecure_port('[::]:0')
server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
- response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ with grpc.insecure_channel('localhost:%d' % port) as channel:
+ response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
server.stop(None)
auth_data = pickle.loads(response)
@@ -98,6 +98,7 @@ class AuthContextTest(unittest.TestCase):
channel_creds,
options=_PROPERTY_OPTIONS)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ channel.close()
server.stop(None)
auth_data = pickle.loads(response)
@@ -132,6 +133,7 @@ class AuthContextTest(unittest.TestCase):
options=_PROPERTY_OPTIONS)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ channel.close()
server.stop(None)
auth_data = pickle.loads(response)
diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
index 727fb7d65f..565bd39b3a 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
@@ -75,6 +75,8 @@ class ChannelConnectivityTest(unittest.TestCase):
channel.unsubscribe(callback.update)
fifth_connectivities = callback.connectivities()
+ channel.close()
+
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities)
@@ -108,7 +110,8 @@ class ChannelConnectivityTest(unittest.TestCase):
_ready_in_connectivities)
second_callback.block_until_connectivities_satisfy(
_ready_in_connectivities)
- del channel
+ channel.close()
+ server.stop(None)
self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
first_connectivities)
@@ -139,6 +142,7 @@ class ChannelConnectivityTest(unittest.TestCase):
callback.block_until_connectivities_satisfy(
_last_connectivity_is_not_ready)
channel.unsubscribe(callback.update)
+ channel.close()
self.assertFalse(thread_pool.was_used())
diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
index 345460ef40..46a4eb9bb6 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
@@ -60,6 +60,8 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertTrue(ready_future.done())
self.assertFalse(ready_future.running())
+ channel.close()
+
def test_immediately_connectable_channel_connectivity(self):
thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
server = grpc.server(thread_pool, options=(('grpc.so_reuseport', 0),))
@@ -84,6 +86,9 @@ class ChannelReadyFutureTest(unittest.TestCase):
self.assertFalse(ready_future.running())
self.assertFalse(thread_pool.was_used())
+ channel.close()
+ server.stop(None)
+
if __name__ == '__main__':
logging.basicConfig()
diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py
index 876d8e827e..87884a19dc 100644
--- a/src/python/grpcio_tests/tests/unit/_compression_test.py
+++ b/src/python/grpcio_tests/tests/unit/_compression_test.py
@@ -77,6 +77,9 @@ class CompressionTest(unittest.TestCase):
self._port = self._server.add_insecure_port('[::]:0')
self._server.start()
+ def tearDown(self):
+ self._server.stop(None)
+
def testUnary(self):
request = b'\x00' * 100
@@ -102,6 +105,7 @@ class CompressionTest(unittest.TestCase):
response = multi_callable(
request, metadata=[('grpc-internal-encoding-request', 'gzip')])
self.assertEqual(request, response)
+ compressed_channel.close()
def testStreaming(self):
request = b'\x00' * 100
@@ -115,6 +119,7 @@ class CompressionTest(unittest.TestCase):
call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))
for response in call:
self.assertEqual(request, response)
+ compressed_channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
index aeb02458a7..5a5dedd5f2 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py
@@ -27,6 +27,7 @@ def _get_number_active_threads():
class ForkPosixTester(unittest.TestCase):
def setUp(self):
+ self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT
cygrpc._GRPC_ENABLE_FORK_SUPPORT = True
def testForkManagedThread(self):
@@ -50,6 +51,9 @@ class ForkPosixTester(unittest.TestCase):
thread.join()
self.assertEqual(0, _get_number_active_threads())
+ def tearDown(self):
+ cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag
+
@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests')
class ForkWindowsTester(unittest.TestCase):
diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
index 3e8393b53c..f27ea422d0 100644
--- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py
+++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
@@ -96,6 +96,7 @@ class EmptyMessageTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
index 6c551df3ec..81de1dae1d 100644
--- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
+++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
@@ -71,6 +71,7 @@ class ErrorMessageEncodingTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testMessageEncoding(self):
for message in _UNICODE_ERROR_MESSAGES:
diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py
index 5226537579..b429ee089f 100644
--- a/src/python/grpcio_tests/tests/unit/_exit_test.py
+++ b/src/python/grpcio_tests/tests/unit/_exit_test.py
@@ -71,7 +71,6 @@ def wait(process):
process.wait()
-@unittest.skip('https://github.com/grpc/grpc/issues/7311')
class ExitTest(unittest.TestCase):
def test_unstarted_server(self):
@@ -130,6 +129,8 @@ class ExitTest(unittest.TestCase):
stderr=sys.stderr)
interrupt_and_wait(process)
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_unary_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
@@ -138,6 +139,8 @@ class ExitTest(unittest.TestCase):
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
@@ -145,6 +148,8 @@ class ExitTest(unittest.TestCase):
stderr=sys.stderr)
interrupt_and_wait(process)
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
@@ -153,6 +158,8 @@ class ExitTest(unittest.TestCase):
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
@@ -161,6 +168,8 @@ class ExitTest(unittest.TestCase):
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_partial_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND +
@@ -169,6 +178,8 @@ class ExitTest(unittest.TestCase):
stderr=sys.stderr)
interrupt_and_wait(process)
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_partial_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND +
@@ -178,6 +189,8 @@ class ExitTest(unittest.TestCase):
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ @unittest.skipIf(os.name == 'nt',
+ 'os.kill does not have required permission on Windows')
def test_in_flight_partial_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND +
diff --git a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py
index ad847ae03e..1ada25382d 100644
--- a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py
+++ b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py
@@ -14,7 +14,7 @@
_BEFORE_IMPORT = tuple(globals())
-from grpc import * # pylint: disable=wildcard-import
+from grpc import * # pylint: disable=wildcard-import,unused-wildcard-import
_AFTER_IMPORT = tuple(globals())
diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
index 99db0ac58b..a647e5e720 100644
--- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py
+++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py
@@ -337,6 +337,7 @@ class InterceptorTest(unittest.TestCase):
def tearDown(self):
self._server.stop(None)
self._server_pool.shutdown(wait=True)
+ self._channel.close()
def testTripleRequestMessagesClientInterceptor(self):
diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
index 0ff49490d5..7ed7c83893 100644
--- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
@@ -62,6 +62,9 @@ class InvalidMetadataTest(unittest.TestCase):
self._stream_unary = _stream_unary_multi_callable(self._channel)
self._stream_stream = _stream_stream_multi_callable(self._channel)
+ def tearDown(self):
+ self._channel.close()
+
def testUnaryRequestBlockingUnaryResponse(self):
request = b'\x07\x08'
metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)
diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
index 00949e2236..e89b521cc5 100644
--- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
@@ -215,6 +215,7 @@ class InvocationDefectsTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testIterableStreamRequestBlockingUnaryResponse(self):
requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
diff --git a/src/python/grpcio_tests/tests/unit/_logging_test.py b/src/python/grpcio_tests/tests/unit/_logging_test.py
index 631b9de9db..8ff127f506 100644
--- a/src/python/grpcio_tests/tests/unit/_logging_test.py
+++ b/src/python/grpcio_tests/tests/unit/_logging_test.py
@@ -14,66 +14,86 @@
"""Test of gRPC Python's interaction with the python logging module"""
import unittest
-import six
-from six.moves import reload_module
import logging
import grpc
-import functools
+import subprocess
import sys
+INTERPRETER = sys.executable
-def patch_stderr(f):
- @functools.wraps(f)
- def _impl(*args, **kwargs):
- old_stderr = sys.stderr
- sys.stderr = six.StringIO()
- try:
- f(*args, **kwargs)
- finally:
- sys.stderr = old_stderr
+class LoggingTest(unittest.TestCase):
- return _impl
+ def test_logger_not_occupied(self):
+ script = """if True:
+ import logging
+ import grpc
-def isolated_logging(f):
+ if len(logging.getLogger().handlers) != 0:
+ raise Exception('expected 0 logging handlers')
- @functools.wraps(f)
- def _impl(*args, **kwargs):
- reload_module(logging)
- reload_module(grpc)
- try:
- f(*args, **kwargs)
- finally:
- reload_module(logging)
+ """
+ self._verifyScriptSucceeds(script)
- return _impl
+ def test_handler_found(self):
+ script = """if True:
+ import logging
+ import grpc
+ """
+ out, err = self._verifyScriptSucceeds(script)
+ self.assertEqual(0, len(err), 'unexpected output to stderr')
-class LoggingTest(unittest.TestCase):
+ def test_can_configure_logger(self):
+ script = """if True:
+ import logging
+ import six
- @isolated_logging
- def test_logger_not_occupied(self):
- self.assertEqual(0, len(logging.getLogger().handlers))
+ import grpc
- @patch_stderr
- @isolated_logging
- def test_handler_found(self):
- self.assertEqual(0, len(sys.stderr.getvalue()))
- @isolated_logging
- def test_can_configure_logger(self):
- intended_stream = six.StringIO()
- logging.basicConfig(stream=intended_stream)
- self.assertEqual(1, len(logging.getLogger().handlers))
- self.assertIs(logging.getLogger().handlers[0].stream, intended_stream)
+ intended_stream = six.StringIO()
+ logging.basicConfig(stream=intended_stream)
+
+ if len(logging.getLogger().handlers) != 1:
+ raise Exception('expected 1 logging handler')
+
+ if logging.getLogger().handlers[0].stream is not intended_stream:
+ raise Exception('wrong handler stream')
+
+ """
+ self._verifyScriptSucceeds(script)
- @isolated_logging
def test_grpc_logger(self):
- self.assertIn("grpc", logging.Logger.manager.loggerDict)
- root_logger = logging.getLogger("grpc")
- self.assertEqual(1, len(root_logger.handlers))
- self.assertIsInstance(root_logger.handlers[0], logging.NullHandler)
+ script = """if True:
+ import logging
+
+ import grpc
+
+ if "grpc" not in logging.Logger.manager.loggerDict:
+ raise Exception('grpc logger not found')
+
+ root_logger = logging.getLogger("grpc")
+ if len(root_logger.handlers) != 1:
+ raise Exception('expected 1 root logger handler')
+ if not isinstance(root_logger.handlers[0], logging.NullHandler):
+ raise Exception('expected logging.NullHandler')
+
+ """
+ self._verifyScriptSucceeds(script)
+
+ def _verifyScriptSucceeds(self, script):
+ process = subprocess.Popen(
+ [INTERPRETER, '-c', script],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ out, err = process.communicate()
+ self.assertEqual(
+ 0, process.returncode,
+ 'process failed with exit code %d (stdout: %s, stderr: %s)' %
+ (process.returncode, out, err))
+ return out, err
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
index 0dafab827a..a63664ac5d 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
@@ -198,8 +198,8 @@ class MetadataCodeDetailsTest(unittest.TestCase):
port = self._server.add_insecure_port('[::]:0')
self._server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- self._unary_unary = channel.unary_unary(
+ self._channel = grpc.insecure_channel('localhost:{}'.format(port))
+ self._unary_unary = self._channel.unary_unary(
'/'.join((
'',
_SERVICE,
@@ -208,17 +208,17 @@ class MetadataCodeDetailsTest(unittest.TestCase):
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
)
- self._unary_stream = channel.unary_stream('/'.join((
+ self._unary_stream = self._channel.unary_stream('/'.join((
'',
_SERVICE,
_UNARY_STREAM,
)),)
- self._stream_unary = channel.stream_unary('/'.join((
+ self._stream_unary = self._channel.stream_unary('/'.join((
'',
_SERVICE,
_STREAM_UNARY,
)),)
- self._stream_stream = channel.stream_stream(
+ self._stream_stream = self._channel.stream_stream(
'/'.join((
'',
_SERVICE,
@@ -228,6 +228,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
response_deserializer=_RESPONSE_DESERIALIZER,
)
+ def tearDown(self):
+ self._server.stop(None)
+ self._channel.close()
+
def testSuccessfulUnaryUnary(self):
self._servicer.set_details(_DETAILS)
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
index 2d352e99d4..7b32b5b5f3 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
@@ -187,13 +187,14 @@ class MetadataFlagsTest(unittest.TestCase):
def test_call_wait_for_ready_default(self):
for perform_call in _ALL_CALL_CASES:
- self.check_connection_does_failfast(perform_call,
- create_dummy_channel())
+ with create_dummy_channel() as channel:
+ self.check_connection_does_failfast(perform_call, channel)
def test_call_wait_for_ready_disabled(self):
for perform_call in _ALL_CALL_CASES:
- self.check_connection_does_failfast(
- perform_call, create_dummy_channel(), wait_for_ready=False)
+ with create_dummy_channel() as channel:
+ self.check_connection_does_failfast(
+ perform_call, channel, wait_for_ready=False)
def test_call_wait_for_ready_enabled(self):
# To test the wait mechanism, Python thread is required to make
@@ -210,16 +211,16 @@ class MetadataFlagsTest(unittest.TestCase):
wg.done()
def test_call(perform_call):
- try:
- channel = grpc.insecure_channel(addr)
- channel.subscribe(wait_for_transient_failure)
- perform_call(channel, wait_for_ready=True)
- except BaseException as e: # pylint: disable=broad-except
- # If the call failed, the thread would be destroyed. The channel
- # object can be collected before calling the callback, which
- # will result in a deadlock.
- wg.done()
- unhandled_exceptions.put(e, True)
+ with grpc.insecure_channel(addr) as channel:
+ try:
+ channel.subscribe(wait_for_transient_failure)
+ perform_call(channel, wait_for_ready=True)
+ except BaseException as e: # pylint: disable=broad-except
+ # If the call failed, the thread would be destroyed. The
+ # channel object can be collected before calling the
+ # callback, which will result in a deadlock.
+ wg.done()
+ unhandled_exceptions.put(e, True)
test_threads = []
for perform_call in _ALL_CALL_CASES:
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py
index 777ab683e3..892df3df08 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py
@@ -186,6 +186,7 @@ class MetadataTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py
index f6d4fcbd0a..d4ea126e2b 100644
--- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py
+++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py
@@ -98,6 +98,8 @@ class ReconnectTest(unittest.TestCase):
server.add_insecure_port('[::]:{}'.format(port))
server.start()
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
+ server.stop(None)
+ channel.close()
if __name__ == '__main__':
diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
index 4fead8fcd5..517c2d2f97 100644
--- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
+++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
@@ -148,6 +148,7 @@ class ResourceExhaustedTest(unittest.TestCase):
def tearDown(self):
self._server.stop(0)
+ self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py
index a768d6c7c1..a99121cee5 100644
--- a/src/python/grpcio_tests/tests/unit/_rpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py
@@ -193,6 +193,7 @@ class RPCTest(unittest.TestCase):
def tearDown(self):
self._server.stop(None)
+ self._channel.close()
def testUnrecognizedMethod(self):
request = b'abc'
diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py
new file mode 100644
index 0000000000..1d1fdba11e
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py
@@ -0,0 +1,97 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Defines a number of module-scope gRPC scenarios to test server shutdown."""
+
+import argparse
+import os
+import threading
+import time
+import logging
+
+import grpc
+from tests.unit import test_common
+
+from concurrent import futures
+from six.moves import queue
+
+WAIT_TIME = 1000
+
+REQUEST = b'request'
+RESPONSE = b'response'
+
+SERVER_RAISES_EXCEPTION = 'server_raises_exception'
+SERVER_DEALLOCATED = 'server_deallocated'
+SERVER_FORK_CAN_EXIT = 'server_fork_can_exit'
+
+FORK_EXIT = '/test/ForkExit'
+
+
+def fork_and_exit(request, servicer_context):
+ pid = os.fork()
+ if pid == 0:
+ os._exit(0)
+ return RESPONSE
+
+
+class GenericHandler(grpc.GenericRpcHandler):
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == FORK_EXIT:
+ return grpc.unary_unary_rpc_method_handler(fork_and_exit)
+ else:
+ return None
+
+
+def run_server(port_queue):
+ server = test_common.test_server()
+ port = server.add_insecure_port('[::]:0')
+ port_queue.put(port)
+ server.add_generic_rpc_handlers((GenericHandler(),))
+ server.start()
+ # threading.Event.wait() does not exhibit the bug identified in
+ # https://github.com/grpc/grpc/issues/17093, sleep instead
+ time.sleep(WAIT_TIME)
+
+
+def run_test(args):
+ if args.scenario == SERVER_RAISES_EXCEPTION:
+ server = test_common.test_server()
+ server.start()
+ raise Exception()
+ elif args.scenario == SERVER_DEALLOCATED:
+ server = test_common.test_server()
+ server.start()
+ server.__del__()
+ while server._state.stage != grpc._server._ServerStage.STOPPED:
+ pass
+ elif args.scenario == SERVER_FORK_CAN_EXIT:
+ port_queue = queue.Queue()
+ thread = threading.Thread(target=run_server, args=(port_queue,))
+ thread.daemon = True
+ thread.start()
+ port = port_queue.get()
+ channel = grpc.insecure_channel('localhost:%d' % port)
+ multi_callable = channel.unary_unary(FORK_EXIT)
+ result, call = multi_callable.with_call(REQUEST, wait_for_ready=True)
+ os.wait()
+ else:
+ raise ValueError('unknown test scenario')
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('scenario', type=str)
+ args = parser.parse_args()
+ run_test(args)
diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py
new file mode 100644
index 0000000000..47446d65a5
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py
@@ -0,0 +1,90 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests clean shutdown of server on various interpreter exit conditions.
+
+The tests in this module spawn a subprocess for each test case, the
+test is considered successful if it doesn't hang/timeout.
+"""
+
+import atexit
+import os
+import subprocess
+import sys
+import threading
+import unittest
+import logging
+
+from tests.unit import _server_shutdown_scenarios
+
+SCENARIO_FILE = os.path.abspath(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)),
+ '_server_shutdown_scenarios.py'))
+INTERPRETER = sys.executable
+BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
+
+processes = []
+process_lock = threading.Lock()
+
+
+# Make sure we attempt to clean up any
+# processes we may have left running
+def cleanup_processes():
+ with process_lock:
+ for process in processes:
+ try:
+ process.kill()
+ except Exception: # pylint: disable=broad-except
+ pass
+
+
+atexit.register(cleanup_processes)
+
+
+def wait(process):
+ with process_lock:
+ processes.append(process)
+ process.wait()
+
+
+class ServerShutdown(unittest.TestCase):
+
+ # Currently we shut down a server (if possible) after the Python server
+ # instance is garbage collected. This behavior may change in the future.
+ def test_deallocated_server_stops(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ def test_server_exception_exits(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ @unittest.skipIf(os.name == 'nt', 'fork not supported on windows')
+ def test_server_fork_can_exit(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_version_test.py b/src/python/grpcio_tests/tests/unit/_version_test.py
new file mode 100644
index 0000000000..3d37b319e5
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_version_test.py
@@ -0,0 +1,30 @@
+# Copyright 2018 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Test for grpc.__version__"""
+
+import unittest
+import grpc
+import logging
+from grpc import _grpcio_metadata
+
+
+class VersionTest(unittest.TestCase):
+
+ def test_get_version(self):
+ self.assertEqual(grpc.__version__, _grpcio_metadata.__version__)
+
+
+if __name__ == '__main__':
+ logging.basicConfig()
+ unittest.main(verbosity=2)
diff --git a/src/ruby/end2end/graceful_sig_handling_client.rb b/src/ruby/end2end/graceful_sig_handling_client.rb
new file mode 100755
index 0000000000..14a67a62cc
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_handling_client.rb
@@ -0,0 +1,61 @@
+#!/usr/bin/env ruby
+
+# Copyright 2015 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+require_relative './end2end_common'
+
+# Test client. Sends RPC's as normal but process also has signal handlers
+class SigHandlingClientController < ClientControl::ClientController::Service
+ def initialize(stub)
+ @stub = stub
+ end
+
+ def do_echo_rpc(req, _)
+ response = @stub.echo(Echo::EchoRequest.new(request: req.request))
+ fail 'bad response' unless response.response == req.request
+ ClientControl::Void.new
+ end
+end
+
+def main
+ client_control_port = ''
+ server_port = ''
+ OptionParser.new do |opts|
+ opts.on('--client_control_port=P', String) do |p|
+ client_control_port = p
+ end
+ opts.on('--server_port=P', String) do |p|
+ server_port = p
+ end
+ end.parse!
+
+ # Allow a few seconds to be safe.
+ srv = new_rpc_server_for_testing
+ srv.add_http2_port("0.0.0.0:#{client_control_port}",
+ :this_port_is_insecure)
+ stub = Echo::EchoServer::Stub.new("localhost:#{server_port}",
+ :this_channel_is_insecure)
+ control_service = SigHandlingClientController.new(stub)
+ srv.handle(control_service)
+ server_thread = Thread.new do
+ srv.run_till_terminated_or_interrupted(['int'])
+ end
+ srv.wait_till_running
+ # send a first RPC to notify the parent process that we've started
+ stub.echo(Echo::EchoRequest.new(request: 'client/child started'))
+ server_thread.join
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_handling_driver.rb b/src/ruby/end2end/graceful_sig_handling_driver.rb
new file mode 100755
index 0000000000..e12ae28485
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_handling_driver.rb
@@ -0,0 +1,83 @@
+#!/usr/bin/env ruby
+
+# Copyright 2016 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# smoke test for a grpc-using app that receives and
+# handles process-ending signals
+
+require_relative './end2end_common'
+
+# A service that calls back it's received_rpc_callback
+# upon receiving an RPC. Used for synchronization/waiting
+# for child process to start.
+class ClientStartedService < Echo::EchoServer::Service
+ def initialize(received_rpc_callback)
+ @received_rpc_callback = received_rpc_callback
+ end
+
+ def echo(echo_req, _)
+ @received_rpc_callback.call unless @received_rpc_callback.nil?
+ @received_rpc_callback = nil
+ Echo::EchoReply.new(response: echo_req.request)
+ end
+end
+
+def main
+ STDERR.puts 'start server'
+ client_started = false
+ client_started_mu = Mutex.new
+ client_started_cv = ConditionVariable.new
+ received_rpc_callback = proc do
+ client_started_mu.synchronize do
+ client_started = true
+ client_started_cv.signal
+ end
+ end
+
+ client_started_service = ClientStartedService.new(received_rpc_callback)
+ server_runner = ServerRunner.new(client_started_service)
+ server_port = server_runner.run
+ STDERR.puts 'start client'
+ control_stub, client_pid = start_client('graceful_sig_handling_client.rb', server_port)
+
+ client_started_mu.synchronize do
+ client_started_cv.wait(client_started_mu) until client_started
+ end
+
+ control_stub.do_echo_rpc(
+ ClientControl::DoEchoRpcRequest.new(request: 'hello'))
+
+ STDERR.puts 'killing client'
+ Process.kill('SIGINT', client_pid)
+ Process.wait(client_pid)
+ client_exit_status = $CHILD_STATUS
+
+ if client_exit_status.exited?
+ if client_exit_status.exitstatus != 0
+ STDERR.puts 'Client did not close gracefully'
+ exit(1)
+ end
+ else
+ STDERR.puts 'Client did not close gracefully'
+ exit(1)
+ end
+
+ STDERR.puts 'Client ended gracefully'
+
+ # no need to call cleanup, client should already be dead
+ server_runner.stop
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_stop_client.rb b/src/ruby/end2end/graceful_sig_stop_client.rb
new file mode 100755
index 0000000000..b672dc3f2a
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_stop_client.rb
@@ -0,0 +1,78 @@
+#!/usr/bin/env ruby
+
+# Copyright 2015 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+require_relative './end2end_common'
+
+# Test client. Sends RPC's as normal but process also has signal handlers
+class SigHandlingClientController < ClientControl::ClientController::Service
+ def initialize(srv, stub)
+ @srv = srv
+ @stub = stub
+ end
+
+ def do_echo_rpc(req, _)
+ response = @stub.echo(Echo::EchoRequest.new(request: req.request))
+ fail 'bad response' unless response.response == req.request
+ ClientControl::Void.new
+ end
+
+ def shutdown(_, _)
+ # Spawn a new thread because RpcServer#stop is
+ # synchronous and blocks until either this RPC has finished,
+ # or the server's "poll_period" seconds have passed.
+ @shutdown_thread = Thread.new do
+ @srv.stop
+ end
+ ClientControl::Void.new
+ end
+
+ def join_shutdown_thread
+ @shutdown_thread.join
+ end
+end
+
+def main
+ client_control_port = ''
+ server_port = ''
+ OptionParser.new do |opts|
+ opts.on('--client_control_port=P', String) do |p|
+ client_control_port = p
+ end
+ opts.on('--server_port=P', String) do |p|
+ server_port = p
+ end
+ end.parse!
+
+ # The "shutdown" RPC should end very quickly.
+ # Allow a few seconds to be safe.
+ srv = new_rpc_server_for_testing(poll_period: 3)
+ srv.add_http2_port("0.0.0.0:#{client_control_port}",
+ :this_port_is_insecure)
+ stub = Echo::EchoServer::Stub.new("localhost:#{server_port}",
+ :this_channel_is_insecure)
+ control_service = SigHandlingClientController.new(srv, stub)
+ srv.handle(control_service)
+ server_thread = Thread.new do
+ srv.run_till_terminated_or_interrupted(['int'])
+ end
+ srv.wait_till_running
+ # send a first RPC to notify the parent process that we've started
+ stub.echo(Echo::EchoRequest.new(request: 'client/child started'))
+ server_thread.join
+ control_service.join_shutdown_thread
+end
+
+main
diff --git a/src/ruby/end2end/graceful_sig_stop_driver.rb b/src/ruby/end2end/graceful_sig_stop_driver.rb
new file mode 100755
index 0000000000..7a132403eb
--- /dev/null
+++ b/src/ruby/end2end/graceful_sig_stop_driver.rb
@@ -0,0 +1,62 @@
+#!/usr/bin/env ruby
+
+# Copyright 2016 gRPC authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# smoke test for a grpc-using app that receives and
+# handles process-ending signals
+
+require_relative './end2end_common'
+
+# A service that calls back it's received_rpc_callback
+# upon receiving an RPC. Used for synchronization/waiting
+# for child process to start.
+class ClientStartedService < Echo::EchoServer::Service
+ def initialize(received_rpc_callback)
+ @received_rpc_callback = received_rpc_callback
+ end
+
+ def echo(echo_req, _)
+ @received_rpc_callback.call unless @received_rpc_callback.nil?
+ @received_rpc_callback = nil
+ Echo::EchoReply.new(response: echo_req.request)
+ end
+end
+
+def main
+ STDERR.puts 'start server'
+ client_started = false
+ client_started_mu = Mutex.new
+ client_started_cv = ConditionVariable.new
+ received_rpc_callback = proc do
+ client_started_mu.synchronize do
+ client_started = true
+ client_started_cv.signal
+ end
+ end
+
+ client_started_service = ClientStartedService.new(received_rpc_callback)
+ server_runner = ServerRunner.new(client_started_service)
+ server_port = server_runner.run
+ STDERR.puts 'start client'
+ control_stub, client_pid = start_client('./graceful_sig_stop_client.rb', server_port)
+
+ client_started_mu.synchronize do
+ client_started_cv.wait(client_started_mu) until client_started
+ end
+
+ cleanup(control_stub, client_pid, server_runner)
+end
+
+main
diff --git a/src/ruby/lib/grpc/generic/rpc_server.rb b/src/ruby/lib/grpc/generic/rpc_server.rb
index 3b5a0ce27f..f0f73dc56e 100644
--- a/src/ruby/lib/grpc/generic/rpc_server.rb
+++ b/src/ruby/lib/grpc/generic/rpc_server.rb
@@ -240,6 +240,13 @@ module GRPC
# the call has no impact if the server is already stopped, otherwise
# server's current call loop is it's last.
def stop
+ # if called via run_till_terminated_or_interrupted,
+ # signal stop_server_thread and dont do anything
+ if @stop_server.nil? == false && @stop_server == false
+ @stop_server = true
+ @stop_server_cv.broadcast
+ return
+ end
@run_mutex.synchronize do
fail 'Cannot stop before starting' if @running_state == :not_started
return if @running_state != :running
@@ -354,6 +361,60 @@ module GRPC
alias_method :run_till_terminated, :run
+ # runs the server with signal handlers
+ # @param signals
+ # List of String, Integer or both representing signals that the user
+ # would like to send to the server for graceful shutdown
+ # @param wait_interval (optional)
+ # Integer seconds that user would like stop_server_thread to poll
+ # stop_server
+ def run_till_terminated_or_interrupted(signals, wait_interval = 60)
+ @stop_server = false
+ @stop_server_mu = Mutex.new
+ @stop_server_cv = ConditionVariable.new
+
+ @stop_server_thread = Thread.new do
+ loop do
+ break if @stop_server
+ @stop_server_mu.synchronize do
+ @stop_server_cv.wait(@stop_server_mu, wait_interval)
+ end
+ end
+
+ # stop is surrounded by mutex, should handle multiple calls to stop
+ # correctly
+ stop
+ end
+
+ valid_signals = Signal.list
+
+ # register signal handlers
+ signals.each do |sig|
+ # input validation
+ if sig.class == String
+ sig.upcase!
+ if sig.start_with?('SIG')
+ # cut out the SIG prefix to see if valid signal
+ sig = sig[3..-1]
+ end
+ end
+
+ # register signal traps for all valid signals
+ if valid_signals.value?(sig) || valid_signals.key?(sig)
+ Signal.trap(sig) do
+ @stop_server = true
+ @stop_server_cv.broadcast
+ end
+ else
+ fail "#{sig} not a valid signal"
+ end
+ end
+
+ run
+
+ @stop_server_thread.join
+ end
+
# Sends RESOURCE_EXHAUSTED if there are too many unprocessed jobs
def available?(an_rpc)
return an_rpc if @pool.ready_for_work?
diff --git a/src/ruby/lib/grpc/version.rb b/src/ruby/lib/grpc/version.rb
index a4ed052d85..3b7f62d9f5 100644
--- a/src/ruby/lib/grpc/version.rb
+++ b/src/ruby/lib/grpc/version.rb
@@ -14,5 +14,5 @@
# GRPC contains the General RPC module.
module GRPC
- VERSION = '1.18.0.dev'
+ VERSION = '1.19.0.dev'
end
diff --git a/src/ruby/pb/test/client.rb b/src/ruby/pb/test/client.rb
index b2303c6e14..03f3d9001a 100755
--- a/src/ruby/pb/test/client.rb
+++ b/src/ruby/pb/test/client.rb
@@ -111,10 +111,13 @@ def create_stub(opts)
if opts.secure
creds = ssl_creds(opts.use_test_ca)
stub_opts = {
- channel_args: {
- GRPC::Core::Channel::SSL_TARGET => opts.server_host_override
- }
+ channel_args: {}
}
+ unless opts.server_host_override.empty?
+ stub_opts[:channel_args].merge!({
+ GRPC::Core::Channel::SSL_TARGET => opts.server_host_override
+ })
+ end
# Add service account creds if specified
wants_creds = %w(all compute_engine_creds service_account_creds)
@@ -603,7 +606,7 @@ class NamedTests
if not op.metadata.has_key?(initial_metadata_key)
fail AssertionError, "Expected initial metadata. None received"
elsif op.metadata[initial_metadata_key] != metadata[initial_metadata_key]
- fail AssertionError,
+ fail AssertionError,
"Expected initial metadata: #{metadata[initial_metadata_key]}. "\
"Received: #{op.metadata[initial_metadata_key]}"
end
@@ -611,7 +614,7 @@ class NamedTests
fail AssertionError, "Expected trailing metadata. None received"
elsif op.trailing_metadata[trailing_metadata_key] !=
metadata[trailing_metadata_key]
- fail AssertionError,
+ fail AssertionError,
"Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\
"Received: #{op.trailing_metadata[trailing_metadata_key]}"
end
@@ -639,7 +642,7 @@ class NamedTests
fail AssertionError, "Expected trailing metadata. None received"
elsif duplex_op.trailing_metadata[trailing_metadata_key] !=
metadata[trailing_metadata_key]
- fail AssertionError,
+ fail AssertionError,
"Expected trailing metadata: #{metadata[trailing_metadata_key]}. "\
"Received: #{duplex_op.trailing_metadata[trailing_metadata_key]}"
end
@@ -710,7 +713,7 @@ Args = Struct.new(:default_service_account, :server_host, :server_host_override,
# validates the command line options, returning them as a Hash.
def parse_args
args = Args.new
- args.server_host_override = 'foo.test.google.fr'
+ args.server_host_override = ''
OptionParser.new do |opts|
opts.on('--oauth_scope scope',
'Scope for OAuth tokens') { |v| args['oauth_scope'] = v }
diff --git a/src/ruby/tools/version.rb b/src/ruby/tools/version.rb
index 389fb70684..2ad685a7eb 100644
--- a/src/ruby/tools/version.rb
+++ b/src/ruby/tools/version.rb
@@ -14,6 +14,6 @@
module GRPC
module Tools
- VERSION = '1.18.0.dev'
+ VERSION = '1.19.0.dev'
end
end