From c21393e553e7ab528edb5966c835bc0832ea4edf Mon Sep 17 00:00:00 2001 From: "tongxuan.ltx" Date: Thu, 8 Nov 2018 10:43:18 +0800 Subject: g_default_client_callbacks shouldn't be global variable In tensorflow, RPC client thread doesn't active release, rely on process to cleanup. If process have already cleanup the global variable(g_default_client_callbacks), after that client issue a RPC call which contains the ClientContext, then once ClientContext destructor called, pure virtual functions call error is reported. --- src/cpp/client/client_context.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'src/cpp') diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index 50da75f09c..c9ea3e5f83 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -41,9 +41,10 @@ class DefaultGlobalClientCallbacks final }; static internal::GrpcLibraryInitializer g_gli_initializer; -static DefaultGlobalClientCallbacks g_default_client_callbacks; +static DefaultGlobalClientCallbacks* g_default_client_callbacks = + new DefaultGlobalClientCallbacks(); static ClientContext::GlobalCallbacks* g_client_callbacks = - &g_default_client_callbacks; + g_default_client_callbacks; ClientContext::ClientContext() : initial_metadata_received_(false), @@ -139,9 +140,9 @@ grpc::string ClientContext::peer() const { } void ClientContext::SetGlobalCallbacks(GlobalCallbacks* client_callbacks) { - GPR_ASSERT(g_client_callbacks == &g_default_client_callbacks); + GPR_ASSERT(g_client_callbacks == g_default_client_callbacks); GPR_ASSERT(client_callbacks != nullptr); - GPR_ASSERT(client_callbacks != &g_default_client_callbacks); + GPR_ASSERT(client_callbacks != g_default_client_callbacks); g_client_callbacks = client_callbacks; } -- cgit v1.2.3 From 0e29d7b9bce4d67a451a6cd7bc5ca86f36a0f121 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Mon, 12 Nov 2018 23:16:54 -0800 Subject: Properly clear metadata and other structs when reusing ServerContext --- include/grpcpp/impl/codegen/metadata_map.h | 19 +++++++++++++++---- src/cpp/server/server_context.cc | 4 ++++ test/cpp/end2end/client_callback_end2end_test.cc | 17 ++++++++++++++--- 3 files changed, 33 insertions(+), 7 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/impl/codegen/metadata_map.h b/include/grpcpp/impl/codegen/metadata_map.h index 0bba3ed4e3..9cec54d9f0 100644 --- a/include/grpcpp/impl/codegen/metadata_map.h +++ b/include/grpcpp/impl/codegen/metadata_map.h @@ -32,11 +32,9 @@ const char kBinaryErrorDetailsKey[] = "grpc-status-details-bin"; class MetadataMap { public: - MetadataMap() { memset(&arr_, 0, sizeof(arr_)); } + MetadataMap() { Setup(); } - ~MetadataMap() { - g_core_codegen_interface->grpc_metadata_array_destroy(&arr_); - } + ~MetadataMap() { Destroy(); } grpc::string GetBinaryErrorDetails() { // if filled_, extract from the multimap for O(log(n)) @@ -71,11 +69,24 @@ class MetadataMap { } grpc_metadata_array* arr() { return &arr_; } + void Reset() { + filled_ = false; + map_.clear(); + Destroy(); + Setup(); + } + private: bool filled_ = false; grpc_metadata_array arr_; std::multimap map_; + void Destroy() { + g_core_codegen_interface->grpc_metadata_array_destroy(&arr_); + } + + void Setup() { memset(&arr_, 0, sizeof(arr_)); } + void FillMap() { if (filled_) return; filled_ = true; diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 396996e5bc..9c01f896e6 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -247,6 +247,10 @@ void ServerContext::BindDeadlineAndMetadata(gpr_timespec deadline, ServerContext::~ServerContext() { Clear(); } void ServerContext::Clear() { + auth_context_.reset(); + initial_metadata_.clear(); + trailing_metadata_.clear(); + client_metadata_.Reset(); if (call_) { grpc_call_unref(call_); } diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index 7ffc610ce2..a35991396a 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -34,6 +34,7 @@ #include "test/core/util/test_config.h" #include "test/cpp/end2end/test_service_impl.h" #include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" #include @@ -100,11 +101,13 @@ class ClientCallbackEnd2endTest test_string += "Hello world. "; request.set_message(test_string); - + grpc::string val; if (with_binary_metadata) { + request.mutable_param()->set_echo_metadata(true); char bytes[8] = {'\0', '\1', '\2', '\3', '\4', '\5', '\6', static_cast(i)}; - cli_ctx.AddMetadata("custom-bin", grpc::string(bytes, 8)); + val = grpc::string(bytes, 8); + cli_ctx.AddMetadata("custom-bin", val); } cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP); @@ -114,10 +117,18 @@ class ClientCallbackEnd2endTest bool done = false; stub_->experimental_async()->Echo( &cli_ctx, &request, &response, - [&request, &response, &done, &mu, &cv](Status s) { + [&cli_ctx, &request, &response, &done, &mu, &cv, val, + with_binary_metadata](Status s) { GPR_ASSERT(s.ok()); EXPECT_EQ(request.message(), response.message()); + if (with_binary_metadata) { + EXPECT_EQ( + 1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin")); + EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata() + .find("custom-bin") + ->second)); + } std::lock_guard l(mu); done = true; cv.notify_one(); -- cgit v1.2.3 From a9bee9b7edb4045efd52b0e238d07485a791162f Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Wed, 14 Nov 2018 17:48:35 -0800 Subject: Make Pluck use the changes made in FinalizeResult --- include/grpcpp/impl/codegen/completion_queue.h | 3 +-- src/cpp/server/server_cc.cc | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h index d603c7c700..fb38788f7d 100644 --- a/include/grpcpp/impl/codegen/completion_queue.h +++ b/include/grpcpp/impl/codegen/completion_queue.h @@ -307,8 +307,7 @@ class CompletionQueue : private GrpcLibraryCodegen { void* ignored = tag; if (tag->FinalizeResult(&ignored, &ok)) { GPR_CODEGEN_ASSERT(ignored == tag); - // Ignore mutations by FinalizeResult: Pluck returns the C API status - return ev.success != 0; + return ok; } } } diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7a98bce507..a703411c2a 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -84,10 +84,7 @@ class ShutdownTag : public internal::CompletionQueueTag { class DummyTag : public internal::CompletionQueueTag { public: - bool FinalizeResult(void** tag, bool* status) { - *status = true; - return true; - } + bool FinalizeResult(void** tag, bool* status) { return true; } }; class UnimplementedAsyncRequestContext { -- cgit v1.2.3 From 626f1c9d537de45c6604a1dc7b103933073c4f00 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 15 Nov 2018 16:49:43 -0800 Subject: Remove the std::unique_ptr, instead use move semantics everywhere --- include/grpcpp/channel.h | 8 +- include/grpcpp/create_channel.h | 4 +- include/grpcpp/security/credentials.h | 12 +-- include/grpcpp/server.h | 4 +- src/cpp/client/channel_cc.cc | 8 +- src/cpp/client/create_channel.cc | 36 +++---- src/cpp/client/create_channel_internal.cc | 4 +- src/cpp/client/create_channel_internal.h | 4 +- src/cpp/client/create_channel_posix.cc | 10 +- src/cpp/client/cronet_credentials.cc | 9 +- src/cpp/client/insecure_credentials.cc | 9 +- src/cpp/client/secure_credentials.cc | 9 +- src/cpp/client/secure_credentials.h | 4 +- src/cpp/server/server_cc.cc | 7 +- .../end2end/client_interceptors_end2end_test.cc | 108 ++++++++------------- test/cpp/end2end/interceptors_util.cc | 11 +-- test/cpp/end2end/interceptors_util.h | 3 +- 17 files changed, 118 insertions(+), 132 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/channel.h b/include/grpcpp/channel.h index 4502b94b17..ee83396069 100644 --- a/include/grpcpp/channel.h +++ b/include/grpcpp/channel.h @@ -65,13 +65,13 @@ class Channel final : public ChannelInterface, friend void experimental::ChannelResetConnectionBackoff(Channel* channel); friend std::shared_ptr CreateChannelInternal( const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); friend class internal::InterceptedChannel; Channel(const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); internal::Call CreateCall(const internal::RpcMethod& method, diff --git a/include/grpcpp/create_channel.h b/include/grpcpp/create_channel.h index 43188d09e7..e8a2a70581 100644 --- a/include/grpcpp/create_channel.h +++ b/include/grpcpp/create_channel.h @@ -70,8 +70,8 @@ std::shared_ptr CreateCustomChannelWithInterceptors( const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); } // namespace experimental } // namespace grpc diff --git a/include/grpcpp/security/credentials.h b/include/grpcpp/security/credentials.h index 8dfbdec3e6..d8c9e04d77 100644 --- a/include/grpcpp/security/credentials.h +++ b/include/grpcpp/security/credentials.h @@ -46,8 +46,8 @@ std::shared_ptr CreateCustomChannelWithInterceptors( const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); } // namespace experimental @@ -80,8 +80,8 @@ class ChannelCredentials : private GrpcLibraryCodegen { const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); virtual std::shared_ptr CreateChannel( @@ -91,8 +91,8 @@ class ChannelCredentials : private GrpcLibraryCodegen { // implemented as a virtual function so that it does not break API. virtual std::shared_ptr CreateChannelWithInterceptors( const grpc::string& target, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { return nullptr; }; diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index a14a4da578..cdcac186cb 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -111,8 +111,8 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { /// interceptors std::shared_ptr InProcessChannelWithInterceptors( const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); private: diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 8e1cea0269..d1c55319f7 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -54,13 +54,11 @@ namespace grpc { static internal::GrpcLibraryInitializer g_gli_initializer; Channel::Channel( const grpc::string& host, grpc_channel* channel, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) : host_(host), c_channel_(channel) { - if (interceptor_creators != nullptr) { - interceptor_creators_ = std::move(*interceptor_creators); - } + interceptor_creators_ = std::move(interceptor_creators); g_gli_initializer.summon(); } diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc index efdff6c265..457daa674c 100644 --- a/src/cpp/client/create_channel.cc +++ b/src/cpp/client/create_channel.cc @@ -39,13 +39,14 @@ std::shared_ptr CreateCustomChannel( const std::shared_ptr& creds, const ChannelArguments& args) { GrpcLibraryCodegen init_lib; // We need to call init in case of a bad creds. - return creds - ? creds->CreateChannel(target, args) - : CreateChannelInternal("", - grpc_lame_client_channel_create( - nullptr, GRPC_STATUS_INVALID_ARGUMENT, - "Invalid credentials."), - nullptr); + return creds ? creds->CreateChannel(target, args) + : CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::vector>()); } namespace experimental { @@ -64,17 +65,18 @@ std::shared_ptr CreateCustomChannelWithInterceptors( const grpc::string& target, const std::shared_ptr& creds, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { - return creds - ? creds->CreateChannelWithInterceptors( - target, args, std::move(interceptor_creators)) - : CreateChannelInternal("", - grpc_lame_client_channel_create( - nullptr, GRPC_STATUS_INVALID_ARGUMENT, - "Invalid credentials."), - nullptr); + return creds ? creds->CreateChannelWithInterceptors( + target, args, std::move(interceptor_creators)) + : CreateChannelInternal( + "", + grpc_lame_client_channel_create( + nullptr, GRPC_STATUS_INVALID_ARGUMENT, + "Invalid credentials."), + std::vector>()); } } // namespace experimental diff --git a/src/cpp/client/create_channel_internal.cc b/src/cpp/client/create_channel_internal.cc index 313d682aae..a0efb97f7e 100644 --- a/src/cpp/client/create_channel_internal.cc +++ b/src/cpp/client/create_channel_internal.cc @@ -26,8 +26,8 @@ namespace grpc { std::shared_ptr CreateChannelInternal( const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { return std::shared_ptr( new Channel(host, c_channel, std::move(interceptor_creators))); diff --git a/src/cpp/client/create_channel_internal.h b/src/cpp/client/create_channel_internal.h index 512fc22866..a90c92c518 100644 --- a/src/cpp/client/create_channel_internal.h +++ b/src/cpp/client/create_channel_internal.h @@ -31,8 +31,8 @@ class Channel; std::shared_ptr CreateChannelInternal( const grpc::string& host, grpc_channel* c_channel, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators); } // namespace grpc diff --git a/src/cpp/client/create_channel_posix.cc b/src/cpp/client/create_channel_posix.cc index 8d775e7a87..3affc1ef39 100644 --- a/src/cpp/client/create_channel_posix.cc +++ b/src/cpp/client/create_channel_posix.cc @@ -34,7 +34,8 @@ std::shared_ptr CreateInsecureChannelFromFd(const grpc::string& target, init_lib.init(); return CreateChannelInternal( "", grpc_insecure_channel_create_from_fd(target.c_str(), fd, nullptr), - nullptr); + std::vector< + std::unique_ptr>()); } std::shared_ptr CreateCustomInsecureChannelFromFd( @@ -46,15 +47,16 @@ std::shared_ptr CreateCustomInsecureChannelFromFd( return CreateChannelInternal( "", grpc_insecure_channel_create_from_fd(target.c_str(), fd, &channel_args), - nullptr); + std::vector< + std::unique_ptr>()); } namespace experimental { std::shared_ptr CreateCustomInsecureChannelWithInterceptorsFromFd( const grpc::string& target, int fd, const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { internal::GrpcLibrary init_lib; init_lib.init(); diff --git a/src/cpp/client/cronet_credentials.cc b/src/cpp/client/cronet_credentials.cc index 09a76b428c..b2801764f2 100644 --- a/src/cpp/client/cronet_credentials.cc +++ b/src/cpp/client/cronet_credentials.cc @@ -31,7 +31,10 @@ class CronetChannelCredentialsImpl final : public ChannelCredentials { std::shared_ptr CreateChannel( const string& target, const grpc::ChannelArguments& args) override { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector>()); } SecureChannelCredentials* AsSecureCredentials() override { return nullptr; } @@ -39,8 +42,8 @@ class CronetChannelCredentialsImpl final : public ChannelCredentials { private: std::shared_ptr CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) override { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc index b816e0c59a..241ce91803 100644 --- a/src/cpp/client/insecure_credentials.cc +++ b/src/cpp/client/insecure_credentials.cc @@ -32,13 +32,16 @@ class InsecureChannelCredentialsImpl final : public ChannelCredentials { public: std::shared_ptr CreateChannel( const string& target, const grpc::ChannelArguments& args) override { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector>()); } std::shared_ptr CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) override { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index 7faaa20e78..d0abe441a6 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -36,14 +36,17 @@ SecureChannelCredentials::SecureChannelCredentials( std::shared_ptr SecureChannelCredentials::CreateChannel( const string& target, const grpc::ChannelArguments& args) { - return CreateChannelWithInterceptors(target, args, nullptr); + return CreateChannelWithInterceptors( + target, args, + std::vector< + std::unique_ptr>()); } std::shared_ptr SecureChannelCredentials::CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { grpc_channel_args channel_args; args.SetChannelArgs(&channel_args); diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index bfb6e17ee9..613f1d6dc2 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -42,8 +42,8 @@ class SecureChannelCredentials final : public ChannelCredentials { private: std::shared_ptr CreateChannelWithInterceptors( const string& target, const grpc::ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) override; grpc_channel_credentials* const c_creds_; }; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7a98bce507..72d005f23d 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -732,14 +732,15 @@ std::shared_ptr Server::InProcessChannel( grpc_channel_args channel_args = args.c_channel_args(); return CreateChannelInternal( "inproc", grpc_inproc_channel_create(server_, &channel_args, nullptr), - nullptr); + std::vector< + std::unique_ptr>()); } std::shared_ptr Server::experimental_type::InProcessChannelWithInterceptors( const ChannelArguments& args, - std::unique_ptr>> + std::vector< + std::unique_ptr> interceptor_creators) { grpc_channel_args channel_args = args.c_channel_args(); return CreateChannelInternal( diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 0b34ec93ae..60e8b051ab 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -361,15 +361,13 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test { TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -382,20 +380,18 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 20 dummy interceptors before hijacking interceptor for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new HijackingInterceptorFactory())); // Add 20 dummy interceptors after hijacking interceptor for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -408,13 +404,11 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { ChannelArguments args; - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new HijackingInterceptorFactory())); auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); @@ -426,21 +420,19 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingMakesAnotherCallTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 5 dummy interceptors before hijacking interceptor for (auto i = 0; i < 5; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } - creators->push_back( + creators.push_back( std::unique_ptr( new HijackingInterceptorMakesAnotherCallFactory())); // Add 7 dummy interceptors after hijacking interceptor for (auto i = 0; i < 7; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = server_->experimental().InProcessChannelWithInterceptors( @@ -456,15 +448,13 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTestWithCallback) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = server_->experimental().InProcessChannelWithInterceptors( @@ -496,15 +486,13 @@ class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test { TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -517,15 +505,13 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -538,15 +524,13 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back(std::unique_ptr( + std::vector> + creators; + creators.push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -583,13 +567,11 @@ TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) { experimental::RegisterGlobalClientInterceptorFactory(&global_factory); ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -610,13 +592,11 @@ TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) { experimental::RegisterGlobalClientInterceptorFactory(&global_factory); ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( @@ -637,13 +617,11 @@ TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) { experimental::RegisterGlobalClientInterceptorFactory(&global_factory); ChannelArguments args; DummyInterceptor::Reset(); - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index 602d1695a3..5d59c1a4b7 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -132,16 +132,13 @@ bool CheckMetadata(const std::multimap& map, return false; } -std::unique_ptr>> +std::vector> CreateDummyClientInterceptors() { - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); + std::vector> + creators; // Add 20 dummy interceptors before hijacking interceptor for (auto i = 0; i < 20; i++) { - creators->push_back(std::unique_ptr( + creators.push_back(std::unique_ptr( new DummyInterceptorFactory())); } return creators; diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index b4c4791fca..d886e32494 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -149,8 +149,7 @@ void MakeCallbackCall(const std::shared_ptr& channel); bool CheckMetadata(const std::multimap& map, const string& key, const string& value); -std::unique_ptr>> +std::vector> CreateDummyClientInterceptors(); inline void* tag(int i) { return (void*)static_cast(i); } -- cgit v1.2.3 From f3e4ae633ee4476b1a30c9216f736eb71ffc968c Mon Sep 17 00:00:00 2001 From: Muxi Yan Date: Mon, 19 Nov 2018 15:24:53 -0800 Subject: Regenerate projects --- CMakeLists.txt | 2 +- Makefile | 4 ++-- gRPC-C++.podspec | 4 ++-- gRPC-Core.podspec | 2 +- gRPC-ProtoRPC.podspec | 2 +- gRPC-RxLibrary.podspec | 2 +- gRPC.podspec | 2 +- package.xml | 4 ++-- src/core/lib/surface/version.cc | 2 +- src/cpp/common/version_cc.cc | 2 +- src/csharp/Grpc.Core/Version.csproj.include | 2 +- src/csharp/Grpc.Core/VersionInfo.cs | 4 ++-- src/csharp/build_packages_dotnetcli.bat | 2 +- src/csharp/build_unitypackage.bat | 2 +- src/objective-c/!ProtoCompiler-gRPCPlugin.podspec | 2 +- src/objective-c/GRPCClient/private/version.h | 2 +- src/objective-c/tests/version.h | 2 +- src/php/composer.json | 2 +- src/php/ext/grpc/version.h | 2 +- src/python/grpcio/grpc/_grpcio_metadata.py | 2 +- src/python/grpcio/grpc_version.py | 2 +- src/python/grpcio_health_checking/grpc_version.py | 2 +- src/python/grpcio_reflection/grpc_version.py | 2 +- src/python/grpcio_testing/grpc_version.py | 2 +- src/python/grpcio_tests/grpc_version.py | 2 +- src/ruby/lib/grpc/version.rb | 2 +- src/ruby/tools/version.rb | 2 +- tools/distrib/python/grpcio_tools/grpc_version.py | 2 +- tools/doxygen/Doxyfile.c++ | 2 +- tools/doxygen/Doxyfile.c++.internal | 2 +- 30 files changed, 34 insertions(+), 34 deletions(-) (limited to 'src/cpp') diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c2ba2048c..bea450037d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ cmake_minimum_required(VERSION 2.8) set(PACKAGE_NAME "grpc") -set(PACKAGE_VERSION "1.17.0-dev") +set(PACKAGE_VERSION "1.18.0-dev") set(PACKAGE_STRING "${PACKAGE_NAME} ${PACKAGE_VERSION}") set(PACKAGE_TARNAME "${PACKAGE_NAME}-${PACKAGE_VERSION}") set(PACKAGE_BUGREPORT "https://github.com/grpc/grpc/issues/") diff --git a/Makefile b/Makefile index 12603f1fc8..a6b56bb4f7 100644 --- a/Makefile +++ b/Makefile @@ -438,8 +438,8 @@ Q = @ endif CORE_VERSION = 7.0.0-dev -CPP_VERSION = 1.17.0-dev -CSHARP_VERSION = 1.17.0-dev +CPP_VERSION = 1.18.0-dev +CSHARP_VERSION = 1.18.0-dev CPPFLAGS_NO_ARCH += $(addprefix -I, $(INCLUDES)) $(addprefix -D, $(DEFINES)) CPPFLAGS += $(CPPFLAGS_NO_ARCH) $(ARCH_FLAGS) diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index fae5ce4a6e..c59d8fe1a6 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -23,7 +23,7 @@ Pod::Spec.new do |s| s.name = 'gRPC-C++' # TODO (mxyan): use version that match gRPC version when pod is stabilized - # version = '1.17.0-dev' + # version = '1.18.0-dev' version = '0.0.4' s.version = version s.summary = 'gRPC C++ library' @@ -31,7 +31,7 @@ Pod::Spec.new do |s| s.license = 'Apache License, Version 2.0' s.authors = { 'The gRPC contributors' => 'grpc-packages@google.com' } - grpc_version = '1.17.0-dev' + grpc_version = '1.18.0-dev' s.source = { :git => 'https://github.com/grpc/grpc.git', diff --git a/gRPC-Core.podspec b/gRPC-Core.podspec index f0a715cb58..3f04177074 100644 --- a/gRPC-Core.podspec +++ b/gRPC-Core.podspec @@ -22,7 +22,7 @@ Pod::Spec.new do |s| s.name = 'gRPC-Core' - version = '1.17.0-dev' + version = '1.18.0-dev' s.version = version s.summary = 'Core cross-platform gRPC library, written in C' s.homepage = 'https://grpc.io' diff --git a/gRPC-ProtoRPC.podspec b/gRPC-ProtoRPC.podspec index 693b873d14..13fe3e0b9c 100644 --- a/gRPC-ProtoRPC.podspec +++ b/gRPC-ProtoRPC.podspec @@ -21,7 +21,7 @@ Pod::Spec.new do |s| s.name = 'gRPC-ProtoRPC' - version = '1.17.0-dev' + version = '1.18.0-dev' s.version = version s.summary = 'RPC library for Protocol Buffers, based on gRPC' s.homepage = 'https://grpc.io' diff --git a/gRPC-RxLibrary.podspec b/gRPC-RxLibrary.podspec index fd590023e1..e132ad41b4 100644 --- a/gRPC-RxLibrary.podspec +++ b/gRPC-RxLibrary.podspec @@ -21,7 +21,7 @@ Pod::Spec.new do |s| s.name = 'gRPC-RxLibrary' - version = '1.17.0-dev' + version = '1.18.0-dev' s.version = version s.summary = 'Reactive Extensions library for iOS/OSX.' s.homepage = 'https://grpc.io' diff --git a/gRPC.podspec b/gRPC.podspec index 5e513cb127..940a1ac621 100644 --- a/gRPC.podspec +++ b/gRPC.podspec @@ -20,7 +20,7 @@ Pod::Spec.new do |s| s.name = 'gRPC' - version = '1.17.0-dev' + version = '1.18.0-dev' s.version = version s.summary = 'gRPC client library for iOS/OSX' s.homepage = 'https://grpc.io' diff --git a/package.xml b/package.xml index 3044cbf862..c5046cd461 100644 --- a/package.xml +++ b/package.xml @@ -13,8 +13,8 @@ 2018-01-19 - 1.17.0dev - 1.17.0dev + 1.18.0dev + 1.18.0dev beta diff --git a/src/core/lib/surface/version.cc b/src/core/lib/surface/version.cc index 66890ce65a..4829cc80a5 100644 --- a/src/core/lib/surface/version.cc +++ b/src/core/lib/surface/version.cc @@ -25,4 +25,4 @@ const char* grpc_version_string(void) { return "7.0.0-dev"; } -const char* grpc_g_stands_for(void) { return "gizmo"; } +const char* grpc_g_stands_for(void) { return "goose"; } diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc index 8abd45efb7..55da89e6c8 100644 --- a/src/cpp/common/version_cc.cc +++ b/src/cpp/common/version_cc.cc @@ -22,5 +22,5 @@ #include namespace grpc { -grpc::string Version() { return "1.17.0-dev"; } +grpc::string Version() { return "1.18.0-dev"; } } // namespace grpc diff --git a/src/csharp/Grpc.Core/Version.csproj.include b/src/csharp/Grpc.Core/Version.csproj.include index ed0d884365..4fffe4f644 100755 --- a/src/csharp/Grpc.Core/Version.csproj.include +++ b/src/csharp/Grpc.Core/Version.csproj.include @@ -1,7 +1,7 @@ - 1.17.0-dev + 1.18.0-dev 3.6.1 diff --git a/src/csharp/Grpc.Core/VersionInfo.cs b/src/csharp/Grpc.Core/VersionInfo.cs index 14714c8c4a..633880189c 100644 --- a/src/csharp/Grpc.Core/VersionInfo.cs +++ b/src/csharp/Grpc.Core/VersionInfo.cs @@ -33,11 +33,11 @@ namespace Grpc.Core /// /// Current AssemblyFileVersion of gRPC C# assemblies /// - public const string CurrentAssemblyFileVersion = "1.17.0.0"; + public const string CurrentAssemblyFileVersion = "1.18.0.0"; /// /// Current version of gRPC C# /// - public const string CurrentVersion = "1.17.0-dev"; + public const string CurrentVersion = "1.18.0-dev"; } } diff --git a/src/csharp/build_packages_dotnetcli.bat b/src/csharp/build_packages_dotnetcli.bat index 27688360e9..76d4f14390 100755 --- a/src/csharp/build_packages_dotnetcli.bat +++ b/src/csharp/build_packages_dotnetcli.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.17.0-dev +set VERSION=1.18.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/csharp/build_unitypackage.bat b/src/csharp/build_unitypackage.bat index dd74de0491..3334d24c11 100644 --- a/src/csharp/build_unitypackage.bat +++ b/src/csharp/build_unitypackage.bat @@ -13,7 +13,7 @@ @rem limitations under the License. @rem Current package versions -set VERSION=1.17.0-dev +set VERSION=1.18.0-dev @rem Adjust the location of nuget.exe set NUGET=C:\nuget\nuget.exe diff --git a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec index a95a120d21..55ca6048bc 100644 --- a/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec +++ b/src/objective-c/!ProtoCompiler-gRPCPlugin.podspec @@ -42,7 +42,7 @@ Pod::Spec.new do |s| # exclamation mark ensures that other "regular" pods will be able to find it as it'll be installed # before them. s.name = '!ProtoCompiler-gRPCPlugin' - v = '1.17.0-dev' + v = '1.18.0-dev' s.version = v s.summary = 'The gRPC ProtoC plugin generates Objective-C files from .proto services.' s.description = <<-DESC diff --git a/src/objective-c/GRPCClient/private/version.h b/src/objective-c/GRPCClient/private/version.h index d5463c0b4c..0be0e3c9a0 100644 --- a/src/objective-c/GRPCClient/private/version.h +++ b/src/objective-c/GRPCClient/private/version.h @@ -22,4 +22,4 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.17.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" diff --git a/src/objective-c/tests/version.h b/src/objective-c/tests/version.h index ca27c03b3c..f2fd692070 100644 --- a/src/objective-c/tests/version.h +++ b/src/objective-c/tests/version.h @@ -22,5 +22,5 @@ // instead. This file can be regenerated from the template by running // `tools/buildgen/generate_projects.sh`. -#define GRPC_OBJC_VERSION_STRING @"1.17.0-dev" +#define GRPC_OBJC_VERSION_STRING @"1.18.0-dev" #define GRPC_C_VERSION_STRING @"7.0.0-dev" diff --git a/src/php/composer.json b/src/php/composer.json index d54db91b5f..9c298c0e85 100644 --- a/src/php/composer.json +++ b/src/php/composer.json @@ -2,7 +2,7 @@ "name": "grpc/grpc-dev", "description": "gRPC library for PHP - for Developement use only", "license": "Apache-2.0", - "version": "1.17.0", + "version": "1.18.0", "require": { "php": ">=5.5.0", "google/protobuf": "^v3.3.0" diff --git a/src/php/ext/grpc/version.h b/src/php/ext/grpc/version.h index 70f8bbbf40..1ddf90a667 100644 --- a/src/php/ext/grpc/version.h +++ b/src/php/ext/grpc/version.h @@ -20,6 +20,6 @@ #ifndef VERSION_H #define VERSION_H -#define PHP_GRPC_VERSION "1.17.0dev" +#define PHP_GRPC_VERSION "1.18.0dev" #endif /* VERSION_H */ diff --git a/src/python/grpcio/grpc/_grpcio_metadata.py b/src/python/grpcio/grpc/_grpcio_metadata.py index 42b3a1ad49..7a9f173947 100644 --- a/src/python/grpcio/grpc/_grpcio_metadata.py +++ b/src/python/grpcio/grpc/_grpcio_metadata.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc/_grpcio_metadata.py.template`!!! -__version__ = """1.17.0.dev0""" +__version__ = """1.18.0.dev0""" diff --git a/src/python/grpcio/grpc_version.py b/src/python/grpcio/grpc_version.py index 71113e68d9..2e91818d2c 100644 --- a/src/python/grpcio/grpc_version.py +++ b/src/python/grpcio/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_health_checking/grpc_version.py b/src/python/grpcio_health_checking/grpc_version.py index a30aac2e0b..85fa762f7e 100644 --- a/src/python/grpcio_health_checking/grpc_version.py +++ b/src/python/grpcio_health_checking/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_health_checking/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_reflection/grpc_version.py b/src/python/grpcio_reflection/grpc_version.py index aafea9fe76..e62ab169a2 100644 --- a/src/python/grpcio_reflection/grpc_version.py +++ b/src/python/grpcio_reflection/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_reflection/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_testing/grpc_version.py b/src/python/grpcio_testing/grpc_version.py index 876acd3142..7b4c1695fa 100644 --- a/src/python/grpcio_testing/grpc_version.py +++ b/src/python/grpcio_testing/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_testing/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/python/grpcio_tests/grpc_version.py b/src/python/grpcio_tests/grpc_version.py index cc9b41587c..2fcd1ad617 100644 --- a/src/python/grpcio_tests/grpc_version.py +++ b/src/python/grpcio_tests/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/src/python/grpcio_tests/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/src/ruby/lib/grpc/version.rb b/src/ruby/lib/grpc/version.rb index 243d566645..a4ed052d85 100644 --- a/src/ruby/lib/grpc/version.rb +++ b/src/ruby/lib/grpc/version.rb @@ -14,5 +14,5 @@ # GRPC contains the General RPC module. module GRPC - VERSION = '1.17.0.dev' + VERSION = '1.18.0.dev' end diff --git a/src/ruby/tools/version.rb b/src/ruby/tools/version.rb index 92e85eb882..389fb70684 100644 --- a/src/ruby/tools/version.rb +++ b/src/ruby/tools/version.rb @@ -14,6 +14,6 @@ module GRPC module Tools - VERSION = '1.17.0.dev' + VERSION = '1.18.0.dev' end end diff --git a/tools/distrib/python/grpcio_tools/grpc_version.py b/tools/distrib/python/grpcio_tools/grpc_version.py index 4b775e667e..29b2127960 100644 --- a/tools/distrib/python/grpcio_tools/grpc_version.py +++ b/tools/distrib/python/grpcio_tools/grpc_version.py @@ -14,4 +14,4 @@ # AUTO-GENERATED FROM `$REPO_ROOT/templates/tools/distrib/python/grpcio_tools/grpc_version.py.template`!!! -VERSION = '1.17.0.dev0' +VERSION = '1.18.0.dev0' diff --git a/tools/doxygen/Doxyfile.c++ b/tools/doxygen/Doxyfile.c++ index 392113c284..fbac37e8e7 100644 --- a/tools/doxygen/Doxyfile.c++ +++ b/tools/doxygen/Doxyfile.c++ @@ -40,7 +40,7 @@ PROJECT_NAME = "GRPC C++" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 1.17.0-dev +PROJECT_NUMBER = 1.18.0-dev # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index a96683883c..f7a9c79620 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -40,7 +40,7 @@ PROJECT_NAME = "GRPC C++" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 1.17.0-dev +PROJECT_NUMBER = 1.18.0-dev # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a -- cgit v1.2.3 From d7eb26648d33912199573b842557b0e92ec3071f Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Sun, 4 Nov 2018 23:37:00 -0800 Subject: Client callback streaming --- include/grpcpp/generic/generic_stub.h | 5 + include/grpcpp/impl/codegen/callback_common.h | 8 +- include/grpcpp/impl/codegen/channel_interface.h | 12 + include/grpcpp/impl/codegen/client_callback.h | 505 +++++++++++++++++++++++ include/grpcpp/impl/codegen/client_context.h | 12 + src/compiler/cpp_generator.cc | 84 +++- src/cpp/client/generic_stub.cc | 12 + test/cpp/codegen/compiler_test_golden | 7 + test/cpp/end2end/client_callback_end2end_test.cc | 228 ++++++++++ 9 files changed, 860 insertions(+), 13 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/generic/generic_stub.h b/include/grpcpp/generic/generic_stub.h index d509d9a520..ccbf8a0e55 100644 --- a/include/grpcpp/generic/generic_stub.h +++ b/include/grpcpp/generic/generic_stub.h @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace grpc { @@ -76,6 +77,10 @@ class GenericStub final { const ByteBuffer* request, ByteBuffer* response, std::function on_completion); + experimental::ClientCallbackReaderWriter* + PrepareBidiStreamingCall(ClientContext* context, const grpc::string& method, + experimental::ClientBidiReactor* reactor); + private: GenericStub* stub_; }; diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index 51367cf550..f7a24204dc 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -145,18 +145,19 @@ class CallbackWithSuccessTag // or on a tag that has been Set before unless the tag has been cleared. void Set(grpc_call* call, std::function f, CompletionQueueTag* ops) { + GPR_CODEGEN_ASSERT(call_ == nullptr); + g_core_codegen_interface->grpc_call_ref(call); call_ = call; func_ = std::move(f); ops_ = ops; - g_core_codegen_interface->grpc_call_ref(call); functor_run = &CallbackWithSuccessTag::StaticRun; } void Clear() { if (call_ != nullptr) { - func_ = nullptr; grpc_call* call = call_; call_ = nullptr; + func_ = nullptr; g_core_codegen_interface->grpc_call_unref(call); } } @@ -182,10 +183,9 @@ class CallbackWithSuccessTag } void Run(bool ok) { void* ignored = ops_; - bool new_ok = ok; // Allow a "false" return value from FinalizeResult to silence the // callback, just as it silences a CQ tag in the async cases - bool do_callback = ops_->FinalizeResult(&ignored, &new_ok); + bool do_callback = ops_->FinalizeResult(&ignored, &ok); GPR_CODEGEN_ASSERT(ignored == ops_); if (do_callback) { diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index 6ec1ffb8c7..728a7b9049 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -53,6 +53,12 @@ template class ClientAsyncReaderWriterFactory; template class ClientAsyncResponseReaderFactory; +template +class ClientCallbackReaderWriterFactory; +template +class ClientCallbackReaderFactory; +template +class ClientCallbackWriterFactory; class InterceptedChannel; } // namespace internal @@ -106,6 +112,12 @@ class ChannelInterface { friend class ::grpc::internal::ClientAsyncReaderWriterFactory; template friend class ::grpc::internal::ClientAsyncResponseReaderFactory; + template + friend class ::grpc::internal::ClientCallbackReaderWriterFactory; + template + friend class ::grpc::internal::ClientCallbackReaderFactory; + template + friend class ::grpc::internal::ClientCallbackWriterFactory; template friend class ::grpc::internal::BlockingUnaryCallImpl; template diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 4baa819091..01a28a9e6f 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -88,6 +89,510 @@ class CallbackUnaryCallImpl { call.PerformOps(ops); } }; +} // namespace internal + +namespace experimental { + +// The user must implement this reactor interface with reactions to each event +// type that gets called by the library. An empty reaction is provided by +// default + +class ClientBidiReactor { + public: + virtual ~ClientBidiReactor() {} + virtual void OnDone(Status s) {} + virtual void OnReadInitialMetadataDone(bool ok) {} + virtual void OnReadDone(bool ok) {} + virtual void OnWriteDone(bool ok) {} + virtual void OnWritesDoneDone(bool ok) {} +}; + +class ClientReadReactor { + public: + virtual ~ClientReadReactor() {} + virtual void OnDone(Status s) {} + virtual void OnReadInitialMetadataDone(bool ok) {} + virtual void OnReadDone(bool ok) {} +}; + +class ClientWriteReactor { + public: + virtual ~ClientWriteReactor() {} + virtual void OnDone(Status s) {} + virtual void OnReadInitialMetadataDone(bool ok) {} + virtual void OnWriteDone(bool ok) {} + virtual void OnWritesDoneDone(bool ok) {} +}; + +template +class ClientCallbackReaderWriter { + public: + virtual ~ClientCallbackReaderWriter() {} + virtual void StartCall() = 0; + void Write(const Request* req) { Write(req, WriteOptions()); } + virtual void Write(const Request* req, WriteOptions options) = 0; + void WriteLast(const Request* req, WriteOptions options) { + Write(req, options.set_last_message()); + } + virtual void WritesDone() = 0; + virtual void Read(Response* resp) = 0; +}; + +template +class ClientCallbackReader { + public: + virtual ~ClientCallbackReader() {} + virtual void StartCall() = 0; + virtual void Read(Response* resp) = 0; +}; + +template +class ClientCallbackWriter { + public: + virtual ~ClientCallbackWriter() {} + virtual void StartCall() = 0; + void Write(const Request* req) { Write(req, WriteOptions()); } + virtual void Write(const Request* req, WriteOptions options) = 0; + void WriteLast(const Request* req, WriteOptions options) { + Write(req, options.set_last_message()); + } + virtual void WritesDone() = 0; +}; + +} // namespace experimental + +namespace internal { + +// Forward declare factory classes for friendship +template +class ClientCallbackReaderWriterFactory; +template +class ClientCallbackReaderFactory; +template +class ClientCallbackWriterFactory; + +template +class ClientCallbackReaderWriterImpl + : public ::grpc::experimental::ClientCallbackReaderWriter { + public: + // always allocated against a call arena, no memory free required + static void operator delete(void* ptr, std::size_t size) { + assert(size == sizeof(ClientCallbackReaderWriterImpl)); + } + + // This operator should never be called as the memory should be freed as part + // of the arena destruction. It only exists to provide a matching operator + // delete to the operator new so that some compilers will not complain (see + // https://github.com/grpc/grpc/issues/11301) Note at the time of adding this + // there are no tests catching the compiler warning. + static void operator delete(void*, void*) { assert(0); } + + void MaybeFinish() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(std::move(finish_status_)); + auto* call = call_.call(); + this->~ClientCallbackReaderWriterImpl(); + g_core_codegen_interface->grpc_call_unref(call); + } + } + + void StartCall() override { + // This call initiates two batches + // 1. Send initial metadata (unless corked)/recv initial metadata + // 2. Recv trailing metadata, on_completion callback + callbacks_outstanding_ = 2; + + start_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); + start_corked_ = context_->initial_metadata_corked_; + if (!start_corked_) { + start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + } + start_ops_.RecvInitialMetadata(context_); + start_ops_.set_core_cq_tag(&start_tag_); + call_.PerformOps(&start_ops_); + + finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, + &finish_ops_); + finish_ops_.ClientRecvStatus(context_, &finish_status_); + finish_ops_.set_core_cq_tag(&finish_tag_); + call_.PerformOps(&finish_ops_); + + // Also set up the read and write tags so that they don't have to be set up + // each time + write_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeFinish(); + }, + &write_ops_); + write_ops_.set_core_cq_tag(&write_tag_); + + read_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeFinish(); + }, + &read_ops_); + read_ops_.set_core_cq_tag(&read_tag_); + } + + void Read(Response* msg) override { + read_ops_.RecvMessage(msg); + callbacks_outstanding_++; + call_.PerformOps(&read_ops_); + } + + void Write(const Request* msg, WriteOptions options) override { + if (start_corked_) { + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + start_corked_ = false; + } + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*msg).ok()); + + if (options.is_last_message()) { + options.set_buffer_hint(); + write_ops_.ClientSendClose(); + } + callbacks_outstanding_++; + call_.PerformOps(&write_ops_); + } + void WritesDone() override { + if (start_corked_) { + writes_done_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + start_corked_ = false; + } + writes_done_ops_.ClientSendClose(); + writes_done_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWritesDoneDone(ok); + MaybeFinish(); + }, + &writes_done_ops_); + writes_done_ops_.set_core_cq_tag(&writes_done_tag_); + callbacks_outstanding_++; + call_.PerformOps(&writes_done_ops_); + } + + private: + friend class ClientCallbackReaderWriterFactory; + + ClientCallbackReaderWriterImpl( + Call call, ClientContext* context, + ::grpc::experimental::ClientBidiReactor* reactor) + : context_(context), call_(call), reactor_(reactor) {} + + ClientContext* context_; + Call call_; + ::grpc::experimental::ClientBidiReactor* reactor_; + + CallOpSet start_ops_; + CallbackWithSuccessTag start_tag_; + bool start_corked_; + + CallOpSet finish_ops_; + CallbackWithSuccessTag finish_tag_; + Status finish_status_; + + CallOpSet + write_ops_; + CallbackWithSuccessTag write_tag_; + + CallOpSet writes_done_ops_; + CallbackWithSuccessTag writes_done_tag_; + + CallOpSet> read_ops_; + CallbackWithSuccessTag read_tag_; + + std::atomic_int callbacks_outstanding_; +}; + +template +class ClientCallbackReaderWriterFactory { + public: + static experimental::ClientCallbackReaderWriter* Create( + ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, + ClientContext* context, + ::grpc::experimental::ClientBidiReactor* reactor) { + Call call = channel->CreateCall(method, context, channel->CallbackCQ()); + + g_core_codegen_interface->grpc_call_ref(call.call()); + return new (g_core_codegen_interface->grpc_call_arena_alloc( + call.call(), sizeof(ClientCallbackReaderWriterImpl))) + ClientCallbackReaderWriterImpl(call, context, + reactor); + } +}; + +template +class ClientCallbackReaderImpl + : public ::grpc::experimental::ClientCallbackReader { + public: + // always allocated against a call arena, no memory free required + static void operator delete(void* ptr, std::size_t size) { + assert(size == sizeof(ClientCallbackReaderImpl)); + } + + // This operator should never be called as the memory should be freed as part + // of the arena destruction. It only exists to provide a matching operator + // delete to the operator new so that some compilers will not complain (see + // https://github.com/grpc/grpc/issues/11301) Note at the time of adding this + // there are no tests catching the compiler warning. + static void operator delete(void*, void*) { assert(0); } + + void MaybeFinish() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(std::move(finish_status_)); + auto* call = call_.call(); + this->~ClientCallbackReaderImpl(); + g_core_codegen_interface->grpc_call_unref(call); + } + } + + void StartCall() override { + // This call initiates two batches + // 1. Send initial metadata (unless corked)/recv initial metadata + // 2. Recv trailing metadata, on_completion callback + callbacks_outstanding_ = 2; + + start_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); + start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + start_ops_.RecvInitialMetadata(context_); + start_ops_.set_core_cq_tag(&start_tag_); + call_.PerformOps(&start_ops_); + + finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, + &finish_ops_); + finish_ops_.ClientRecvStatus(context_, &finish_status_); + finish_ops_.set_core_cq_tag(&finish_tag_); + call_.PerformOps(&finish_ops_); + + // Also set up the read tag so it doesn't have to be set up each time + read_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeFinish(); + }, + &read_ops_); + read_ops_.set_core_cq_tag(&read_tag_); + } + + void Read(Response* msg) override { + read_ops_.RecvMessage(msg); + callbacks_outstanding_++; + call_.PerformOps(&read_ops_); + } + + private: + friend class ClientCallbackReaderFactory; + + template + ClientCallbackReaderImpl(Call call, ClientContext* context, Request* request, + ::grpc::experimental::ClientReadReactor* reactor) + : context_(context), call_(call), reactor_(reactor) { + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(start_ops_.SendMessage(*request).ok()); + start_ops_.ClientSendClose(); + } + + ClientContext* context_; + Call call_; + ::grpc::experimental::ClientReadReactor* reactor_; + + CallOpSet + start_ops_; + CallbackWithSuccessTag start_tag_; + + CallOpSet finish_ops_; + CallbackWithSuccessTag finish_tag_; + Status finish_status_; + + CallOpSet> read_ops_; + CallbackWithSuccessTag read_tag_; + + std::atomic_int callbacks_outstanding_; +}; + +template +class ClientCallbackReaderFactory { + public: + template + static experimental::ClientCallbackReader* Create( + ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, + ClientContext* context, const Request* request, + ::grpc::experimental::ClientReadReactor* reactor) { + Call call = channel->CreateCall(method, context, channel->CallbackCQ()); + + g_core_codegen_interface->grpc_call_ref(call.call()); + return new (g_core_codegen_interface->grpc_call_arena_alloc( + call.call(), sizeof(ClientCallbackReaderImpl))) + ClientCallbackReaderImpl(call, context, request, reactor); + } +}; + +template +class ClientCallbackWriterImpl + : public ::grpc::experimental::ClientCallbackWriter { + public: + // always allocated against a call arena, no memory free required + static void operator delete(void* ptr, std::size_t size) { + assert(size == sizeof(ClientCallbackWriterImpl)); + } + + // This operator should never be called as the memory should be freed as part + // of the arena destruction. It only exists to provide a matching operator + // delete to the operator new so that some compilers will not complain (see + // https://github.com/grpc/grpc/issues/11301) Note at the time of adding this + // there are no tests catching the compiler warning. + static void operator delete(void*, void*) { assert(0); } + + void MaybeFinish() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(std::move(finish_status_)); + auto* call = call_.call(); + this->~ClientCallbackWriterImpl(); + g_core_codegen_interface->grpc_call_unref(call); + } + } + + void StartCall() override { + // This call initiates two batches + // 1. Send initial metadata (unless corked)/recv initial metadata + // 2. Recv message + trailing metadata, on_completion callback + callbacks_outstanding_ = 2; + + start_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); + start_corked_ = context_->initial_metadata_corked_; + if (!start_corked_) { + start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + } + start_ops_.RecvInitialMetadata(context_); + start_ops_.set_core_cq_tag(&start_tag_); + call_.PerformOps(&start_ops_); + + finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, + &finish_ops_); + finish_ops_.ClientRecvStatus(context_, &finish_status_); + finish_ops_.set_core_cq_tag(&finish_tag_); + call_.PerformOps(&finish_ops_); + + // Also set up the read and write tags so that they don't have to be set up + // each time + write_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeFinish(); + }, + &write_ops_); + write_ops_.set_core_cq_tag(&write_tag_); + } + + void Write(const Request* msg, WriteOptions options) override { + if (start_corked_) { + write_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + start_corked_ = false; + } + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*msg).ok()); + + if (options.is_last_message()) { + options.set_buffer_hint(); + write_ops_.ClientSendClose(); + } + callbacks_outstanding_++; + call_.PerformOps(&write_ops_); + } + void WritesDone() override { + if (start_corked_) { + writes_done_ops_.SendInitialMetadata(&context_->send_initial_metadata_, + context_->initial_metadata_flags()); + start_corked_ = false; + } + writes_done_ops_.ClientSendClose(); + writes_done_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWritesDoneDone(ok); + MaybeFinish(); + }, + &writes_done_ops_); + writes_done_ops_.set_core_cq_tag(&writes_done_tag_); + callbacks_outstanding_++; + call_.PerformOps(&writes_done_ops_); + } + + private: + friend class ClientCallbackWriterFactory; + + template + ClientCallbackWriterImpl(Call call, ClientContext* context, + Response* response, + ::grpc::experimental::ClientWriteReactor* reactor) + : context_(context), call_(call), reactor_(reactor) { + finish_ops_.RecvMessage(response); + finish_ops_.AllowNoMessage(); + } + + ClientContext* context_; + Call call_; + ::grpc::experimental::ClientWriteReactor* reactor_; + + CallOpSet start_ops_; + CallbackWithSuccessTag start_tag_; + bool start_corked_; + + CallOpSet finish_ops_; + CallbackWithSuccessTag finish_tag_; + Status finish_status_; + + CallOpSet + write_ops_; + CallbackWithSuccessTag write_tag_; + + CallOpSet writes_done_ops_; + CallbackWithSuccessTag writes_done_tag_; + + std::atomic_int callbacks_outstanding_; +}; + +template +class ClientCallbackWriterFactory { + public: + template + static experimental::ClientCallbackWriter* Create( + ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, + ClientContext* context, Response* response, + ::grpc::experimental::ClientWriteReactor* reactor) { + Call call = channel->CreateCall(method, context, channel->CallbackCQ()); + + g_core_codegen_interface->grpc_call_ref(call.call()); + return new (g_core_codegen_interface->grpc_call_arena_alloc( + call.call(), sizeof(ClientCallbackWriterImpl))) + ClientCallbackWriterImpl(call, context, response, reactor); + } +}; } // namespace internal } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 75b955e760..6059c3c58a 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -71,6 +71,12 @@ template class BlockingUnaryCallImpl; template class CallbackUnaryCallImpl; +template +class ClientCallbackReaderWriterImpl; +template +class ClientCallbackReaderImpl; +template +class ClientCallbackWriterImpl; } // namespace internal template @@ -394,6 +400,12 @@ class ClientContext { friend class ::grpc::internal::BlockingUnaryCallImpl; template friend class ::grpc::internal::CallbackUnaryCallImpl; + template + friend class ::grpc::internal::ClientCallbackReaderWriterImpl; + template + friend class ::grpc::internal::ClientCallbackReaderImpl; + template + friend class ::grpc::internal::ClientCallbackWriterImpl; // Used by friend class CallOpClientRecvStatus void set_debug_error_string(const grpc::string& debug_error_string) { diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 7986aca696..473e57166f 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -132,6 +132,7 @@ grpc::string GetHeaderIncludes(grpc_generator::File* file, "grpcpp/impl/codegen/async_generic_service.h", "grpcpp/impl/codegen/async_stream.h", "grpcpp/impl/codegen/async_unary_call.h", + "grpcpp/impl/codegen/client_callback.h", "grpcpp/impl/codegen/method_handler_impl.h", "grpcpp/impl/codegen/proto_utils.h", "grpcpp/impl/codegen/rpc_method.h", @@ -580,11 +581,23 @@ void PrintHeaderClientMethodCallbackInterfaces( "const $Request$* request, $Response$* response, " "std::function) = 0;\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add support for client-side streaming + printer->Print(*vars, + "virtual ::grpc::experimental::ClientCallbackWriter< " + "$Request$>* $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor* reactor) = 0;\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add support for server-side streaming + printer->Print(*vars, + "virtual ::grpc::experimental::ClientCallbackReader< " + "$Response$>* $Method$(::grpc::ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor* reactor) = 0;\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add support for bidi streaming + printer->Print( + *vars, + "virtual ::grpc::experimental::ClientCallbackReaderWriter< $Request$, " + "$Response$>* $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor* reactor) = 0;\n"); } } @@ -631,11 +644,26 @@ void PrintHeaderClientMethodCallback(grpc_generator::Printer* printer, "const $Request$* request, $Response$* response, " "std::function) override;\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add support for client-side streaming + printer->Print( + *vars, + "::grpc::experimental::ClientCallbackWriter< $Request$>* " + "$Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor* reactor) override;\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add support for server-side streaming + printer->Print( + *vars, + "::grpc::experimental::ClientCallbackReader< $Response$>* " + "$Method$(::grpc::ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor* reactor) override;\n"); + } else if (method->BidiStreaming()) { - // TODO(vjpai): Add support for bidi streaming + printer->Print( + *vars, + "::grpc::experimental::ClientCallbackReaderWriter< $Request$, " + "$Response$>* $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor* reactor) override;\n"); } } @@ -1607,7 +1635,20 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context, response);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print( + *vars, + "::grpc::experimental::ClientCallbackWriter< $Request$>* " + "$ns$$Service$::" + "Stub::experimental_async::$Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor* reactor) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientCallbackWriterFactory< " + "$Request$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, response, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; @@ -1641,7 +1682,19 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context, request);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print(*vars, + "::grpc::experimental::ClientCallbackReader< $Response$>* " + "$ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor* reactor) {\n"); + printer->Print(*vars, + " return ::grpc::internal::ClientCallbackReaderFactory< " + "$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, request, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; @@ -1675,7 +1728,20 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context);\n" "}\n\n"); - // TODO(vjpai): Add callback version + printer->Print(*vars, + "::grpc::experimental::ClientCallbackReaderWriter< " + "$Request$,$Response$>* " + "$ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "::grpc::experimental::ClientBidiReactor* reactor) {\n"); + printer->Print( + *vars, + " return ::grpc::internal::ClientCallbackReaderWriterFactory< " + "$Request$,$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; diff --git a/src/cpp/client/generic_stub.cc b/src/cpp/client/generic_stub.cc index 87902b26f0..f029daec65 100644 --- a/src/cpp/client/generic_stub.cc +++ b/src/cpp/client/generic_stub.cc @@ -72,4 +72,16 @@ void GenericStub::experimental_type::UnaryCall( context, request, response, std::move(on_completion)); } +experimental::ClientCallbackReaderWriter* +GenericStub::experimental_type::PrepareBidiStreamingCall( + ClientContext* context, const grpc::string& method, + experimental::ClientBidiReactor* reactor) { + return internal::ClientCallbackReaderWriterFactory< + ByteBuffer, ByteBuffer>::Create(stub_->channel_.get(), + internal::RpcMethod( + method.c_str(), + internal::RpcMethod::BIDI_STREAMING), + context, reactor); +} + } // namespace grpc diff --git a/test/cpp/codegen/compiler_test_golden b/test/cpp/codegen/compiler_test_golden index fdc67969d9..7a25f51d10 100644 --- a/test/cpp/codegen/compiler_test_golden +++ b/test/cpp/codegen/compiler_test_golden @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -117,10 +118,13 @@ class ServiceA final { // // Method A2 leading comment 1 // Method A2 leading comment 2 + virtual ::grpc::experimental::ClientCallbackWriter< ::grpc::testing::Request>* MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::experimental::ClientWriteReactor* reactor) = 0; // MethodA2 trailing comment 1 // Method A3 leading comment 1 + virtual ::grpc::experimental::ClientCallbackReader< ::grpc::testing::Response>* MethodA3(::grpc::ClientContext* context, ::grpc::testing::Request* request, ::grpc::experimental::ClientReadReactor* reactor) = 0; // Method A3 trailing comment 1 // Method A4 leading comment 1 + virtual ::grpc::experimental::ClientCallbackReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4(::grpc::ClientContext* context, ::grpc::experimental::ClientBidiReactor* reactor) = 0; // Method A4 trailing comment 1 }; virtual class experimental_async_interface* experimental_async() { return nullptr; } @@ -178,6 +182,9 @@ class ServiceA final { public StubInterface::experimental_async_interface { public: void MethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, std::function) override; + ::grpc::experimental::ClientCallbackWriter< ::grpc::testing::Request>* MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::experimental::ClientWriteReactor* reactor) override; + ::grpc::experimental::ClientCallbackReader< ::grpc::testing::Response>* MethodA3(::grpc::ClientContext* context, ::grpc::testing::Request* request, ::grpc::experimental::ClientReadReactor* reactor) override; + ::grpc::experimental::ClientCallbackReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4(::grpc::ClientContext* context, ::grpc::experimental::ClientBidiReactor* reactor) override; private: friend class Stub; explicit experimental_async(Stub* stub): stub_(stub) { } diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index a35991396a..e39fc5aab2 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -182,6 +182,59 @@ class ClientCallbackEnd2endTest } } + void SendGenericEchoAsBidi(int num_rpcs) { + const grpc::string kMethodName("/grpc.testing.EchoTestService/Echo"); + grpc::string test_string(""); + for (int i = 0; i < num_rpcs; i++) { + test_string += "Hello world. "; + class Client : public grpc::experimental::ClientBidiReactor { + public: + Client(ClientCallbackEnd2endTest* test, const grpc::string& method_name, + const grpc::string& test_str) { + stream_ = + test->generic_stub_->experimental().PrepareBidiStreamingCall( + &cli_ctx_, method_name, this); + stream_->StartCall(); + request_.set_message(test_str); + send_buf_ = SerializeToByteBuffer(&request_); + stream_->Read(&recv_buf_); + stream_->Write(send_buf_.get()); + } + void OnWriteDone(bool ok) override { stream_->WritesDone(); } + void OnReadDone(bool ok) override { + EchoResponse response; + EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); + EXPECT_EQ(request_.message(), response.message()); + }; + void OnDone(Status s) override { + // The stream is invalid once OnDone is called + stream_ = nullptr; + EXPECT_TRUE(s.ok()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + EchoRequest request_; + std::unique_ptr send_buf_; + ByteBuffer recv_buf_; + ClientContext cli_ctx_; + experimental::ClientCallbackReaderWriter* + stream_; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } rpc{this, kMethodName, test_string}; + + rpc.Await(); + } + } bool is_server_started_; std::shared_ptr channel_; std::unique_ptr stub_; @@ -211,6 +264,11 @@ TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) { SendRpcsGeneric(10, false); } +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) { + ResetStub(); + SendGenericEchoAsBidi(10); +} + #if GRPC_ALLOW_EXCEPTIONS TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) { ResetStub(); @@ -267,6 +325,176 @@ TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { } } +TEST_P(ClientCallbackEnd2endTest, RequestStream) { + // TODO(vjpai): test with callback server once supported + if (GetParam().callback_server) { + return; + } + + ResetStub(); + class Client : public grpc::experimental::ClientWriteReactor { + public: + explicit Client(grpc::testing::EchoTestService::Stub* stub) { + context_.set_initial_metadata_corked(true); + stream_ = stub->experimental_async()->RequestStream(&context_, &response_, + this); + stream_->StartCall(); + request_.set_message("Hello server."); + stream_->Write(&request_); + } + void OnWriteDone(bool ok) override { + writes_left_--; + if (writes_left_ > 1) { + stream_->Write(&request_); + } else if (writes_left_ == 1) { + stream_->WriteLast(&request_, WriteOptions()); + } + } + void OnDone(Status s) override { + stream_ = nullptr; + EXPECT_TRUE(s.ok()); + EXPECT_EQ(response_.message(), "Hello server.Hello server.Hello server."); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + ::grpc::experimental::ClientCallbackWriter* stream_; + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + int writes_left_{3}; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } test{stub_.get()}; + + test.Await(); +} + +TEST_P(ClientCallbackEnd2endTest, ResponseStream) { + // TODO(vjpai): test with callback server once supported + if (GetParam().callback_server) { + return; + } + + ResetStub(); + class Client : public grpc::experimental::ClientReadReactor { + public: + explicit Client(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello client "); + stream_ = stub->experimental_async()->ResponseStream(&context_, &request_, + this); + stream_->StartCall(); + stream_->Read(&response_); + } + void OnReadDone(bool ok) override { + // Note that != is the boolean XOR operator + EXPECT_NE(ok, reads_complete_ == kServerDefaultResponseStreamsToSend); + if (ok) { + EXPECT_EQ(response_.message(), + request_.message() + grpc::to_string(reads_complete_)); + reads_complete_++; + stream_->Read(&response_); + } + } + void OnDone(Status s) override { + stream_ = nullptr; + EXPECT_TRUE(s.ok()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + ::grpc::experimental::ClientCallbackReader* stream_; + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + int reads_complete_{0}; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } test{stub_.get()}; + + test.Await(); +} + +TEST_P(ClientCallbackEnd2endTest, BidiStream) { + // TODO(vjpai): test with callback server once supported + if (GetParam().callback_server) { + return; + } + ResetStub(); + class Client : public grpc::experimental::ClientBidiReactor { + public: + explicit Client(grpc::testing::EchoTestService::Stub* stub) { + request_.set_message("Hello fren "); + stream_ = stub->experimental_async()->BidiStream(&context_, this); + stream_->StartCall(); + stream_->Read(&response_); + stream_->Write(&request_); + } + void OnReadDone(bool ok) override { + // Note that != is the boolean XOR operator + EXPECT_NE(ok, reads_complete_ == kServerDefaultResponseStreamsToSend); + if (ok) { + EXPECT_EQ(response_.message(), request_.message()); + reads_complete_++; + stream_->Read(&response_); + } + } + void OnWriteDone(bool ok) override { + EXPECT_TRUE(ok); + if (++writes_complete_ == kServerDefaultResponseStreamsToSend) { + stream_->WritesDone(); + } else { + stream_->Write(&request_); + } + } + void OnDone(Status s) override { + stream_ = nullptr; + EXPECT_TRUE(s.ok()); + std::unique_lock l(mu_); + done_ = true; + cv_.notify_one(); + } + void Await() { + std::unique_lock l(mu_); + while (!done_) { + cv_.wait(l); + } + } + + private: + ::grpc::experimental::ClientCallbackReaderWriter* + stream_; + EchoRequest request_; + EchoResponse response_; + ClientContext context_; + int reads_complete_{0}; + int writes_complete_{0}; + std::mutex mu_; + std::condition_variable cv_; + bool done_ = false; + } test{stub_.get()}; + + test.Await(); +} + TestScenario scenarios[] = {TestScenario{false}, TestScenario{true}}; INSTANTIATE_TEST_CASE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest, -- cgit v1.2.3 From ea1156da3fb3d6fb8660d078e70cf5486fc71a65 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Fri, 30 Nov 2018 02:16:08 -0800 Subject: Stop exposing streaming object class --- include/grpcpp/generic/generic_stub.h | 6 +- include/grpcpp/impl/codegen/client_callback.h | 302 +++--- src/compiler/cpp_generator.cc | 96 +- src/cpp/client/generic_stub.cc | 15 +- test/cpp/codegen/compiler_test_golden | 1105 +++++++++++++++------- test/cpp/end2end/client_callback_end2end_test.cc | 63 +- 6 files changed, 1053 insertions(+), 534 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/generic/generic_stub.h b/include/grpcpp/generic/generic_stub.h index ccbf8a0e55..eb014184e4 100644 --- a/include/grpcpp/generic/generic_stub.h +++ b/include/grpcpp/generic/generic_stub.h @@ -77,9 +77,9 @@ class GenericStub final { const ByteBuffer* request, ByteBuffer* response, std::function on_completion); - experimental::ClientCallbackReaderWriter* - PrepareBidiStreamingCall(ClientContext* context, const grpc::string& method, - experimental::ClientBidiReactor* reactor); + void PrepareBidiStreamingCall( + ClientContext* context, const grpc::string& method, + experimental::ClientBidiReactor* reactor); private: GenericStub* stub_; diff --git a/include/grpcpp/impl/codegen/client_callback.h b/include/grpcpp/impl/codegen/client_callback.h index 92a588b3c3..999c1c8a3e 100644 --- a/include/grpcpp/impl/codegen/client_callback.h +++ b/include/grpcpp/impl/codegen/client_callback.h @@ -93,9 +93,67 @@ class CallbackUnaryCallImpl { namespace experimental { +// Forward declarations +template +class ClientBidiReactor; +template +class ClientReadReactor; +template +class ClientWriteReactor; + +// NOTE: The streaming objects are not actually implemented in the public API. +// These interfaces are provided for mocking only. Typical applications +// will interact exclusively with the reactors that they define. +template +class ClientCallbackReaderWriter { + public: + virtual ~ClientCallbackReaderWriter() {} + virtual void StartCall() = 0; + virtual void Write(const Request* req, WriteOptions options) = 0; + virtual void WritesDone() = 0; + virtual void Read(Response* resp) = 0; + + protected: + void BindReactor(ClientBidiReactor* reactor) { + reactor->BindStream(this); + } +}; + +template +class ClientCallbackReader { + public: + virtual ~ClientCallbackReader() {} + virtual void StartCall() = 0; + virtual void Read(Response* resp) = 0; + + protected: + void BindReactor(ClientReadReactor* reactor) { + reactor->BindReader(this); + } +}; + +template +class ClientCallbackWriter { + public: + virtual ~ClientCallbackWriter() {} + virtual void StartCall() = 0; + void Write(const Request* req) { Write(req, WriteOptions()); } + virtual void Write(const Request* req, WriteOptions options) = 0; + void WriteLast(const Request* req, WriteOptions options) { + Write(req, options.set_last_message()); + } + virtual void WritesDone() = 0; + + protected: + void BindReactor(ClientWriteReactor* reactor) { + reactor->BindWriter(this); + } +}; + // The user must implement this reactor interface with reactions to each event // type that gets called by the library. An empty reaction is provided by // default +template class ClientBidiReactor { public: virtual ~ClientBidiReactor() {} @@ -104,16 +162,44 @@ class ClientBidiReactor { virtual void OnReadDone(bool ok) {} virtual void OnWriteDone(bool ok) {} virtual void OnWritesDoneDone(bool ok) {} + + void StartCall() { stream_->StartCall(); } + void StartRead(Response* resp) { stream_->Read(resp); } + void StartWrite(const Request* req) { StartWrite(req, WriteOptions()); } + void StartWrite(const Request* req, WriteOptions options) { + stream_->Write(req, std::move(options)); + } + void StartWriteLast(const Request* req, WriteOptions options) { + StartWrite(req, std::move(options.set_last_message())); + } + void StartWritesDone() { stream_->WritesDone(); } + + private: + friend class ClientCallbackReaderWriter; + void BindStream(ClientCallbackReaderWriter* stream) { + stream_ = stream; + } + ClientCallbackReaderWriter* stream_; }; +template class ClientReadReactor { public: virtual ~ClientReadReactor() {} virtual void OnDone(Status s) {} virtual void OnReadInitialMetadataDone(bool ok) {} virtual void OnReadDone(bool ok) {} + + void StartCall() { reader_->StartCall(); } + void StartRead(Response* resp) { reader_->Read(resp); } + + private: + friend class ClientCallbackReader; + void BindReader(ClientCallbackReader* reader) { reader_ = reader; } + ClientCallbackReader* reader_; }; +template class ClientWriteReactor { public: virtual ~ClientWriteReactor() {} @@ -121,41 +207,21 @@ class ClientWriteReactor { virtual void OnReadInitialMetadataDone(bool ok) {} virtual void OnWriteDone(bool ok) {} virtual void OnWritesDoneDone(bool ok) {} -}; -template -class ClientCallbackReaderWriter { - public: - virtual ~ClientCallbackReaderWriter() {} - virtual void StartCall() = 0; - void Write(const Request* req) { Write(req, WriteOptions()); } - virtual void Write(const Request* req, WriteOptions options) = 0; - void WriteLast(const Request* req, WriteOptions options) { - Write(req, options.set_last_message()); + void StartCall() { writer_->StartCall(); } + void StartWrite(const Request* req) { StartWrite(req, WriteOptions()); } + void StartWrite(const Request* req, WriteOptions options) { + writer_->Write(req, std::move(options)); } - virtual void WritesDone() = 0; - virtual void Read(Response* resp) = 0; -}; - -template -class ClientCallbackReader { - public: - virtual ~ClientCallbackReader() {} - virtual void StartCall() = 0; - virtual void Read(Response* resp) = 0; -}; - -template -class ClientCallbackWriter { - public: - virtual ~ClientCallbackWriter() {} - virtual void StartCall() = 0; - void Write(const Request* req) { Write(req, WriteOptions()); } - virtual void Write(const Request* req, WriteOptions options) = 0; - void WriteLast(const Request* req, WriteOptions options) { - Write(req, options.set_last_message()); + void StartWriteLast(const Request* req, WriteOptions options) { + StartWrite(req, std::move(options.set_last_message())); } - virtual void WritesDone() = 0; + void StartWritesDone() { writer_->WritesDone(); } + + private: + friend class ClientCallbackWriter; + void BindWriter(ClientCallbackWriter* writer) { writer_ = writer; } + ClientCallbackWriter* writer_; }; } // namespace experimental @@ -204,12 +270,13 @@ class ClientCallbackReaderWriterImpl // 4. Any write backlog started_ = true; - start_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnReadInitialMetadataDone(ok); - MaybeFinish(); - }, - &start_ops_); + start_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); if (!start_corked_) { start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); @@ -220,27 +287,29 @@ class ClientCallbackReaderWriterImpl // Also set up the read and write tags so that they don't have to be set up // each time - write_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnWriteDone(ok); - MaybeFinish(); - }, - &write_ops_); + write_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeFinish(); + }, + &write_ops_); write_ops_.set_core_cq_tag(&write_tag_); - read_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnReadDone(ok); - MaybeFinish(); - }, - &read_ops_); + read_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeFinish(); + }, + &read_ops_); read_ops_.set_core_cq_tag(&read_tag_); if (read_ops_at_start_) { call_.PerformOps(&read_ops_); } - finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, - &finish_ops_); + finish_tag_.Set( + call_.call(), [this](bool ok) { MaybeFinish(); }, &finish_ops_); finish_ops_.ClientRecvStatus(context_, &finish_status_); finish_ops_.set_core_cq_tag(&finish_tag_); call_.PerformOps(&finish_ops_); @@ -291,12 +360,13 @@ class ClientCallbackReaderWriterImpl start_corked_ = false; } writes_done_ops_.ClientSendClose(); - writes_done_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnWritesDoneDone(ok); - MaybeFinish(); - }, - &writes_done_ops_); + writes_done_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnWritesDoneDone(ok); + MaybeFinish(); + }, + &writes_done_ops_); writes_done_ops_.set_core_cq_tag(&writes_done_tag_); callbacks_outstanding_++; if (started_) { @@ -311,15 +381,17 @@ class ClientCallbackReaderWriterImpl ClientCallbackReaderWriterImpl( Call call, ClientContext* context, - ::grpc::experimental::ClientBidiReactor* reactor) + ::grpc::experimental::ClientBidiReactor* reactor) : context_(context), call_(call), reactor_(reactor), - start_corked_(context_->initial_metadata_corked_) {} + start_corked_(context_->initial_metadata_corked_) { + this->BindReactor(reactor); + } ClientContext* context_; Call call_; - ::grpc::experimental::ClientBidiReactor* reactor_; + ::grpc::experimental::ClientBidiReactor* reactor_; CallOpSet start_ops_; CallbackWithSuccessTag start_tag_; @@ -350,14 +422,14 @@ class ClientCallbackReaderWriterImpl template class ClientCallbackReaderWriterFactory { public: - static experimental::ClientCallbackReaderWriter* Create( + static void Create( ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, ClientContext* context, - ::grpc::experimental::ClientBidiReactor* reactor) { + ::grpc::experimental::ClientBidiReactor* reactor) { Call call = channel->CreateCall(method, context, channel->CallbackCQ()); g_core_codegen_interface->grpc_call_ref(call.call()); - return new (g_core_codegen_interface->grpc_call_arena_alloc( + new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientCallbackReaderWriterImpl))) ClientCallbackReaderWriterImpl(call, context, reactor); @@ -396,12 +468,13 @@ class ClientCallbackReaderImpl // 3. Recv trailing metadata, on_completion callback started_ = true; - start_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnReadInitialMetadataDone(ok); - MaybeFinish(); - }, - &start_ops_); + start_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); start_ops_.RecvInitialMetadata(context_); @@ -409,19 +482,20 @@ class ClientCallbackReaderImpl call_.PerformOps(&start_ops_); // Also set up the read tag so it doesn't have to be set up each time - read_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnReadDone(ok); - MaybeFinish(); - }, - &read_ops_); + read_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeFinish(); + }, + &read_ops_); read_ops_.set_core_cq_tag(&read_tag_); if (read_ops_at_start_) { call_.PerformOps(&read_ops_); } - finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, - &finish_ops_); + finish_tag_.Set( + call_.call(), [this](bool ok) { MaybeFinish(); }, &finish_ops_); finish_ops_.ClientRecvStatus(context_, &finish_status_); finish_ops_.set_core_cq_tag(&finish_tag_); call_.PerformOps(&finish_ops_); @@ -441,9 +515,11 @@ class ClientCallbackReaderImpl friend class ClientCallbackReaderFactory; template - ClientCallbackReaderImpl(Call call, ClientContext* context, Request* request, - ::grpc::experimental::ClientReadReactor* reactor) + ClientCallbackReaderImpl( + Call call, ClientContext* context, Request* request, + ::grpc::experimental::ClientReadReactor* reactor) : context_(context), call_(call), reactor_(reactor) { + this->BindReactor(reactor); // TODO(vjpai): don't assert GPR_CODEGEN_ASSERT(start_ops_.SendMessage(*request).ok()); start_ops_.ClientSendClose(); @@ -451,7 +527,7 @@ class ClientCallbackReaderImpl ClientContext* context_; Call call_; - ::grpc::experimental::ClientReadReactor* reactor_; + ::grpc::experimental::ClientReadReactor* reactor_; CallOpSet @@ -475,14 +551,14 @@ template class ClientCallbackReaderFactory { public: template - static experimental::ClientCallbackReader* Create( + static void Create( ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, ClientContext* context, const Request* request, - ::grpc::experimental::ClientReadReactor* reactor) { + ::grpc::experimental::ClientReadReactor* reactor) { Call call = channel->CreateCall(method, context, channel->CallbackCQ()); g_core_codegen_interface->grpc_call_ref(call.call()); - return new (g_core_codegen_interface->grpc_call_arena_alloc( + new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientCallbackReaderImpl))) ClientCallbackReaderImpl(call, context, request, reactor); } @@ -520,12 +596,13 @@ class ClientCallbackWriterImpl // 3. Any backlog started_ = true; - start_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnReadInitialMetadataDone(ok); - MaybeFinish(); - }, - &start_ops_); + start_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnReadInitialMetadataDone(ok); + MaybeFinish(); + }, + &start_ops_); if (!start_corked_) { start_ops_.SendInitialMetadata(&context_->send_initial_metadata_, context_->initial_metadata_flags()); @@ -536,16 +613,17 @@ class ClientCallbackWriterImpl // Also set up the read and write tags so that they don't have to be set up // each time - write_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnWriteDone(ok); - MaybeFinish(); - }, - &write_ops_); + write_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeFinish(); + }, + &write_ops_); write_ops_.set_core_cq_tag(&write_tag_); - finish_tag_.Set(call_.call(), [this](bool ok) { MaybeFinish(); }, - &finish_ops_); + finish_tag_.Set( + call_.call(), [this](bool ok) { MaybeFinish(); }, &finish_ops_); finish_ops_.ClientRecvStatus(context_, &finish_status_); finish_ops_.set_core_cq_tag(&finish_tag_); call_.PerformOps(&finish_ops_); @@ -586,12 +664,13 @@ class ClientCallbackWriterImpl start_corked_ = false; } writes_done_ops_.ClientSendClose(); - writes_done_tag_.Set(call_.call(), - [this](bool ok) { - reactor_->OnWritesDoneDone(ok); - MaybeFinish(); - }, - &writes_done_ops_); + writes_done_tag_.Set( + call_.call(), + [this](bool ok) { + reactor_->OnWritesDoneDone(ok); + MaybeFinish(); + }, + &writes_done_ops_); writes_done_ops_.set_core_cq_tag(&writes_done_tag_); callbacks_outstanding_++; if (started_) { @@ -605,20 +684,21 @@ class ClientCallbackWriterImpl friend class ClientCallbackWriterFactory; template - ClientCallbackWriterImpl(Call call, ClientContext* context, - Response* response, - ::grpc::experimental::ClientWriteReactor* reactor) + ClientCallbackWriterImpl( + Call call, ClientContext* context, Response* response, + ::grpc::experimental::ClientWriteReactor* reactor) : context_(context), call_(call), reactor_(reactor), start_corked_(context_->initial_metadata_corked_) { + this->BindReactor(reactor); finish_ops_.RecvMessage(response); finish_ops_.AllowNoMessage(); } ClientContext* context_; Call call_; - ::grpc::experimental::ClientWriteReactor* reactor_; + ::grpc::experimental::ClientWriteReactor* reactor_; CallOpSet start_ops_; CallbackWithSuccessTag start_tag_; @@ -646,14 +726,14 @@ template class ClientCallbackWriterFactory { public: template - static experimental::ClientCallbackWriter* Create( + static void Create( ChannelInterface* channel, const ::grpc::internal::RpcMethod& method, ClientContext* context, Response* response, - ::grpc::experimental::ClientWriteReactor* reactor) { + ::grpc::experimental::ClientWriteReactor* reactor) { Call call = channel->CreateCall(method, context, channel->CallbackCQ()); g_core_codegen_interface->grpc_call_ref(call.call()); - return new (g_core_codegen_interface->grpc_call_arena_alloc( + new (g_core_codegen_interface->grpc_call_arena_alloc( call.call(), sizeof(ClientCallbackWriterImpl))) ClientCallbackWriterImpl(call, context, response, reactor); } diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 473e57166f..a368b47f01 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -582,22 +582,21 @@ void PrintHeaderClientMethodCallbackInterfaces( "std::function) = 0;\n"); } else if (ClientOnlyStreaming(method)) { printer->Print(*vars, - "virtual ::grpc::experimental::ClientCallbackWriter< " - "$Request$>* $Method$(::grpc::ClientContext* context, " + "virtual void $Method$(::grpc::ClientContext* context, " "$Response$* response, " - "::grpc::experimental::ClientWriteReactor* reactor) = 0;\n"); + "::grpc::experimental::ClientWriteReactor< $Request$>* " + "reactor) = 0;\n"); } else if (ServerOnlyStreaming(method)) { printer->Print(*vars, - "virtual ::grpc::experimental::ClientCallbackReader< " - "$Response$>* $Method$(::grpc::ClientContext* context, " + "virtual void $Method$(::grpc::ClientContext* context, " "$Request$* request, " - "::grpc::experimental::ClientReadReactor* reactor) = 0;\n"); + "::grpc::experimental::ClientReadReactor< $Response$>* " + "reactor) = 0;\n"); } else if (method->BidiStreaming()) { - printer->Print( - *vars, - "virtual ::grpc::experimental::ClientCallbackReaderWriter< $Request$, " - "$Response$>* $Method$(::grpc::ClientContext* context, " - "::grpc::experimental::ClientBidiReactor* reactor) = 0;\n"); + printer->Print(*vars, + "virtual void $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< " + "$Request$,$Response$>* reactor) = 0;\n"); } } @@ -644,26 +643,23 @@ void PrintHeaderClientMethodCallback(grpc_generator::Printer* printer, "const $Request$* request, $Response$* response, " "std::function) override;\n"); } else if (ClientOnlyStreaming(method)) { - printer->Print( - *vars, - "::grpc::experimental::ClientCallbackWriter< $Request$>* " - "$Method$(::grpc::ClientContext* context, " - "$Response$* response, " - "::grpc::experimental::ClientWriteReactor* reactor) override;\n"); + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "$Response$* response, " + "::grpc::experimental::ClientWriteReactor< $Request$>* " + "reactor) override;\n"); } else if (ServerOnlyStreaming(method)) { - printer->Print( - *vars, - "::grpc::experimental::ClientCallbackReader< $Response$>* " - "$Method$(::grpc::ClientContext* context, " - "$Request$* request, " - "::grpc::experimental::ClientReadReactor* reactor) override;\n"); + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor< $Response$>* " + "reactor) override;\n"); } else if (method->BidiStreaming()) { - printer->Print( - *vars, - "::grpc::experimental::ClientCallbackReaderWriter< $Request$, " - "$Response$>* $Method$(::grpc::ClientContext* context, " - "::grpc::experimental::ClientBidiReactor* reactor) override;\n"); + printer->Print(*vars, + "void $Method$(::grpc::ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< " + "$Request$,$Response$>* reactor) override;\n"); } } @@ -1637,13 +1633,12 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, printer->Print( *vars, - "::grpc::experimental::ClientCallbackWriter< $Request$>* " - "$ns$$Service$::" + "void $ns$$Service$::" "Stub::experimental_async::$Method$(::grpc::ClientContext* context, " "$Response$* response, " - "::grpc::experimental::ClientWriteReactor* reactor) {\n"); + "::grpc::experimental::ClientWriteReactor< $Request$>* reactor) {\n"); printer->Print(*vars, - " return ::grpc::internal::ClientCallbackWriterFactory< " + " ::grpc::internal::ClientCallbackWriterFactory< " "$Request$>::Create(" "stub_->channel_.get(), " "stub_->rpcmethod_$Method$_, " @@ -1682,14 +1677,14 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context, request);\n" "}\n\n"); + printer->Print( + *vars, + "void $ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "$Request$* request, " + "::grpc::experimental::ClientReadReactor< $Response$>* reactor) {\n"); printer->Print(*vars, - "::grpc::experimental::ClientCallbackReader< $Response$>* " - "$ns$$Service$::Stub::experimental_async::$Method$(::grpc::" - "ClientContext* context, " - "$Request$* request, " - "::grpc::experimental::ClientReadReactor* reactor) {\n"); - printer->Print(*vars, - " return ::grpc::internal::ClientCallbackReaderFactory< " + " ::grpc::internal::ClientCallbackReaderFactory< " "$Response$>::Create(" "stub_->channel_.get(), " "stub_->rpcmethod_$Method$_, " @@ -1728,20 +1723,19 @@ void PrintSourceClientMethod(grpc_generator::Printer* printer, "context);\n" "}\n\n"); - printer->Print(*vars, - "::grpc::experimental::ClientCallbackReaderWriter< " - "$Request$,$Response$>* " - "$ns$$Service$::Stub::experimental_async::$Method$(::grpc::" - "ClientContext* context, " - "::grpc::experimental::ClientBidiReactor* reactor) {\n"); printer->Print( *vars, - " return ::grpc::internal::ClientCallbackReaderWriterFactory< " - "$Request$,$Response$>::Create(" - "stub_->channel_.get(), " - "stub_->rpcmethod_$Method$_, " - "context, reactor);\n" - "}\n\n"); + "void $ns$$Service$::Stub::experimental_async::$Method$(::grpc::" + "ClientContext* context, " + "::grpc::experimental::ClientBidiReactor< $Request$,$Response$>* " + "reactor) {\n"); + printer->Print(*vars, + " ::grpc::internal::ClientCallbackReaderWriterFactory< " + "$Request$,$Response$>::Create(" + "stub_->channel_.get(), " + "stub_->rpcmethod_$Method$_, " + "context, reactor);\n" + "}\n\n"); for (auto async_prefix : async_prefixes) { (*vars)["AsyncPrefix"] = async_prefix.prefix; diff --git a/src/cpp/client/generic_stub.cc b/src/cpp/client/generic_stub.cc index f029daec65..f61c1b5317 100644 --- a/src/cpp/client/generic_stub.cc +++ b/src/cpp/client/generic_stub.cc @@ -72,16 +72,13 @@ void GenericStub::experimental_type::UnaryCall( context, request, response, std::move(on_completion)); } -experimental::ClientCallbackReaderWriter* -GenericStub::experimental_type::PrepareBidiStreamingCall( +void GenericStub::experimental_type::PrepareBidiStreamingCall( ClientContext* context, const grpc::string& method, - experimental::ClientBidiReactor* reactor) { - return internal::ClientCallbackReaderWriterFactory< - ByteBuffer, ByteBuffer>::Create(stub_->channel_.get(), - internal::RpcMethod( - method.c_str(), - internal::RpcMethod::BIDI_STREAMING), - context, reactor); + experimental::ClientBidiReactor* reactor) { + internal::ClientCallbackReaderWriterFactory::Create( + stub_->channel_.get(), + internal::RpcMethod(method.c_str(), internal::RpcMethod::BIDI_STREAMING), + context, reactor); } } // namespace grpc diff --git a/test/cpp/codegen/compiler_test_golden b/test/cpp/codegen/compiler_test_golden index 7a25f51d10..1a5fe27932 100644 --- a/test/cpp/codegen/compiler_test_golden +++ b/test/cpp/codegen/compiler_test_golden @@ -26,7 +26,6 @@ #include "src/proto/grpc/testing/compiler_test.pb.h" -#include #include #include #include @@ -39,6 +38,7 @@ #include #include #include +#include namespace grpc { class CompletionQueue; @@ -64,294 +64,556 @@ class ServiceA final { public: virtual ~StubInterface() {} // MethodA1 leading comment 1 - virtual ::grpc::Status MethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::testing::Response* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>> AsyncMethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>>(AsyncMethodA1Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>> PrepareAsyncMethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>>(PrepareAsyncMethodA1Raw(context, request, cq)); + virtual ::grpc::Status MethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::testing::Response* response) = 0; + std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface<::grpc::testing::Response>> + AsyncMethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>>(AsyncMethodA1Raw(context, request, cq)); + } + std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface<::grpc::testing::Response>> + PrepareAsyncMethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>>( + PrepareAsyncMethodA1Raw(context, request, cq)); } // MethodA1 trailing comment 1 // MethodA2 detached leading comment 1 // // Method A2 leading comment 1 // Method A2 leading comment 2 - std::unique_ptr< ::grpc::ClientWriterInterface< ::grpc::testing::Request>> MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response) { - return std::unique_ptr< ::grpc::ClientWriterInterface< ::grpc::testing::Request>>(MethodA2Raw(context, response)); - } - std::unique_ptr< ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>> AsyncMethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>>(AsyncMethodA2Raw(context, response, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>> PrepareAsyncMethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>>(PrepareAsyncMethodA2Raw(context, response, cq)); + std::unique_ptr<::grpc::ClientWriterInterface<::grpc::testing::Request>> + MethodA2(::grpc::ClientContext* context, + ::grpc::testing::Response* response) { + return std::unique_ptr< + ::grpc::ClientWriterInterface<::grpc::testing::Request>>( + MethodA2Raw(context, response)); + } + std::unique_ptr< + ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>> + AsyncMethodA2(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq, void* tag) { + return std::unique_ptr< + ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>>( + AsyncMethodA2Raw(context, response, cq, tag)); + } + std::unique_ptr< + ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>> + PrepareAsyncMethodA2(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>>( + PrepareAsyncMethodA2Raw(context, response, cq)); } // MethodA2 trailing comment 1 // Method A3 leading comment 1 - std::unique_ptr< ::grpc::ClientReaderInterface< ::grpc::testing::Response>> MethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request) { - return std::unique_ptr< ::grpc::ClientReaderInterface< ::grpc::testing::Response>>(MethodA3Raw(context, request)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>> AsyncMethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>>(AsyncMethodA3Raw(context, request, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>> PrepareAsyncMethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>>(PrepareAsyncMethodA3Raw(context, request, cq)); + std::unique_ptr<::grpc::ClientReaderInterface<::grpc::testing::Response>> + MethodA3(::grpc::ClientContext* context, + const ::grpc::testing::Request& request) { + return std::unique_ptr< + ::grpc::ClientReaderInterface<::grpc::testing::Response>>( + MethodA3Raw(context, request)); + } + std::unique_ptr< + ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>> + AsyncMethodA3(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq, void* tag) { + return std::unique_ptr< + ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>>( + AsyncMethodA3Raw(context, request, cq, tag)); + } + std::unique_ptr< + ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>> + PrepareAsyncMethodA3(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>>( + PrepareAsyncMethodA3Raw(context, request, cq)); } // Method A3 trailing comment 1 // Method A4 leading comment 1 - std::unique_ptr< ::grpc::ClientReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>> MethodA4(::grpc::ClientContext* context) { - return std::unique_ptr< ::grpc::ClientReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>>(MethodA4Raw(context)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>> AsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>>(AsyncMethodA4Raw(context, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>> PrepareAsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>>(PrepareAsyncMethodA4Raw(context, cq)); + std::unique_ptr<::grpc::ClientReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>> + MethodA4(::grpc::ClientContext* context) { + return std::unique_ptr<::grpc::ClientReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>>( + MethodA4Raw(context)); + } + std::unique_ptr<::grpc::ClientAsyncReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>> + AsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, + void* tag) { + return std::unique_ptr<::grpc::ClientAsyncReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>>( + AsyncMethodA4Raw(context, cq, tag)); + } + std::unique_ptr<::grpc::ClientAsyncReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>> + PrepareAsyncMethodA4(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncReaderWriterInterface< + ::grpc::testing::Request, ::grpc::testing::Response>>( + PrepareAsyncMethodA4Raw(context, cq)); } // Method A4 trailing comment 1 class experimental_async_interface { public: virtual ~experimental_async_interface() {} // MethodA1 leading comment 1 - virtual void MethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, std::function) = 0; + virtual void MethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + std::function) = 0; // MethodA1 trailing comment 1 // MethodA2 detached leading comment 1 // // Method A2 leading comment 1 // Method A2 leading comment 2 - virtual ::grpc::experimental::ClientCallbackWriter< ::grpc::testing::Request>* MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::experimental::ClientWriteReactor* reactor) = 0; + virtual void MethodA2( + ::grpc::ClientContext* context, ::grpc::testing::Response* response, + ::grpc::experimental::ClientWriteReactor<::grpc::testing::Request>* + reactor) = 0; // MethodA2 trailing comment 1 // Method A3 leading comment 1 - virtual ::grpc::experimental::ClientCallbackReader< ::grpc::testing::Response>* MethodA3(::grpc::ClientContext* context, ::grpc::testing::Request* request, ::grpc::experimental::ClientReadReactor* reactor) = 0; + virtual void MethodA3( + ::grpc::ClientContext* context, ::grpc::testing::Request* request, + ::grpc::experimental::ClientReadReactor<::grpc::testing::Response>* + reactor) = 0; // Method A3 trailing comment 1 // Method A4 leading comment 1 - virtual ::grpc::experimental::ClientCallbackReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4(::grpc::ClientContext* context, ::grpc::experimental::ClientBidiReactor* reactor) = 0; + virtual void MethodA4( + ::grpc::ClientContext* context, + ::grpc::experimental::ClientBidiReactor<::grpc::testing::Request, + ::grpc::testing::Response>* + reactor) = 0; // Method A4 trailing comment 1 }; - virtual class experimental_async_interface* experimental_async() { return nullptr; } - private: - virtual ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>* AsyncMethodA1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>* PrepareAsyncMethodA1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientWriterInterface< ::grpc::testing::Request>* MethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response) = 0; - virtual ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>* AsyncMethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq, void* tag) = 0; - virtual ::grpc::ClientAsyncWriterInterface< ::grpc::testing::Request>* PrepareAsyncMethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientReaderInterface< ::grpc::testing::Response>* MethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request) = 0; - virtual ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>* AsyncMethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq, void* tag) = 0; - virtual ::grpc::ClientAsyncReaderInterface< ::grpc::testing::Response>* PrepareAsyncMethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4Raw(::grpc::ClientContext* context) = 0; - virtual ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>* AsyncMethodA4Raw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, void* tag) = 0; - virtual ::grpc::ClientAsyncReaderWriterInterface< ::grpc::testing::Request, ::grpc::testing::Response>* PrepareAsyncMethodA4Raw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq) = 0; + virtual class experimental_async_interface* experimental_async() { + return nullptr; + } + + private: + virtual ::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>* + AsyncMethodA1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>* + PrepareAsyncMethodA1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientWriterInterface<::grpc::testing::Request>* + MethodA2Raw(::grpc::ClientContext* context, + ::grpc::testing::Response* response) = 0; + virtual ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>* + AsyncMethodA2Raw(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq, void* tag) = 0; + virtual ::grpc::ClientAsyncWriterInterface<::grpc::testing::Request>* + PrepareAsyncMethodA2Raw(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientReaderInterface<::grpc::testing::Response>* + MethodA3Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request) = 0; + virtual ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>* + AsyncMethodA3Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq, void* tag) = 0; + virtual ::grpc::ClientAsyncReaderInterface<::grpc::testing::Response>* + PrepareAsyncMethodA3Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientReaderWriterInterface<::grpc::testing::Request, + ::grpc::testing::Response>* + MethodA4Raw(::grpc::ClientContext* context) = 0; + virtual ::grpc::ClientAsyncReaderWriterInterface<::grpc::testing::Request, + ::grpc::testing::Response>* + AsyncMethodA4Raw(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq, void* tag) = 0; + virtual ::grpc::ClientAsyncReaderWriterInterface<::grpc::testing::Request, + ::grpc::testing::Response>* + PrepareAsyncMethodA4Raw(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq) = 0; }; class Stub final : public StubInterface { public: - Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status MethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::testing::Response* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>> AsyncMethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>>(AsyncMethodA1Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>> PrepareAsyncMethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>>(PrepareAsyncMethodA1Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientWriter< ::grpc::testing::Request>> MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response) { - return std::unique_ptr< ::grpc::ClientWriter< ::grpc::testing::Request>>(MethodA2Raw(context, response)); - } - std::unique_ptr< ::grpc::ClientAsyncWriter< ::grpc::testing::Request>> AsyncMethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncWriter< ::grpc::testing::Request>>(AsyncMethodA2Raw(context, response, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncWriter< ::grpc::testing::Request>> PrepareAsyncMethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncWriter< ::grpc::testing::Request>>(PrepareAsyncMethodA2Raw(context, response, cq)); - } - std::unique_ptr< ::grpc::ClientReader< ::grpc::testing::Response>> MethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request) { - return std::unique_ptr< ::grpc::ClientReader< ::grpc::testing::Response>>(MethodA3Raw(context, request)); - } - std::unique_ptr< ::grpc::ClientAsyncReader< ::grpc::testing::Response>> AsyncMethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncReader< ::grpc::testing::Response>>(AsyncMethodA3Raw(context, request, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncReader< ::grpc::testing::Response>> PrepareAsyncMethodA3(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncReader< ::grpc::testing::Response>>(PrepareAsyncMethodA3Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>> MethodA4(::grpc::ClientContext* context) { - return std::unique_ptr< ::grpc::ClientReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>>(MethodA4Raw(context)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>> AsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, void* tag) { - return std::unique_ptr< ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>>(AsyncMethodA4Raw(context, cq, tag)); - } - std::unique_ptr< ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>> PrepareAsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>>(PrepareAsyncMethodA4Raw(context, cq)); - } - class experimental_async final : - public StubInterface::experimental_async_interface { + Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); + ::grpc::Status MethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::testing::Response* response) override; + std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>> + AsyncMethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>>( + AsyncMethodA1Raw(context, request, cq)); + } + std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>> + PrepareAsyncMethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>>( + PrepareAsyncMethodA1Raw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientWriter<::grpc::testing::Request>> MethodA2( + ::grpc::ClientContext* context, ::grpc::testing::Response* response) { + return std::unique_ptr<::grpc::ClientWriter<::grpc::testing::Request>>( + MethodA2Raw(context, response)); + } + std::unique_ptr<::grpc::ClientAsyncWriter<::grpc::testing::Request>> + AsyncMethodA2(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq, void* tag) { + return std::unique_ptr< + ::grpc::ClientAsyncWriter<::grpc::testing::Request>>( + AsyncMethodA2Raw(context, response, cq, tag)); + } + std::unique_ptr<::grpc::ClientAsyncWriter<::grpc::testing::Request>> + PrepareAsyncMethodA2(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncWriter<::grpc::testing::Request>>( + PrepareAsyncMethodA2Raw(context, response, cq)); + } + std::unique_ptr<::grpc::ClientReader<::grpc::testing::Response>> MethodA3( + ::grpc::ClientContext* context, + const ::grpc::testing::Request& request) { + return std::unique_ptr<::grpc::ClientReader<::grpc::testing::Response>>( + MethodA3Raw(context, request)); + } + std::unique_ptr<::grpc::ClientAsyncReader<::grpc::testing::Response>> + AsyncMethodA3(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq, void* tag) { + return std::unique_ptr< + ::grpc::ClientAsyncReader<::grpc::testing::Response>>( + AsyncMethodA3Raw(context, request, cq, tag)); + } + std::unique_ptr<::grpc::ClientAsyncReader<::grpc::testing::Response>> + PrepareAsyncMethodA3(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncReader<::grpc::testing::Response>>( + PrepareAsyncMethodA3Raw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>> + MethodA4(::grpc::ClientContext* context) { + return std::unique_ptr<::grpc::ClientReaderWriter< + ::grpc::testing::Request, ::grpc::testing::Response>>( + MethodA4Raw(context)); + } + std::unique_ptr<::grpc::ClientAsyncReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>> + AsyncMethodA4(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, + void* tag) { + return std::unique_ptr<::grpc::ClientAsyncReaderWriter< + ::grpc::testing::Request, ::grpc::testing::Response>>( + AsyncMethodA4Raw(context, cq, tag)); + } + std::unique_ptr<::grpc::ClientAsyncReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>> + PrepareAsyncMethodA4(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncReaderWriter< + ::grpc::testing::Request, ::grpc::testing::Response>>( + PrepareAsyncMethodA4Raw(context, cq)); + } + class experimental_async final + : public StubInterface::experimental_async_interface { public: - void MethodA1(::grpc::ClientContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, std::function) override; - ::grpc::experimental::ClientCallbackWriter< ::grpc::testing::Request>* MethodA2(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::experimental::ClientWriteReactor* reactor) override; - ::grpc::experimental::ClientCallbackReader< ::grpc::testing::Response>* MethodA3(::grpc::ClientContext* context, ::grpc::testing::Request* request, ::grpc::experimental::ClientReadReactor* reactor) override; - ::grpc::experimental::ClientCallbackReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4(::grpc::ClientContext* context, ::grpc::experimental::ClientBidiReactor* reactor) override; + void MethodA1(::grpc::ClientContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + std::function) override; + void MethodA2( + ::grpc::ClientContext* context, ::grpc::testing::Response* response, + ::grpc::experimental::ClientWriteReactor<::grpc::testing::Request>* + reactor) override; + void MethodA3( + ::grpc::ClientContext* context, ::grpc::testing::Request* request, + ::grpc::experimental::ClientReadReactor<::grpc::testing::Response>* + reactor) override; + void MethodA4(::grpc::ClientContext* context, + ::grpc::experimental::ClientBidiReactor< + ::grpc::testing::Request, ::grpc::testing::Response>* + reactor) override; + private: friend class Stub; - explicit experimental_async(Stub* stub): stub_(stub) { } + explicit experimental_async(Stub* stub) : stub_(stub) {} Stub* stub() { return stub_; } Stub* stub_; }; - class experimental_async_interface* experimental_async() override { return &async_stub_; } - - private: - std::shared_ptr< ::grpc::ChannelInterface> channel_; - class experimental_async async_stub_{this}; - ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>* AsyncMethodA1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>* PrepareAsyncMethodA1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientWriter< ::grpc::testing::Request>* MethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response) override; - ::grpc::ClientAsyncWriter< ::grpc::testing::Request>* AsyncMethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq, void* tag) override; - ::grpc::ClientAsyncWriter< ::grpc::testing::Request>* PrepareAsyncMethodA2Raw(::grpc::ClientContext* context, ::grpc::testing::Response* response, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientReader< ::grpc::testing::Response>* MethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request) override; - ::grpc::ClientAsyncReader< ::grpc::testing::Response>* AsyncMethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq, void* tag) override; - ::grpc::ClientAsyncReader< ::grpc::testing::Response>* PrepareAsyncMethodA3Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4Raw(::grpc::ClientContext* context) override; - ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* AsyncMethodA4Raw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq, void* tag) override; - ::grpc::ClientAsyncReaderWriter< ::grpc::testing::Request, ::grpc::testing::Response>* PrepareAsyncMethodA4Raw(::grpc::ClientContext* context, ::grpc::CompletionQueue* cq) override; + class experimental_async_interface* experimental_async() override { + return &async_stub_; + } + + private: + std::shared_ptr<::grpc::ChannelInterface> channel_; + class experimental_async async_stub_ { + this + }; + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>* + AsyncMethodA1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>* + PrepareAsyncMethodA1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientWriter<::grpc::testing::Request>* MethodA2Raw( + ::grpc::ClientContext* context, + ::grpc::testing::Response* response) override; + ::grpc::ClientAsyncWriter<::grpc::testing::Request>* AsyncMethodA2Raw( + ::grpc::ClientContext* context, ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq, void* tag) override; + ::grpc::ClientAsyncWriter<::grpc::testing::Request>* + PrepareAsyncMethodA2Raw(::grpc::ClientContext* context, + ::grpc::testing::Response* response, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientReader<::grpc::testing::Response>* MethodA3Raw( + ::grpc::ClientContext* context, + const ::grpc::testing::Request& request) override; + ::grpc::ClientAsyncReader<::grpc::testing::Response>* AsyncMethodA3Raw( + ::grpc::ClientContext* context, const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq, void* tag) override; + ::grpc::ClientAsyncReader<::grpc::testing::Response>* + PrepareAsyncMethodA3Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>* + MethodA4Raw(::grpc::ClientContext* context) override; + ::grpc::ClientAsyncReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>* + AsyncMethodA4Raw(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq, void* tag) override; + ::grpc::ClientAsyncReaderWriter<::grpc::testing::Request, + ::grpc::testing::Response>* + PrepareAsyncMethodA4Raw(::grpc::ClientContext* context, + ::grpc::CompletionQueue* cq) override; const ::grpc::internal::RpcMethod rpcmethod_MethodA1_; const ::grpc::internal::RpcMethod rpcmethod_MethodA2_; const ::grpc::internal::RpcMethod rpcmethod_MethodA3_; const ::grpc::internal::RpcMethod rpcmethod_MethodA4_; }; - static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); + static std::unique_ptr NewStub( + const std::shared_ptr<::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); class Service : public ::grpc::Service { public: Service(); virtual ~Service(); // MethodA1 leading comment 1 - virtual ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response); + virtual ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response); // MethodA1 trailing comment 1 // MethodA2 detached leading comment 1 // // Method A2 leading comment 1 // Method A2 leading comment 2 - virtual ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response); + virtual ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response); // MethodA2 trailing comment 1 // Method A3 leading comment 1 - virtual ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer); + virtual ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer); // Method A3 trailing comment 1 // Method A4 leading comment 1 - virtual ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream); + virtual ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream); // Method A4 trailing comment 1 }; template class WithAsyncMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_MethodA1() { - ::grpc::Service::MarkMethodAsync(0); - } + WithAsyncMethod_MethodA1() { ::grpc::Service::MarkMethodAsync(0); } ~WithAsyncMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA1(::grpc::ServerContext* context, ::grpc::testing::Request* request, ::grpc::ServerAsyncResponseWriter< ::grpc::testing::Response>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestMethodA1( + ::grpc::ServerContext* context, ::grpc::testing::Request* request, + ::grpc::ServerAsyncResponseWriter<::grpc::testing::Response>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_MethodA2 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_MethodA2() { - ::grpc::Service::MarkMethodAsync(1); - } + WithAsyncMethod_MethodA2() { ::grpc::Service::MarkMethodAsync(1); } ~WithAsyncMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA2(::grpc::ServerContext* context, ::grpc::ServerAsyncReader< ::grpc::testing::Response, ::grpc::testing::Request>* reader, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncClientStreaming(1, context, reader, new_call_cq, notification_cq, tag); + void RequestMethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerAsyncReader<::grpc::testing::Response, + ::grpc::testing::Request>* reader, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncClientStreaming( + 1, context, reader, new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_MethodA3() { - ::grpc::Service::MarkMethodAsync(2); - } + WithAsyncMethod_MethodA3() { ::grpc::Service::MarkMethodAsync(2); } ~WithAsyncMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA3(::grpc::ServerContext* context, ::grpc::testing::Request* request, ::grpc::ServerAsyncWriter< ::grpc::testing::Response>* writer, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncServerStreaming(2, context, request, writer, new_call_cq, notification_cq, tag); + void RequestMethodA3( + ::grpc::ServerContext* context, ::grpc::testing::Request* request, + ::grpc::ServerAsyncWriter<::grpc::testing::Response>* writer, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncServerStreaming( + 2, context, request, writer, new_call_cq, notification_cq, tag); } }; template class WithAsyncMethod_MethodA4 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_MethodA4() { - ::grpc::Service::MarkMethodAsync(3); - } + WithAsyncMethod_MethodA4() { ::grpc::Service::MarkMethodAsync(3); } ~WithAsyncMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA4(::grpc::ServerContext* context, ::grpc::ServerAsyncReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncBidiStreaming(3, context, stream, new_call_cq, notification_cq, tag); + void RequestMethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerAsyncReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncBidiStreaming( + 3, context, stream, new_call_cq, notification_cq, tag); } }; - typedef WithAsyncMethod_MethodA1 > > > AsyncService; + typedef WithAsyncMethod_MethodA1>>> + AsyncService; template class ExperimentalWithCallbackMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: ExperimentalWithCallbackMethod_MethodA1() { - ::grpc::Service::experimental().MarkMethodCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodA1, ::grpc::testing::Request, ::grpc::testing::Response>( - [this](::grpc::ServerContext* context, - const ::grpc::testing::Request* request, - ::grpc::testing::Response* response, - ::grpc::experimental::ServerCallbackRpcController* controller) { + ::grpc::Service::experimental().MarkMethodCallback( + 0, new ::grpc::internal::CallbackUnaryHandler< + ExperimentalWithCallbackMethod_MethodA1, + ::grpc::testing::Request, ::grpc::testing::Response>( + [this](::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* + controller) { this->MethodA1(context, request, response, controller); - }, this)); + }, + this)); } ~ExperimentalWithCallbackMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - virtual void MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + virtual void MethodA1( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); + } }; template class ExperimentalWithCallbackMethod_MethodA2 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithCallbackMethod_MethodA2() { - } + ExperimentalWithCallbackMethod_MethodA2() {} ~ExperimentalWithCallbackMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -359,15 +621,17 @@ class ServiceA final { template class ExperimentalWithCallbackMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithCallbackMethod_MethodA3() { - } + ExperimentalWithCallbackMethod_MethodA3() {} ~ExperimentalWithCallbackMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -375,33 +639,41 @@ class ServiceA final { template class ExperimentalWithCallbackMethod_MethodA4 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithCallbackMethod_MethodA4() { - } + ExperimentalWithCallbackMethod_MethodA4() {} ~ExperimentalWithCallbackMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } }; - typedef ExperimentalWithCallbackMethod_MethodA1 > > > ExperimentalCallbackService; + typedef ExperimentalWithCallbackMethod_MethodA1< + ExperimentalWithCallbackMethod_MethodA2< + ExperimentalWithCallbackMethod_MethodA3< + ExperimentalWithCallbackMethod_MethodA4>>> + ExperimentalCallbackService; template class WithGenericMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_MethodA1() { - ::grpc::Service::MarkMethodGeneric(0); - } + WithGenericMethod_MethodA1() { ::grpc::Service::MarkMethodGeneric(0); } ~WithGenericMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -409,16 +681,18 @@ class ServiceA final { template class WithGenericMethod_MethodA2 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_MethodA2() { - ::grpc::Service::MarkMethodGeneric(1); - } + WithGenericMethod_MethodA2() { ::grpc::Service::MarkMethodGeneric(1); } ~WithGenericMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -426,16 +700,17 @@ class ServiceA final { template class WithGenericMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_MethodA3() { - ::grpc::Service::MarkMethodGeneric(2); - } + WithGenericMethod_MethodA3() { ::grpc::Service::MarkMethodGeneric(2); } ~WithGenericMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -443,16 +718,18 @@ class ServiceA final { template class WithGenericMethod_MethodA4 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_MethodA4() { - ::grpc::Service::MarkMethodGeneric(3); - } + WithGenericMethod_MethodA4() { ::grpc::Service::MarkMethodGeneric(3); } ~WithGenericMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -460,120 +737,164 @@ class ServiceA final { template class WithRawMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithRawMethod_MethodA1() { - ::grpc::Service::MarkMethodRaw(0); - } + WithRawMethod_MethodA1() { ::grpc::Service::MarkMethodRaw(0); } ~WithRawMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA1(::grpc::ServerContext* context, ::grpc::ByteBuffer* request, ::grpc::ServerAsyncResponseWriter< ::grpc::ByteBuffer>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestMethodA1( + ::grpc::ServerContext* context, ::grpc::ByteBuffer* request, + ::grpc::ServerAsyncResponseWriter<::grpc::ByteBuffer>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; template class WithRawMethod_MethodA2 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithRawMethod_MethodA2() { - ::grpc::Service::MarkMethodRaw(1); - } + WithRawMethod_MethodA2() { ::grpc::Service::MarkMethodRaw(1); } ~WithRawMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA2(::grpc::ServerContext* context, ::grpc::ServerAsyncReader< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* reader, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncClientStreaming(1, context, reader, new_call_cq, notification_cq, tag); + void RequestMethodA2(::grpc::ServerContext* context, + ::grpc::ServerAsyncReader<::grpc::ByteBuffer, + ::grpc::ByteBuffer>* reader, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, + void* tag) { + ::grpc::Service::RequestAsyncClientStreaming( + 1, context, reader, new_call_cq, notification_cq, tag); } }; template class WithRawMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithRawMethod_MethodA3() { - ::grpc::Service::MarkMethodRaw(2); - } + WithRawMethod_MethodA3() { ::grpc::Service::MarkMethodRaw(2); } ~WithRawMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA3(::grpc::ServerContext* context, ::grpc::ByteBuffer* request, ::grpc::ServerAsyncWriter< ::grpc::ByteBuffer>* writer, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncServerStreaming(2, context, request, writer, new_call_cq, notification_cq, tag); + void RequestMethodA3(::grpc::ServerContext* context, + ::grpc::ByteBuffer* request, + ::grpc::ServerAsyncWriter<::grpc::ByteBuffer>* writer, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, + void* tag) { + ::grpc::Service::RequestAsyncServerStreaming( + 2, context, request, writer, new_call_cq, notification_cq, tag); } }; template class WithRawMethod_MethodA4 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithRawMethod_MethodA4() { - ::grpc::Service::MarkMethodRaw(3); - } + WithRawMethod_MethodA4() { ::grpc::Service::MarkMethodRaw(3); } ~WithRawMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodA4(::grpc::ServerContext* context, ::grpc::ServerAsyncReaderWriter< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* stream, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncBidiStreaming(3, context, stream, new_call_cq, notification_cq, tag); + void RequestMethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerAsyncReaderWriter<::grpc::ByteBuffer, ::grpc::ByteBuffer>* + stream, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncBidiStreaming( + 3, context, stream, new_call_cq, notification_cq, tag); } }; template class ExperimentalWithRawCallbackMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: ExperimentalWithRawCallbackMethod_MethodA1() { - ::grpc::Service::experimental().MarkMethodRawCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodA1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( - [this](::grpc::ServerContext* context, - const ::grpc::ByteBuffer* request, - ::grpc::ByteBuffer* response, - ::grpc::experimental::ServerCallbackRpcController* controller) { + ::grpc::Service::experimental().MarkMethodRawCallback( + 0, new ::grpc::internal::CallbackUnaryHandler< + ExperimentalWithRawCallbackMethod_MethodA1, + ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this](::grpc::ServerContext* context, + const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* + controller) { this->MethodA1(context, request, response, controller); - }, this)); + }, + this)); } ~ExperimentalWithRawCallbackMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - virtual void MethodA1(::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + virtual void MethodA1( + ::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); + } }; template class ExperimentalWithRawCallbackMethod_MethodA2 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithRawCallbackMethod_MethodA2() { - } + ExperimentalWithRawCallbackMethod_MethodA2() {} ~ExperimentalWithRawCallbackMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA2( + ::grpc::ServerContext* context, + ::grpc::ServerReader<::grpc::testing::Request>* reader, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -581,15 +902,17 @@ class ServiceA final { template class ExperimentalWithRawCallbackMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithRawCallbackMethod_MethodA3() { - } + ExperimentalWithRawCallbackMethod_MethodA3() {} ~ExperimentalWithRawCallbackMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -597,15 +920,18 @@ class ServiceA final { template class ExperimentalWithRawCallbackMethod_MethodA4 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - ExperimentalWithRawCallbackMethod_MethodA4() { - } + ExperimentalWithRawCallbackMethod_MethodA4() {} ~ExperimentalWithRawCallbackMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + ::grpc::Status MethodA4( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter<::grpc::testing::Response, + ::grpc::testing::Request>* stream) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -613,46 +939,69 @@ class ServiceA final { template class WithStreamedUnaryMethod_MethodA1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_MethodA1() { - ::grpc::Service::MarkMethodStreamed(0, - new ::grpc::internal::StreamedUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(std::bind(&WithStreamedUnaryMethod_MethodA1::StreamedMethodA1, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 0, new ::grpc::internal::StreamedUnaryHandler< + ::grpc::testing::Request, ::grpc::testing::Response>(std::bind( + &WithStreamedUnaryMethod_MethodA1::StreamedMethodA1, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodA1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedMethodA1(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< ::grpc::testing::Request,::grpc::testing::Response>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedMethodA1( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer<::grpc::testing::Request, + ::grpc::testing::Response>* + server_unary_streamer) = 0; }; - typedef WithStreamedUnaryMethod_MethodA1 StreamedUnaryService; + typedef WithStreamedUnaryMethod_MethodA1 StreamedUnaryService; template class WithSplitStreamingMethod_MethodA3 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithSplitStreamingMethod_MethodA3() { - ::grpc::Service::MarkMethodStreamed(2, - new ::grpc::internal::SplitServerStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>(std::bind(&WithSplitStreamingMethod_MethodA3::StreamedMethodA3, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 2, + new ::grpc::internal::SplitServerStreamingHandler< + ::grpc::testing::Request, ::grpc::testing::Response>(std::bind( + &WithSplitStreamingMethod_MethodA3::StreamedMethodA3, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithSplitStreamingMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + ::grpc::Status MethodA3( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::ServerWriter<::grpc::testing::Response>* writer) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with split streamed - virtual ::grpc::Status StreamedMethodA3(::grpc::ServerContext* context, ::grpc::ServerSplitStreamer< ::grpc::testing::Request,::grpc::testing::Response>* server_split_streamer) = 0; + virtual ::grpc::Status StreamedMethodA3( + ::grpc::ServerContext* context, + ::grpc::ServerSplitStreamer<::grpc::testing::Request, + ::grpc::testing::Response>* + server_split_streamer) = 0; }; - typedef WithSplitStreamingMethod_MethodA3 SplitStreamedService; - typedef WithStreamedUnaryMethod_MethodA1 > StreamedService; + typedef WithSplitStreamingMethod_MethodA3 SplitStreamedService; + typedef WithStreamedUnaryMethod_MethodA1< + WithSplitStreamingMethod_MethodA3> + StreamedService; }; // ServiceB leading comment 1 @@ -665,125 +1014,204 @@ class ServiceB final { public: virtual ~StubInterface() {} // MethodB1 leading comment 1 - virtual ::grpc::Status MethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::testing::Response* response) = 0; - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>> AsyncMethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>>(AsyncMethodB1Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>> PrepareAsyncMethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>>(PrepareAsyncMethodB1Raw(context, request, cq)); + virtual ::grpc::Status MethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::testing::Response* response) = 0; + std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface<::grpc::testing::Response>> + AsyncMethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>>(AsyncMethodB1Raw(context, request, cq)); + } + std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface<::grpc::testing::Response>> + PrepareAsyncMethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>>( + PrepareAsyncMethodB1Raw(context, request, cq)); } // MethodB1 trailing comment 1 class experimental_async_interface { public: virtual ~experimental_async_interface() {} // MethodB1 leading comment 1 - virtual void MethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, std::function) = 0; + virtual void MethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + std::function) = 0; // MethodB1 trailing comment 1 }; - virtual class experimental_async_interface* experimental_async() { return nullptr; } - private: - virtual ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>* AsyncMethodB1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< ::grpc::testing::Response>* PrepareAsyncMethodB1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) = 0; + virtual class experimental_async_interface* experimental_async() { + return nullptr; + } + + private: + virtual ::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>* + AsyncMethodB1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface< + ::grpc::testing::Response>* + PrepareAsyncMethodB1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) = 0; }; class Stub final : public StubInterface { public: - Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); - ::grpc::Status MethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::testing::Response* response) override; - std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>> AsyncMethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>>(AsyncMethodB1Raw(context, request, cq)); - } - std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>> PrepareAsyncMethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) { - return std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>>(PrepareAsyncMethodB1Raw(context, request, cq)); - } - class experimental_async final : - public StubInterface::experimental_async_interface { + Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); + ::grpc::Status MethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::testing::Response* response) override; + std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>> + AsyncMethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>>( + AsyncMethodB1Raw(context, request, cq)); + } + std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>> + PrepareAsyncMethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>>( + PrepareAsyncMethodB1Raw(context, request, cq)); + } + class experimental_async final + : public StubInterface::experimental_async_interface { public: - void MethodB1(::grpc::ClientContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, std::function) override; + void MethodB1(::grpc::ClientContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + std::function) override; + private: friend class Stub; - explicit experimental_async(Stub* stub): stub_(stub) { } + explicit experimental_async(Stub* stub) : stub_(stub) {} Stub* stub() { return stub_; } Stub* stub_; }; - class experimental_async_interface* experimental_async() override { return &async_stub_; } + class experimental_async_interface* experimental_async() override { + return &async_stub_; + } private: - std::shared_ptr< ::grpc::ChannelInterface> channel_; - class experimental_async async_stub_{this}; - ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>* AsyncMethodB1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) override; - ::grpc::ClientAsyncResponseReader< ::grpc::testing::Response>* PrepareAsyncMethodB1Raw(::grpc::ClientContext* context, const ::grpc::testing::Request& request, ::grpc::CompletionQueue* cq) override; + std::shared_ptr<::grpc::ChannelInterface> channel_; + class experimental_async async_stub_ { + this + }; + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>* + AsyncMethodB1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader<::grpc::testing::Response>* + PrepareAsyncMethodB1Raw(::grpc::ClientContext* context, + const ::grpc::testing::Request& request, + ::grpc::CompletionQueue* cq) override; const ::grpc::internal::RpcMethod rpcmethod_MethodB1_; }; - static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); + static std::unique_ptr NewStub( + const std::shared_ptr<::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); class Service : public ::grpc::Service { public: Service(); virtual ~Service(); // MethodB1 leading comment 1 - virtual ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response); + virtual ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response); // MethodB1 trailing comment 1 }; template class WithAsyncMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithAsyncMethod_MethodB1() { - ::grpc::Service::MarkMethodAsync(0); - } + WithAsyncMethod_MethodB1() { ::grpc::Service::MarkMethodAsync(0); } ~WithAsyncMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodB1(::grpc::ServerContext* context, ::grpc::testing::Request* request, ::grpc::ServerAsyncResponseWriter< ::grpc::testing::Response>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestMethodB1( + ::grpc::ServerContext* context, ::grpc::testing::Request* request, + ::grpc::ServerAsyncResponseWriter<::grpc::testing::Response>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; - typedef WithAsyncMethod_MethodB1 AsyncService; + typedef WithAsyncMethod_MethodB1 AsyncService; template class ExperimentalWithCallbackMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: ExperimentalWithCallbackMethod_MethodB1() { - ::grpc::Service::experimental().MarkMethodCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodB1, ::grpc::testing::Request, ::grpc::testing::Response>( - [this](::grpc::ServerContext* context, - const ::grpc::testing::Request* request, - ::grpc::testing::Response* response, - ::grpc::experimental::ServerCallbackRpcController* controller) { + ::grpc::Service::experimental().MarkMethodCallback( + 0, new ::grpc::internal::CallbackUnaryHandler< + ExperimentalWithCallbackMethod_MethodB1, + ::grpc::testing::Request, ::grpc::testing::Response>( + [this](::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* + controller) { this->MethodB1(context, request, response, controller); - }, this)); + }, + this)); } ~ExperimentalWithCallbackMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - virtual void MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + virtual void MethodB1( + ::grpc::ServerContext* context, const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); + } }; - typedef ExperimentalWithCallbackMethod_MethodB1 ExperimentalCallbackService; + typedef ExperimentalWithCallbackMethod_MethodB1 + ExperimentalCallbackService; template class WithGenericMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithGenericMethod_MethodB1() { - ::grpc::Service::MarkMethodGeneric(0); - } + WithGenericMethod_MethodB1() { ::grpc::Service::MarkMethodGeneric(0); } ~WithGenericMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } @@ -791,76 +1219,103 @@ class ServiceB final { template class WithRawMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: - WithRawMethod_MethodB1() { - ::grpc::Service::MarkMethodRaw(0); - } + WithRawMethod_MethodB1() { ::grpc::Service::MarkMethodRaw(0); } ~WithRawMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - void RequestMethodB1(::grpc::ServerContext* context, ::grpc::ByteBuffer* request, ::grpc::ServerAsyncResponseWriter< ::grpc::ByteBuffer>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + void RequestMethodB1( + ::grpc::ServerContext* context, ::grpc::ByteBuffer* request, + ::grpc::ServerAsyncResponseWriter<::grpc::ByteBuffer>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); } }; template class ExperimentalWithRawCallbackMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: ExperimentalWithRawCallbackMethod_MethodB1() { - ::grpc::Service::experimental().MarkMethodRawCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodB1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( - [this](::grpc::ServerContext* context, - const ::grpc::ByteBuffer* request, - ::grpc::ByteBuffer* response, - ::grpc::experimental::ServerCallbackRpcController* controller) { + ::grpc::Service::experimental().MarkMethodRawCallback( + 0, new ::grpc::internal::CallbackUnaryHandler< + ExperimentalWithRawCallbackMethod_MethodB1, + ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this](::grpc::ServerContext* context, + const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* + controller) { this->MethodB1(context, request, response, controller); - }, this)); + }, + this)); } ~ExperimentalWithRawCallbackMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable synchronous version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } - virtual void MethodB1(::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + virtual void MethodB1( + ::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); + } }; template class WithStreamedUnaryMethod_MethodB1 : public BaseClass { private: - void BaseClassMustBeDerivedFromService(const Service *service) {} + void BaseClassMustBeDerivedFromService(const Service* service) {} + public: WithStreamedUnaryMethod_MethodB1() { - ::grpc::Service::MarkMethodStreamed(0, - new ::grpc::internal::StreamedUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>(std::bind(&WithStreamedUnaryMethod_MethodB1::StreamedMethodB1, this, std::placeholders::_1, std::placeholders::_2))); + ::grpc::Service::MarkMethodStreamed( + 0, new ::grpc::internal::StreamedUnaryHandler< + ::grpc::testing::Request, ::grpc::testing::Response>(std::bind( + &WithStreamedUnaryMethod_MethodB1::StreamedMethodB1, + this, std::placeholders::_1, std::placeholders::_2))); } ~WithStreamedUnaryMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); } // disable regular version of this method - ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + ::grpc::Status MethodB1(::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response) override { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } // replace default version of method with streamed unary - virtual ::grpc::Status StreamedMethodB1(::grpc::ServerContext* context, ::grpc::ServerUnaryStreamer< ::grpc::testing::Request,::grpc::testing::Response>* server_unary_streamer) = 0; + virtual ::grpc::Status StreamedMethodB1( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer<::grpc::testing::Request, + ::grpc::testing::Response>* + server_unary_streamer) = 0; }; - typedef WithStreamedUnaryMethod_MethodB1 StreamedUnaryService; + typedef WithStreamedUnaryMethod_MethodB1 StreamedUnaryService; typedef Service SplitStreamedService; - typedef WithStreamedUnaryMethod_MethodB1 StreamedService; + typedef WithStreamedUnaryMethod_MethodB1 StreamedService; }; // ServiceB trailing comment 1 } // namespace testing } // namespace grpc - #endif // GRPC_src_2fproto_2fgrpc_2ftesting_2fcompiler_5ftest_2eproto__INCLUDED diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index 57c3cb87f2..a1fe199b54 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -187,20 +187,20 @@ class ClientCallbackEnd2endTest grpc::string test_string(""); for (int i = 0; i < num_rpcs; i++) { test_string += "Hello world. "; - class Client : public grpc::experimental::ClientBidiReactor { + class Client : public grpc::experimental::ClientBidiReactor { public: Client(ClientCallbackEnd2endTest* test, const grpc::string& method_name, const grpc::string& test_str) { - stream_ = - test->generic_stub_->experimental().PrepareBidiStreamingCall( - &cli_ctx_, method_name, this); + test->generic_stub_->experimental().PrepareBidiStreamingCall( + &cli_ctx_, method_name, this); request_.set_message(test_str); send_buf_ = SerializeToByteBuffer(&request_); - stream_->Write(send_buf_.get()); - stream_->Read(&recv_buf_); - stream_->StartCall(); + StartWrite(send_buf_.get()); + StartRead(&recv_buf_); + StartCall(); } - void OnWriteDone(bool ok) override { stream_->WritesDone(); } + void OnWriteDone(bool ok) override { StartWritesDone(); } void OnReadDone(bool ok) override { EchoResponse response; EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response)); @@ -223,8 +223,6 @@ class ClientCallbackEnd2endTest std::unique_ptr send_buf_; ByteBuffer recv_buf_; ClientContext cli_ctx_; - experimental::ClientCallbackReaderWriter* - stream_; std::mutex mu_; std::condition_variable cv_; bool done_ = false; @@ -330,22 +328,21 @@ TEST_P(ClientCallbackEnd2endTest, RequestStream) { } ResetStub(); - class Client : public grpc::experimental::ClientWriteReactor { + class Client : public grpc::experimental::ClientWriteReactor { public: explicit Client(grpc::testing::EchoTestService::Stub* stub) { context_.set_initial_metadata_corked(true); - stream_ = stub->experimental_async()->RequestStream(&context_, &response_, - this); - stream_->StartCall(); + stub->experimental_async()->RequestStream(&context_, &response_, this); + StartCall(); request_.set_message("Hello server."); - stream_->Write(&request_); + StartWrite(&request_); } void OnWriteDone(bool ok) override { writes_left_--; if (writes_left_ > 1) { - stream_->Write(&request_); + StartWrite(&request_); } else if (writes_left_ == 1) { - stream_->WriteLast(&request_, WriteOptions()); + StartWriteLast(&request_, WriteOptions()); } } void OnDone(Status s) override { @@ -363,7 +360,6 @@ TEST_P(ClientCallbackEnd2endTest, RequestStream) { } private: - ::grpc::experimental::ClientCallbackWriter* stream_; EchoRequest request_; EchoResponse response_; ClientContext context_; @@ -383,14 +379,13 @@ TEST_P(ClientCallbackEnd2endTest, ResponseStream) { } ResetStub(); - class Client : public grpc::experimental::ClientReadReactor { + class Client : public grpc::experimental::ClientReadReactor { public: explicit Client(grpc::testing::EchoTestService::Stub* stub) { request_.set_message("Hello client "); - stream_ = stub->experimental_async()->ResponseStream(&context_, &request_, - this); - stream_->StartCall(); - stream_->Read(&response_); + stub->experimental_async()->ResponseStream(&context_, &request_, this); + StartCall(); + StartRead(&response_); } void OnReadDone(bool ok) override { if (!ok) { @@ -400,7 +395,7 @@ TEST_P(ClientCallbackEnd2endTest, ResponseStream) { EXPECT_EQ(response_.message(), request_.message() + grpc::to_string(reads_complete_)); reads_complete_++; - stream_->Read(&response_); + StartRead(&response_); } } void OnDone(Status s) override { @@ -417,7 +412,6 @@ TEST_P(ClientCallbackEnd2endTest, ResponseStream) { } private: - ::grpc::experimental::ClientCallbackReader* stream_; EchoRequest request_; EchoResponse response_; ClientContext context_; @@ -436,14 +430,15 @@ TEST_P(ClientCallbackEnd2endTest, BidiStream) { return; } ResetStub(); - class Client : public grpc::experimental::ClientBidiReactor { + class Client : public grpc::experimental::ClientBidiReactor { public: explicit Client(grpc::testing::EchoTestService::Stub* stub) { request_.set_message("Hello fren "); - stream_ = stub->experimental_async()->BidiStream(&context_, this); - stream_->StartCall(); - stream_->Read(&response_); - stream_->Write(&request_); + stub->experimental_async()->BidiStream(&context_, this); + StartCall(); + StartRead(&response_); + StartWrite(&request_); } void OnReadDone(bool ok) override { if (!ok) { @@ -452,15 +447,15 @@ TEST_P(ClientCallbackEnd2endTest, BidiStream) { EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend); EXPECT_EQ(response_.message(), request_.message()); reads_complete_++; - stream_->Read(&response_); + StartRead(&response_); } } void OnWriteDone(bool ok) override { EXPECT_TRUE(ok); if (++writes_complete_ == kServerDefaultResponseStreamsToSend) { - stream_->WritesDone(); + StartWritesDone(); } else { - stream_->Write(&request_); + StartWrite(&request_); } } void OnDone(Status s) override { @@ -477,8 +472,6 @@ TEST_P(ClientCallbackEnd2endTest, BidiStream) { } private: - ::grpc::experimental::ClientCallbackReaderWriter* - stream_; EchoRequest request_; EchoResponse response_; ClientContext context_; -- cgit v1.2.3 From 2a0c0d7ad6a1256ef6c0398e1900eca2be077a51 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Mon, 5 Nov 2018 14:39:44 -0800 Subject: Streaming API for callback servers --- include/grpcpp/impl/codegen/byte_buffer.h | 8 +- include/grpcpp/impl/codegen/callback_common.h | 19 +- include/grpcpp/impl/codegen/server_callback.h | 774 ++++++++++++++++++++++++-- include/grpcpp/impl/codegen/server_context.h | 21 +- src/compiler/cpp_generator.cc | 69 ++- src/cpp/server/server_cc.cc | 7 +- src/cpp/server/server_context.cc | 47 +- test/cpp/codegen/compiler_test_golden | 56 +- test/cpp/end2end/end2end_test.cc | 90 ++- test/cpp/end2end/test_service_impl.cc | 314 +++++++++-- test/cpp/end2end/test_service_impl.h | 21 +- 11 files changed, 1261 insertions(+), 165 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index abba5549b8..53ecb53371 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -45,8 +45,10 @@ template class RpcMethodHandler; template class ServerStreamingHandler; -template +template class CallbackUnaryHandler; +template +class CallbackServerStreamingHandler; template class ErrorMethodHandler; template @@ -156,8 +158,10 @@ class ByteBuffer final { friend class internal::RpcMethodHandler; template friend class internal::ServerStreamingHandler; - template + template friend class internal::CallbackUnaryHandler; + template + friend class ::grpc::internal::CallbackServerStreamingHandler; template friend class internal::ErrorMethodHandler; template diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index f7a24204dc..a3c8c41246 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -32,6 +32,8 @@ namespace grpc { namespace internal { /// An exception-safe way of invoking a user-specified callback function +// TODO(vjpai): decide whether it is better for this to take a const lvalue +// parameter or an rvalue parameter, or if it even matters template void CatchingCallback(Func&& func, Args&&... args) { #if GRPC_ALLOW_EXCEPTIONS @@ -45,6 +47,20 @@ void CatchingCallback(Func&& func, Args&&... args) { #endif // GRPC_ALLOW_EXCEPTIONS } +template +ReturnType* CatchingReactorCreator(Func&& func, Args&&... args) { +#if GRPC_ALLOW_EXCEPTIONS + try { + return func(std::forward(args)...); + } catch (...) { + // fail the RPC, don't crash the library + return nullptr; + } +#else // GRPC_ALLOW_EXCEPTIONS + return func(std::forward(args)...); +#endif // GRPC_ALLOW_EXCEPTIONS +} + // The contract on these tags is that they are single-shot. They must be // constructed and then fired at exactly one point. There is no expectation // that they can be reused without reconstruction. @@ -185,8 +201,9 @@ class CallbackWithSuccessTag void* ignored = ops_; // Allow a "false" return value from FinalizeResult to silence the // callback, just as it silences a CQ tag in the async cases + auto* ops = ops_; bool do_callback = ops_->FinalizeResult(&ignored, &ok); - GPR_CODEGEN_ASSERT(ignored == ops_); + GPR_CODEGEN_ASSERT(ignored == ops); if (do_callback) { CatchingCallback(func_, ok); diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h index b866fc16dc..1854f6ef2f 100644 --- a/include/grpcpp/impl/codegen/server_callback.h +++ b/include/grpcpp/impl/codegen/server_callback.h @@ -19,7 +19,9 @@ #ifndef GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H #define GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H +#include #include +#include #include #include @@ -32,19 +34,33 @@ namespace grpc { -// forward declarations +// Declare base class of all reactors as internal namespace internal { -template -class CallbackUnaryHandler; + +class ServerReactor { + public: + virtual ~ServerReactor() = default; + virtual void OnDone() {} + virtual void OnCancel() {} +}; + } // namespace internal namespace experimental { +// Forward declarations +template +class ServerReadReactor; +template +class ServerWriteReactor; +template +class ServerBidiReactor; + // For unary RPCs, the exposed controller class is only an interface // and the actual implementation is an internal class. class ServerCallbackRpcController { public: - virtual ~ServerCallbackRpcController() {} + virtual ~ServerCallbackRpcController() = default; // The method handler must call this function when it is done so that // the library knows to free its resources @@ -55,18 +71,193 @@ class ServerCallbackRpcController { virtual void SendInitialMetadata(std::function) = 0; }; +// NOTE: The actual streaming object classes are provided +// as API only to support mocking. There are no implementations of +// these class interfaces in the API. +template +class ServerCallbackReader { + public: + virtual ~ServerCallbackReader() {} + virtual void Finish(Status s) = 0; + virtual void SendInitialMetadata() = 0; + virtual void Read(Request* msg) = 0; + + protected: + template + void BindReactor(ServerReadReactor* reactor) { + reactor->BindReader(this); + } +}; + +template +class ServerCallbackWriter { + public: + virtual ~ServerCallbackWriter() {} + + virtual void Finish(Status s) = 0; + virtual void SendInitialMetadata() = 0; + virtual void Write(const Response* msg, WriteOptions options) = 0; + virtual void WriteAndFinish(const Response* msg, WriteOptions options, + Status s) { + // Default implementation that can/should be overridden + Write(msg, std::move(options)); + Finish(std::move(s)); + }; + + protected: + template + void BindReactor(ServerWriteReactor* reactor) { + reactor->BindWriter(this); + } +}; + +template +class ServerCallbackReaderWriter { + public: + virtual ~ServerCallbackReaderWriter() {} + + virtual void Finish(Status s) = 0; + virtual void SendInitialMetadata() = 0; + virtual void Read(Request* msg) = 0; + virtual void Write(const Response* msg, WriteOptions options) = 0; + virtual void WriteAndFinish(const Response* msg, WriteOptions options, + Status s) { + // Default implementation that can/should be overridden + Write(msg, std::move(options)); + Finish(std::move(s)); + }; + + protected: + void BindReactor(ServerBidiReactor* reactor) { + reactor->BindStream(this); + } +}; + +// The following classes are reactors that are to be implemented +// by the user, returned as the result of the method handler for +// a callback method, and activated by the call to OnStarted +template +class ServerBidiReactor : public internal::ServerReactor { + public: + ~ServerBidiReactor() = default; + virtual void OnStarted(ServerContext*) {} + virtual void OnSendInitialMetadataDone(bool ok) {} + virtual void OnReadDone(bool ok) {} + virtual void OnWriteDone(bool ok) {} + + void StartSendInitialMetadata() { stream_->SendInitialMetadata(); } + void StartRead(Request* msg) { stream_->Read(msg); } + void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); } + void StartWrite(const Response* msg, WriteOptions options) { + stream_->Write(msg, std::move(options)); + } + void StartWriteAndFinish(const Response* msg, WriteOptions options, + Status s) { + stream_->WriteAndFinish(msg, std::move(options), std::move(s)); + } + void StartWriteLast(const Response* msg, WriteOptions options) { + StartWrite(msg, std::move(options.set_last_message())); + } + void Finish(Status s) { stream_->Finish(std::move(s)); } + + private: + friend class ServerCallbackReaderWriter; + void BindStream(ServerCallbackReaderWriter* stream) { + stream_ = stream; + } + + ServerCallbackReaderWriter* stream_; +}; + +template +class ServerReadReactor : public internal::ServerReactor { + public: + ~ServerReadReactor() = default; + virtual void OnStarted(ServerContext*, Response* resp) {} + virtual void OnSendInitialMetadataDone(bool ok) {} + virtual void OnReadDone(bool ok) {} + + void StartSendInitialMetadata() { reader_->SendInitialMetadata(); } + void StartRead(Request* msg) { reader_->Read(msg); } + void Finish(Status s) { reader_->Finish(std::move(s)); } + + private: + friend class ServerCallbackReader; + void BindReader(ServerCallbackReader* reader) { reader_ = reader; } + + ServerCallbackReader* reader_; +}; + +template +class ServerWriteReactor : public internal::ServerReactor { + public: + ~ServerWriteReactor() = default; + virtual void OnStarted(ServerContext*, const Request* req) {} + virtual void OnSendInitialMetadataDone(bool ok) {} + virtual void OnWriteDone(bool ok) {} + + void StartSendInitialMetadata() { writer_->SendInitialMetadata(); } + void StartWrite(const Response* msg) { StartWrite(msg, WriteOptions()); } + void StartWrite(const Response* msg, WriteOptions options) { + writer_->Write(msg, std::move(options)); + } + void StartWriteAndFinish(const Response* msg, WriteOptions options, + Status s) { + writer_->WriteAndFinish(msg, std::move(options), std::move(s)); + } + void StartWriteLast(const Response* msg, WriteOptions options) { + StartWrite(msg, std::move(options.set_last_message())); + } + void Finish(Status s) { writer_->Finish(std::move(s)); } + + private: + friend class ServerCallbackWriter; + void BindWriter(ServerCallbackWriter* writer) { writer_ = writer; } + + ServerCallbackWriter* writer_; +}; + } // namespace experimental namespace internal { -template +template +class UnimplementedReadReactor + : public experimental::ServerReadReactor { + public: + void OnDone() override { delete this; } + void OnStarted(ServerContext*, Response*) override { + this->Finish(Status(StatusCode::UNIMPLEMENTED, "")); + } +}; + +template +class UnimplementedWriteReactor + : public experimental::ServerWriteReactor { + public: + void OnDone() override { delete this; } + void OnStarted(ServerContext*, const Request*) override { + this->Finish(Status(StatusCode::UNIMPLEMENTED, "")); + } +}; + +template +class UnimplementedBidiReactor + : public experimental::ServerBidiReactor { + public: + void OnDone() override { delete this; } + void OnStarted(ServerContext*) override { + this->Finish(Status(StatusCode::UNIMPLEMENTED, "")); + } +}; + +template class CallbackUnaryHandler : public MethodHandler { public: CallbackUnaryHandler( std::function - func, - ServiceType* service) + func) : func_(func) {} void RunHandler(const HandlerParameter& param) final { // Arena allocate a controller structure (that includes request/response) @@ -81,9 +272,8 @@ class CallbackUnaryHandler : public MethodHandler { if (status.ok()) { // Call the actual function handler and expect the user to call finish - CatchingCallback(std::move(func_), param.server_context, - controller->request(), controller->response(), - controller); + CatchingCallback(func_, param.server_context, controller->request(), + controller->response(), controller); } else { // if deserialization failed, we need to fail the call controller->Finish(status); @@ -117,79 +307,579 @@ class CallbackUnaryHandler : public MethodHandler { : public experimental::ServerCallbackRpcController { public: void Finish(Status s) override { - finish_tag_.Set( - call_.call(), - [this](bool) { - grpc_call* call = call_.call(); - auto call_requester = std::move(call_requester_); - this->~ServerCallbackRpcControllerImpl(); // explicitly call - // destructor - g_core_codegen_interface->grpc_call_unref(call); - call_requester(); - }, - &finish_buf_); + finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); }, + &finish_ops_); if (!ctx_->sent_initial_metadata_) { - finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { - finish_buf_.set_compression_level(ctx_->compression_level()); + finish_ops_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; } // The response is dropped if the status is not OK. if (s.ok()) { - finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, - finish_buf_.SendMessage(resp_)); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, + finish_ops_.SendMessage(resp_)); } else { - finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, s); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s); } - finish_buf_.set_core_cq_tag(&finish_tag_); - call_.PerformOps(&finish_buf_); + finish_ops_.set_core_cq_tag(&finish_tag_); + call_.PerformOps(&finish_ops_); } void SendInitialMetadata(std::function f) override { GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); - - meta_tag_.Set(call_.call(), std::move(f), &meta_buf_); - meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_, + callbacks_outstanding_++; + // TODO(vjpai): Consider taking f as a move-capture if we adopt C++14 + // and if performance of this operation matters + meta_tag_.Set(call_.call(), + [this, f](bool ok) { + f(ok); + MaybeDone(); + }, + &meta_ops_); + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, ctx_->initial_metadata_flags()); if (ctx_->compression_level_set()) { - meta_buf_.set_compression_level(ctx_->compression_level()); + meta_ops_.set_compression_level(ctx_->compression_level()); } ctx_->sent_initial_metadata_ = true; - meta_buf_.set_core_cq_tag(&meta_tag_); - call_.PerformOps(&meta_buf_); + meta_ops_.set_core_cq_tag(&meta_tag_); + call_.PerformOps(&meta_ops_); } private: - template - friend class CallbackUnaryHandler; + friend class CallbackUnaryHandler; ServerCallbackRpcControllerImpl(ServerContext* ctx, Call* call, - RequestType* req, + const RequestType* req, std::function call_requester) : ctx_(ctx), call_(*call), req_(req), - call_requester_(std::move(call_requester)) {} + call_requester_(std::move(call_requester)) { + ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, nullptr); + } ~ServerCallbackRpcControllerImpl() { req_->~RequestType(); } - RequestType* request() { return req_; } + const RequestType* request() { return req_; } ResponseType* response() { return &resp_; } - CallOpSet meta_buf_; + void MaybeDone() { + if (--callbacks_outstanding_ == 0) { + grpc_call* call = call_.call(); + auto call_requester = std::move(call_requester_); + this->~ServerCallbackRpcControllerImpl(); // explicitly call destructor + g_core_codegen_interface->grpc_call_unref(call); + call_requester(); + } + } + + CallOpSet meta_ops_; CallbackWithSuccessTag meta_tag_; CallOpSet - finish_buf_; + finish_ops_; CallbackWithSuccessTag finish_tag_; ServerContext* ctx_; Call call_; - RequestType* req_; + const RequestType* req_; ResponseType resp_; std::function call_requester_; + std::atomic_int callbacks_outstanding_{ + 2}; // reserve for Finish and CompletionOp + }; +}; + +template +class CallbackClientStreamingHandler : public MethodHandler { + public: + CallbackClientStreamingHandler( + std::function< + experimental::ServerReadReactor*()> + func) + : func_(std::move(func)) {} + void RunHandler(const HandlerParameter& param) final { + // Arena allocate a reader structure (that includes response) + g_core_codegen_interface->grpc_call_ref(param.call->call()); + + experimental::ServerReadReactor* reactor = + param.status.ok() + ? CatchingReactorCreator< + experimental::ServerReadReactor>( + func_) + : nullptr; + + if (reactor == nullptr) { + // if deserialization or reactor creator failed, we need to fail the call + reactor = new UnimplementedReadReactor; + } + + auto* reader = new (g_core_codegen_interface->grpc_call_arena_alloc( + param.call->call(), sizeof(ServerCallbackReaderImpl))) + ServerCallbackReaderImpl(param.server_context, param.call, + std::move(param.call_requester), reactor); + + reader->BindReactor(reactor); + reactor->OnStarted(param.server_context, reader->response()); + reader->MaybeDone(); + } + + private: + std::function*()> + func_; + + class ServerCallbackReaderImpl + : public experimental::ServerCallbackReader { + public: + void Finish(Status s) override { + finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); }, + &finish_ops_); + if (!ctx_->sent_initial_metadata_) { + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + finish_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + // The response is dropped if the status is not OK. + if (s.ok()) { + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, + finish_ops_.SendMessage(resp_)); + } else { + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s); + } + finish_ops_.set_core_cq_tag(&finish_tag_); + call_.PerformOps(&finish_ops_); + } + + void SendInitialMetadata() override { + GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); + callbacks_outstanding_++; + meta_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnSendInitialMetadataDone(ok); + MaybeDone(); + }, + &meta_ops_); + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + meta_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + meta_ops_.set_core_cq_tag(&meta_tag_); + call_.PerformOps(&meta_ops_); + } + + void Read(RequestType* req) override { + callbacks_outstanding_++; + read_ops_.RecvMessage(req); + call_.PerformOps(&read_ops_); + } + + private: + friend class CallbackClientStreamingHandler; + + ServerCallbackReaderImpl( + ServerContext* ctx, Call* call, std::function call_requester, + experimental::ServerReadReactor* reactor) + : ctx_(ctx), + call_(*call), + call_requester_(std::move(call_requester)), + reactor_(reactor) { + ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor); + read_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeDone(); + }, + &read_ops_); + read_ops_.set_core_cq_tag(&read_tag_); + } + + ~ServerCallbackReaderImpl() {} + + ResponseType* response() { return &resp_; } + + void MaybeDone() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(); + grpc_call* call = call_.call(); + auto call_requester = std::move(call_requester_); + this->~ServerCallbackReaderImpl(); // explicitly call destructor + g_core_codegen_interface->grpc_call_unref(call); + call_requester(); + } + } + + CallOpSet meta_ops_; + CallbackWithSuccessTag meta_tag_; + CallOpSet + finish_ops_; + CallbackWithSuccessTag finish_tag_; + CallOpSet> read_ops_; + CallbackWithSuccessTag read_tag_; + + ServerContext* ctx_; + Call call_; + ResponseType resp_; + std::function call_requester_; + experimental::ServerReadReactor* reactor_; + std::atomic_int callbacks_outstanding_{ + 3}; // reserve for OnStarted, Finish, and CompletionOp + }; +}; + +template +class CallbackServerStreamingHandler : public MethodHandler { + public: + CallbackServerStreamingHandler( + std::function< + experimental::ServerWriteReactor*()> + func) + : func_(std::move(func)) {} + void RunHandler(const HandlerParameter& param) final { + // Arena allocate a writer structure + g_core_codegen_interface->grpc_call_ref(param.call->call()); + + experimental::ServerWriteReactor* reactor = + param.status.ok() + ? CatchingReactorCreator< + experimental::ServerWriteReactor>( + func_) + : nullptr; + + if (reactor == nullptr) { + // if deserialization or reactor creator failed, we need to fail the call + reactor = new UnimplementedWriteReactor; + } + + auto* writer = new (g_core_codegen_interface->grpc_call_arena_alloc( + param.call->call(), sizeof(ServerCallbackWriterImpl))) + ServerCallbackWriterImpl(param.server_context, param.call, + static_cast(param.request), + std::move(param.call_requester), reactor); + writer->BindReactor(reactor); + reactor->OnStarted(param.server_context, writer->request()); + writer->MaybeDone(); + } + + void* Deserialize(grpc_call* call, grpc_byte_buffer* req, + Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new (g_core_codegen_interface->grpc_call_arena_alloc( + call, sizeof(RequestType))) RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + request->~RequestType(); + return nullptr; + } + + private: + std::function*()> + func_; + + class ServerCallbackWriterImpl + : public experimental::ServerCallbackWriter { + public: + void Finish(Status s) override { + finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); }, + &finish_ops_); + finish_ops_.set_core_cq_tag(&finish_tag_); + + if (!ctx_->sent_initial_metadata_) { + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + finish_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s); + call_.PerformOps(&finish_ops_); + } + + void SendInitialMetadata() override { + GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); + callbacks_outstanding_++; + meta_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnSendInitialMetadataDone(ok); + MaybeDone(); + }, + &meta_ops_); + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + meta_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + meta_ops_.set_core_cq_tag(&meta_tag_); + call_.PerformOps(&meta_ops_); + } + + void Write(const ResponseType* resp, WriteOptions options) override { + callbacks_outstanding_++; + if (options.is_last_message()) { + options.set_buffer_hint(); + } + if (!ctx_->sent_initial_metadata_) { + write_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + write_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok()); + call_.PerformOps(&write_ops_); + } + + void WriteAndFinish(const ResponseType* resp, WriteOptions options, + Status s) override { + // This combines the write into the finish callback + // Don't send any message if the status is bad + if (s.ok()) { + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok()); + } + Finish(std::move(s)); + } + + private: + friend class CallbackServerStreamingHandler; + + ServerCallbackWriterImpl( + ServerContext* ctx, Call* call, const RequestType* req, + std::function call_requester, + experimental::ServerWriteReactor* reactor) + : ctx_(ctx), + call_(*call), + req_(req), + call_requester_(std::move(call_requester)), + reactor_(reactor) { + ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor); + write_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeDone(); + }, + &write_ops_); + write_ops_.set_core_cq_tag(&write_tag_); + } + ~ServerCallbackWriterImpl() { req_->~RequestType(); } + + const RequestType* request() { return req_; } + + void MaybeDone() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(); + grpc_call* call = call_.call(); + auto call_requester = std::move(call_requester_); + this->~ServerCallbackWriterImpl(); // explicitly call destructor + g_core_codegen_interface->grpc_call_unref(call); + call_requester(); + } + } + + CallOpSet meta_ops_; + CallbackWithSuccessTag meta_tag_; + CallOpSet + finish_ops_; + CallbackWithSuccessTag finish_tag_; + CallOpSet write_ops_; + CallbackWithSuccessTag write_tag_; + + ServerContext* ctx_; + Call call_; + const RequestType* req_; + std::function call_requester_; + experimental::ServerWriteReactor* reactor_; + std::atomic_int callbacks_outstanding_{ + 3}; // reserve for OnStarted, Finish, and CompletionOp + }; +}; + +template +class CallbackBidiHandler : public MethodHandler { + public: + CallbackBidiHandler( + std::function< + experimental::ServerBidiReactor*()> + func) + : func_(std::move(func)) {} + void RunHandler(const HandlerParameter& param) final { + g_core_codegen_interface->grpc_call_ref(param.call->call()); + + experimental::ServerBidiReactor* reactor = + param.status.ok() + ? CatchingReactorCreator< + experimental::ServerBidiReactor>( + func_) + : nullptr; + + if (reactor == nullptr) { + // if deserialization or reactor creator failed, we need to fail the call + reactor = new UnimplementedBidiReactor; + } + + auto* stream = new (g_core_codegen_interface->grpc_call_arena_alloc( + param.call->call(), sizeof(ServerCallbackReaderWriterImpl))) + ServerCallbackReaderWriterImpl(param.server_context, param.call, + std::move(param.call_requester), + reactor); + + stream->BindReactor(reactor); + reactor->OnStarted(param.server_context); + stream->MaybeDone(); + } + + private: + std::function*()> + func_; + + class ServerCallbackReaderWriterImpl + : public experimental::ServerCallbackReaderWriter { + public: + void Finish(Status s) override { + finish_tag_.Set(call_.call(), [this](bool) { MaybeDone(); }, + &finish_ops_); + finish_ops_.set_core_cq_tag(&finish_tag_); + + if (!ctx_->sent_initial_metadata_) { + finish_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + finish_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s); + call_.PerformOps(&finish_ops_); + } + + void SendInitialMetadata() override { + GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); + callbacks_outstanding_++; + meta_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnSendInitialMetadataDone(ok); + MaybeDone(); + }, + &meta_ops_); + meta_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + meta_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + meta_ops_.set_core_cq_tag(&meta_tag_); + call_.PerformOps(&meta_ops_); + } + + void Write(const ResponseType* resp, WriteOptions options) override { + callbacks_outstanding_++; + if (options.is_last_message()) { + options.set_buffer_hint(); + } + if (!ctx_->sent_initial_metadata_) { + write_ops_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + write_ops_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok()); + call_.PerformOps(&write_ops_); + } + + void WriteAndFinish(const ResponseType* resp, WriteOptions options, + Status s) override { + // Don't send any message if the status is bad + if (s.ok()) { + // TODO(vjpai): don't assert + GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok()); + } + Finish(std::move(s)); + } + + void Read(RequestType* req) override { + callbacks_outstanding_++; + read_ops_.RecvMessage(req); + call_.PerformOps(&read_ops_); + } + + private: + friend class CallbackBidiHandler; + + ServerCallbackReaderWriterImpl( + ServerContext* ctx, Call* call, std::function call_requester, + experimental::ServerBidiReactor* reactor) + : ctx_(ctx), + call_(*call), + call_requester_(std::move(call_requester)), + reactor_(reactor) { + ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, reactor); + write_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnWriteDone(ok); + MaybeDone(); + }, + &write_ops_); + write_ops_.set_core_cq_tag(&write_tag_); + read_tag_.Set(call_.call(), + [this](bool ok) { + reactor_->OnReadDone(ok); + MaybeDone(); + }, + &read_ops_); + read_ops_.set_core_cq_tag(&read_tag_); + } + ~ServerCallbackReaderWriterImpl() {} + + void MaybeDone() { + if (--callbacks_outstanding_ == 0) { + reactor_->OnDone(); + grpc_call* call = call_.call(); + auto call_requester = std::move(call_requester_); + this->~ServerCallbackReaderWriterImpl(); // explicitly call destructor + g_core_codegen_interface->grpc_call_unref(call); + call_requester(); + } + } + + CallOpSet meta_ops_; + CallbackWithSuccessTag meta_tag_; + CallOpSet + finish_ops_; + CallbackWithSuccessTag finish_tag_; + CallOpSet write_ops_; + CallbackWithSuccessTag write_tag_; + CallOpSet> read_ops_; + CallbackWithSuccessTag read_tag_; + + ServerContext* ctx_; + Call call_; + std::function call_requester_; + experimental::ServerBidiReactor* reactor_; + std::atomic_int callbacks_outstanding_{ + 3}; // reserve for OnStarted, Finish, and CompletionOp }; }; diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index 82ee862f61..4a5f9e2dd9 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -66,13 +66,20 @@ template class ServerStreamingHandler; template class BidiStreamingHandler; -template +template class CallbackUnaryHandler; +template +class CallbackClientStreamingHandler; +template +class CallbackServerStreamingHandler; +template +class CallbackBidiHandler; template class TemplatedBidiStreamingHandler; template class ErrorMethodHandler; class Call; +class ServerReactor; } // namespace internal class CompletionQueue; @@ -270,8 +277,14 @@ class ServerContext { friend class ::grpc::internal::ServerStreamingHandler; template friend class ::grpc::internal::TemplatedBidiStreamingHandler; - template + template friend class ::grpc::internal::CallbackUnaryHandler; + template + friend class ::grpc::internal::CallbackClientStreamingHandler; + template + friend class ::grpc::internal::CallbackServerStreamingHandler; + template + friend class ::grpc::internal::CallbackBidiHandler; template friend class internal::ErrorMethodHandler; friend class ::grpc::ClientContext; @@ -282,7 +295,9 @@ class ServerContext { class CompletionOp; - void BeginCompletionOp(internal::Call* call, bool callback); + void BeginCompletionOp(internal::Call* call, + std::function callback, + internal::ServerReactor* reactor); /// Return the tag queued by BeginCompletionOp() internal::CompletionQueueTag* GetCompletionOpTag(); diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index a368b47f01..b004687250 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -889,6 +889,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerReadReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedReadReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } else if (ServerOnlyStreaming(method)) { printer->Print( *vars, @@ -900,6 +905,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerWriteReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedWriteReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } else if (method->BidiStreaming()) { printer->Print( *vars, @@ -911,6 +921,11 @@ void PrintHeaderServerCallbackMethodsHelper( " abort();\n" " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" "}\n"); + printer->Print(*vars, + "virtual ::grpc::experimental::ServerBidiReactor< " + "$RealRequest$, $RealResponse$>* $Method$() {\n" + " return new ::grpc::internal::UnimplementedBidiReactor<\n" + " $RealRequest$, $RealResponse$>;}\n"); } } @@ -939,22 +954,36 @@ void PrintHeaderServerMethodCallback( *vars, " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" " new ::grpc::internal::CallbackUnaryHandler< " - "ExperimentalWithCallbackMethod_$Method$, $RealRequest$, " - "$RealResponse$>(\n" + "$RealRequest$, $RealResponse$>(\n" " [this](::grpc::ServerContext* context,\n" " const $RealRequest$* request,\n" " $RealResponse$* response,\n" " ::grpc::experimental::ServerCallbackRpcController* " "controller) {\n" - " this->$" + " return this->$" "Method$(context, request, response, controller);\n" - " }, this));\n"); + " }));\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } printer->Print(*vars, "}\n"); printer->Print(*vars, @@ -991,8 +1020,7 @@ void PrintHeaderServerMethodRawCallback( *vars, " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" " new ::grpc::internal::CallbackUnaryHandler< " - "ExperimentalWithRawCallbackMethod_$Method$, $RealRequest$, " - "$RealResponse$>(\n" + "$RealRequest$, $RealResponse$>(\n" " [this](::grpc::ServerContext* context,\n" " const $RealRequest$* request,\n" " $RealResponse$* response,\n" @@ -1000,13 +1028,28 @@ void PrintHeaderServerMethodRawCallback( "controller) {\n" " this->$" "Method$(context, request, response, controller);\n" - " }, this));\n"); + " }));\n"); } else if (ClientOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackClientStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (ServerOnlyStreaming(method)) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackServerStreamingHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } else if (method->BidiStreaming()) { - // TODO(vjpai): Add in code generation for all streaming methods + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackBidiHandler< " + "$RealRequest$, $RealResponse$>(\n" + " [this] { return this->$Method$(); }));\n"); } printer->Print(*vars, "}\n"); printer->Print(*vars, diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 0a51cf5626..69af43a656 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -291,7 +291,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void ContinueRunAfterInterception() { { - ctx_.BeginCompletionOp(&call_, false); + ctx_.BeginCompletionOp(&call_, nullptr, nullptr); global_callbacks_->PreSynchronousRequest(&ctx_); auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get(); @@ -456,7 +456,6 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { } } void ContinueRunAfterInterception() { - req_->ctx_.BeginCompletionOp(call_, true); req_->method_->handler()->RunHandler( internal::MethodHandler::HandlerParameter( call_, &req_->ctx_, req_->request_, req_->request_status_, @@ -1018,7 +1017,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, } } if (*status && call_) { - context_->BeginCompletionOp(&call_wrapper_, false); + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); } *tag = tag_; if (delete_on_finalize_) { @@ -1029,7 +1028,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, void ServerInterface::BaseAsyncRequest:: ContinueFinalizeResultAfterInterception() { - context_->BeginCompletionOp(&call_wrapper_, false); + context_->BeginCompletionOp(&call_wrapper_, nullptr, nullptr); // Queue a tag which will be returned immediately grpc_core::ExecCtx exec_ctx; grpc_cq_begin_op(notification_cq_->cq(), this); diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 9c01f896e6..1b524bc3e8 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -17,6 +17,7 @@ */ #include +#include #include #include @@ -41,8 +42,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { public: // initial refs: one in the server context, one in the cq // must ref the call before calling constructor and after deleting this - CompletionOp(internal::Call* call) + CompletionOp(internal::Call* call, internal::ServerReactor* reactor) : call_(*call), + reactor_(reactor), has_tag_(false), tag_(nullptr), core_cq_tag_(this), @@ -124,9 +126,9 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { return; } /* Start a dummy op so that we can return the tag */ - GPR_CODEGEN_ASSERT(GRPC_CALL_OK == - g_core_codegen_interface->grpc_call_start_batch( - call_.call(), nullptr, 0, this, nullptr)); + GPR_CODEGEN_ASSERT( + GRPC_CALL_OK == + grpc_call_start_batch(call_.call(), nullptr, 0, core_cq_tag_, nullptr)); } private: @@ -136,13 +138,14 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { } internal::Call call_; + internal::ServerReactor* reactor_; bool has_tag_; void* tag_; void* core_cq_tag_; std::mutex mu_; int refs_; bool finalized_; - int cancelled_; + int cancelled_; // This is an int (not bool) because it is passed to core bool done_intercepting_; internal::InterceptorBatchMethodsImpl interceptor_methods_; }; @@ -190,7 +193,16 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { } finalized_ = true; - if (!*status) cancelled_ = 1; + // If for some reason the incoming status is false, mark that as a + // cancellation. + // TODO(vjpai): does this ever happen? + if (!*status) { + cancelled_ = 1; + } + + if (cancelled_ && (reactor_ != nullptr)) { + reactor_->OnCancel(); + } /* Release the lock since we are going to be running through interceptors now */ lock.unlock(); @@ -251,21 +263,25 @@ void ServerContext::Clear() { initial_metadata_.clear(); trailing_metadata_.clear(); client_metadata_.Reset(); - if (call_) { - grpc_call_unref(call_); - } if (completion_op_) { completion_op_->Unref(); + completion_op_ = nullptr; completion_tag_.Clear(); } if (rpc_info_) { rpc_info_->Unref(); + rpc_info_ = nullptr; + } + if (call_) { + auto* call = call_; + call_ = nullptr; + grpc_call_unref(call); } - // Don't need to clear out call_, completion_op_, or rpc_info_ because this is - // either called from destructor or just before Setup } -void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) { +void ServerContext::BeginCompletionOp(internal::Call* call, + std::function callback, + internal::ServerReactor* reactor) { GPR_ASSERT(!completion_op_); if (rpc_info_) { rpc_info_->Ref(); @@ -273,10 +289,11 @@ void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) { grpc_call_ref(call->call()); completion_op_ = new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) - CompletionOp(call); - if (callback) { - completion_tag_.Set(call->call(), nullptr, completion_op_); + CompletionOp(call, reactor); + if (callback != nullptr) { + completion_tag_.Set(call->call(), std::move(callback), completion_op_); completion_op_->set_core_cq_tag(&completion_tag_); + completion_op_->set_tag(completion_op_); } else if (has_notify_when_done_tag_) { completion_op_->set_tag(async_notify_when_done_tag_); } diff --git a/test/cpp/codegen/compiler_test_golden b/test/cpp/codegen/compiler_test_golden index 5f0eb6c35c..1871e1375e 100644 --- a/test/cpp/codegen/compiler_test_golden +++ b/test/cpp/codegen/compiler_test_golden @@ -322,13 +322,13 @@ class ServiceA final { public: ExperimentalWithCallbackMethod_MethodA1() { ::grpc::Service::experimental().MarkMethodCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodA1, ::grpc::testing::Request, ::grpc::testing::Response>( + new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>( [this](::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { - this->MethodA1(context, request, response, controller); - }, this)); + return this->MethodA1(context, request, response, controller); + })); } ~ExperimentalWithCallbackMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); @@ -346,6 +346,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithCallbackMethod_MethodA2() { + ::grpc::Service::experimental().MarkMethodCallback(1, + new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>( + [this] { return this->MethodA2(); })); } ~ExperimentalWithCallbackMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); @@ -355,6 +358,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerReadReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA2() { + return new ::grpc::internal::UnimplementedReadReactor< + ::grpc::testing::Request, ::grpc::testing::Response>;} }; template class ExperimentalWithCallbackMethod_MethodA3 : public BaseClass { @@ -362,6 +368,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithCallbackMethod_MethodA3() { + ::grpc::Service::experimental().MarkMethodCallback(2, + new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::testing::Request, ::grpc::testing::Response>( + [this] { return this->MethodA3(); })); } ~ExperimentalWithCallbackMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); @@ -371,6 +380,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerWriteReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA3() { + return new ::grpc::internal::UnimplementedWriteReactor< + ::grpc::testing::Request, ::grpc::testing::Response>;} }; template class ExperimentalWithCallbackMethod_MethodA4 : public BaseClass { @@ -378,6 +390,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithCallbackMethod_MethodA4() { + ::grpc::Service::experimental().MarkMethodCallback(3, + new ::grpc::internal::CallbackBidiHandler< ::grpc::testing::Request, ::grpc::testing::Response>( + [this] { return this->MethodA4(); })); } ~ExperimentalWithCallbackMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); @@ -387,6 +402,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerBidiReactor< ::grpc::testing::Request, ::grpc::testing::Response>* MethodA4() { + return new ::grpc::internal::UnimplementedBidiReactor< + ::grpc::testing::Request, ::grpc::testing::Response>;} }; typedef ExperimentalWithCallbackMethod_MethodA1 > > > ExperimentalCallbackService; template @@ -544,13 +562,13 @@ class ServiceA final { public: ExperimentalWithRawCallbackMethod_MethodA1() { ::grpc::Service::experimental().MarkMethodRawCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodA1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( [this](::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { this->MethodA1(context, request, response, controller); - }, this)); + })); } ~ExperimentalWithRawCallbackMethod_MethodA1() override { BaseClassMustBeDerivedFromService(this); @@ -568,6 +586,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithRawCallbackMethod_MethodA2() { + ::grpc::Service::experimental().MarkMethodRawCallback(1, + new ::grpc::internal::CallbackClientStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this] { return this->MethodA2(); })); } ~ExperimentalWithRawCallbackMethod_MethodA2() override { BaseClassMustBeDerivedFromService(this); @@ -577,6 +598,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerReadReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA2() { + return new ::grpc::internal::UnimplementedReadReactor< + ::grpc::ByteBuffer, ::grpc::ByteBuffer>;} }; template class ExperimentalWithRawCallbackMethod_MethodA3 : public BaseClass { @@ -584,6 +608,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithRawCallbackMethod_MethodA3() { + ::grpc::Service::experimental().MarkMethodRawCallback(2, + new ::grpc::internal::CallbackServerStreamingHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this] { return this->MethodA3(); })); } ~ExperimentalWithRawCallbackMethod_MethodA3() override { BaseClassMustBeDerivedFromService(this); @@ -593,6 +620,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerWriteReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA3() { + return new ::grpc::internal::UnimplementedWriteReactor< + ::grpc::ByteBuffer, ::grpc::ByteBuffer>;} }; template class ExperimentalWithRawCallbackMethod_MethodA4 : public BaseClass { @@ -600,6 +630,9 @@ class ServiceA final { void BaseClassMustBeDerivedFromService(const Service *service) {} public: ExperimentalWithRawCallbackMethod_MethodA4() { + ::grpc::Service::experimental().MarkMethodRawCallback(3, + new ::grpc::internal::CallbackBidiHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this] { return this->MethodA4(); })); } ~ExperimentalWithRawCallbackMethod_MethodA4() override { BaseClassMustBeDerivedFromService(this); @@ -609,6 +642,9 @@ class ServiceA final { abort(); return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); } + virtual ::grpc::experimental::ServerBidiReactor< ::grpc::ByteBuffer, ::grpc::ByteBuffer>* MethodA4() { + return new ::grpc::internal::UnimplementedBidiReactor< + ::grpc::ByteBuffer, ::grpc::ByteBuffer>;} }; template class WithStreamedUnaryMethod_MethodA1 : public BaseClass { @@ -752,13 +788,13 @@ class ServiceB final { public: ExperimentalWithCallbackMethod_MethodB1() { ::grpc::Service::experimental().MarkMethodCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodB1, ::grpc::testing::Request, ::grpc::testing::Response>( + new ::grpc::internal::CallbackUnaryHandler< ::grpc::testing::Request, ::grpc::testing::Response>( [this](::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { - this->MethodB1(context, request, response, controller); - }, this)); + return this->MethodB1(context, request, response, controller); + })); } ~ExperimentalWithCallbackMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); @@ -815,13 +851,13 @@ class ServiceB final { public: ExperimentalWithRawCallbackMethod_MethodB1() { ::grpc::Service::experimental().MarkMethodRawCallback(0, - new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodB1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + new ::grpc::internal::CallbackUnaryHandler< ::grpc::ByteBuffer, ::grpc::ByteBuffer>( [this](::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { this->MethodB1(context, request, response, controller); - }, this)); + })); } ~ExperimentalWithRawCallbackMethod_MethodB1() override { BaseClassMustBeDerivedFromService(this); diff --git a/test/cpp/end2end/end2end_test.cc b/test/cpp/end2end/end2end_test.cc index 03291e1785..4d37bae217 100644 --- a/test/cpp/end2end/end2end_test.cc +++ b/test/cpp/end2end/end2end_test.cc @@ -196,16 +196,18 @@ class TestServiceImplDupPkg class TestScenario { public: TestScenario(bool interceptors, bool proxy, bool inproc_stub, - const grpc::string& creds_type) + const grpc::string& creds_type, bool use_callback_server) : use_interceptors(interceptors), use_proxy(proxy), inproc(inproc_stub), - credentials_type(creds_type) {} + credentials_type(creds_type), + callback_server(use_callback_server) {} void Log() const; bool use_interceptors; bool use_proxy; bool inproc; const grpc::string credentials_type; + bool callback_server; }; static std::ostream& operator<<(std::ostream& out, @@ -214,6 +216,8 @@ static std::ostream& operator<<(std::ostream& out, << (scenario.use_interceptors ? "true" : "false") << ", use_proxy=" << (scenario.use_proxy ? "true" : "false") << ", inproc=" << (scenario.inproc ? "true" : "false") + << ", server_type=" + << (scenario.callback_server ? "callback" : "sync") << ", credentials='" << scenario.credentials_type << "'}"; } @@ -280,7 +284,11 @@ class End2endTest : public ::testing::TestWithParam { builder.experimental().SetInterceptorCreators(std::move(creators)); } builder.AddListeningPort(server_address_.str(), server_creds); - builder.RegisterService(&service_); + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } builder.RegisterService("foo.test.youtube.com", &special_service_); builder.RegisterService(&dup_pkg_service_); @@ -362,6 +370,7 @@ class End2endTest : public ::testing::TestWithParam { std::ostringstream server_address_; const int kMaxMessageSize_; TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; TestServiceImpl special_service_; TestServiceImplDupPkg dup_pkg_service_; grpc::string user_agent_prefix_; @@ -1016,7 +1025,8 @@ TEST_P(End2endTest, DiffPackageServices) { EXPECT_TRUE(s.ok()); } -void CancelRpc(ClientContext* context, int delay_us, TestServiceImpl* service) { +template +void CancelRpc(ClientContext* context, int delay_us, ServiceType* service) { gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_micros(delay_us, GPR_TIMESPAN))); while (!service->signal_client()) { @@ -1446,7 +1456,24 @@ TEST_P(ProxyEnd2endTest, ClientCancelsRpc) { request.mutable_param()->set_client_cancel_after_us(kCancelDelayUs); ClientContext context; - std::thread cancel_thread(CancelRpc, &context, kCancelDelayUs, &service_); + std::thread cancel_thread; + if (!GetParam().callback_server) { + cancel_thread = std::thread( + [&context, this](int delay) { CancelRpc(&context, delay, &service_); }, + kCancelDelayUs); + // Note: the unusual pattern above (and below) is caused by a conflict + // between two sets of compiler expectations. clang allows const to be + // captured without mention, so there is no need to capture kCancelDelayUs + // (and indeed clang-tidy complains if you do so). OTOH, a Windows compiler + // in our tests requires an explicit capture even for const. We square this + // circle by passing the const value in as an argument to the lambda. + } else { + cancel_thread = std::thread( + [&context, this](int delay) { + CancelRpc(&context, delay, &callback_service_); + }, + kCancelDelayUs); + } Status s = stub_->Echo(&context, request, &response); cancel_thread.join(); EXPECT_EQ(StatusCode::CANCELLED, s.error_code()); @@ -1838,10 +1865,12 @@ TEST_P(ResourceQuotaEnd2endTest, SimpleRequest) { EXPECT_TRUE(s.ok()); } +// TODO(vjpai): refactor arguments into a struct if it makes sense std::vector CreateTestScenarios(bool use_proxy, bool test_insecure, bool test_secure, - bool test_inproc) { + bool test_inproc, + bool test_callback_server) { std::vector scenarios; std::vector credentials_types; if (test_secure) { @@ -1857,41 +1886,48 @@ std::vector CreateTestScenarios(bool use_proxy, if (test_insecure && insec_ok()) { credentials_types.push_back(kInsecureCredentialsType); } + + // For now test callback server only with inproc GPR_ASSERT(!credentials_types.empty()); for (const auto& cred : credentials_types) { - scenarios.emplace_back(false, false, false, cred); - scenarios.emplace_back(true, false, false, cred); + scenarios.emplace_back(false, false, false, cred, false); + scenarios.emplace_back(true, false, false, cred, false); if (use_proxy) { - scenarios.emplace_back(false, true, false, cred); - scenarios.emplace_back(true, true, false, cred); + scenarios.emplace_back(false, true, false, cred, false); + scenarios.emplace_back(true, true, false, cred, false); } } if (test_inproc && insec_ok()) { - scenarios.emplace_back(false, false, true, kInsecureCredentialsType); - scenarios.emplace_back(true, false, true, kInsecureCredentialsType); + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, false); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, false); + if (test_callback_server) { + scenarios.emplace_back(false, false, true, kInsecureCredentialsType, + true); + scenarios.emplace_back(true, false, true, kInsecureCredentialsType, true); + } } return scenarios; } -INSTANTIATE_TEST_CASE_P(End2end, End2endTest, - ::testing::ValuesIn(CreateTestScenarios(false, true, - true, true))); +INSTANTIATE_TEST_CASE_P( + End2end, End2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); -INSTANTIATE_TEST_CASE_P(End2endServerTryCancel, End2endServerTryCancelTest, - ::testing::ValuesIn(CreateTestScenarios(false, true, - true, true))); +INSTANTIATE_TEST_CASE_P( + End2endServerTryCancel, End2endServerTryCancelTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, true))); -INSTANTIATE_TEST_CASE_P(ProxyEnd2end, ProxyEnd2endTest, - ::testing::ValuesIn(CreateTestScenarios(true, true, - true, true))); +INSTANTIATE_TEST_CASE_P( + ProxyEnd2end, ProxyEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(true, true, true, true, false))); -INSTANTIATE_TEST_CASE_P(SecureEnd2end, SecureEnd2endTest, - ::testing::ValuesIn(CreateTestScenarios(false, false, - true, false))); +INSTANTIATE_TEST_CASE_P( + SecureEnd2end, SecureEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, false, true, false, true))); -INSTANTIATE_TEST_CASE_P(ResourceQuotaEnd2end, ResourceQuotaEnd2endTest, - ::testing::ValuesIn(CreateTestScenarios(false, true, - true, true))); +INSTANTIATE_TEST_CASE_P( + ResourceQuotaEnd2end, ResourceQuotaEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios(false, true, true, true, false))); } // namespace } // namespace testing diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc index 1726e87ea6..9d9c01cade 100644 --- a/test/cpp/end2end/test_service_impl.cc +++ b/test/cpp/end2end/test_service_impl.cc @@ -71,6 +71,46 @@ void CheckServerAuthContext( } } // namespace +namespace { +int GetIntValueFromMetadataHelper( + const char* key, + const std::multimap& metadata, + int default_value) { + if (metadata.find(key) != metadata.end()) { + std::istringstream iss(ToString(metadata.find(key)->second)); + iss >> default_value; + gpr_log(GPR_INFO, "%s : %d", key, default_value); + } + + return default_value; +} + +int GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} + +void ServerTryCancel(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + // Now wait until it's really canceled + while (!context->IsCancelled()) { + gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000, GPR_TIMESPAN))); + } +} + +void ServerTryCancelNonblocking(ServerContext* context) { + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); +} + +} // namespace + Status TestServiceImpl::Echo(ServerContext* context, const EchoRequest* request, EchoResponse* response) { // A bit of sleep to make sure that short deadline tests fail @@ -195,6 +235,7 @@ void CallbackTestServiceImpl::EchoNonDelayed( controller->Finish(Status(static_cast(error.code()), error.error_message(), error.binary_error_details())); + return; } int server_try_cancel = GetIntValueFromMetadata( kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); @@ -254,7 +295,7 @@ void CallbackTestServiceImpl::EchoNonDelayed( alarm_.experimental().Set( gpr_time_add( gpr_now(GPR_CLOCK_REALTIME), - gpr_time_from_micros(request->param().client_cancel_after_us(), + gpr_time_from_micros(request->param().server_cancel_after_us(), GPR_TIMESPAN)), [controller](bool) { controller->Finish(Status::CANCELLED); }); return; @@ -279,6 +320,7 @@ void CallbackTestServiceImpl::EchoNonDelayed( request->param().debug_info().SerializeAsString(); context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info); controller->Finish(Status::CANCELLED); + return; } } if (request->has_param() && @@ -325,7 +367,7 @@ Status TestServiceImpl::RequestStream(ServerContext* context, std::thread* server_try_cancel_thd = nullptr; if (server_try_cancel == CANCEL_DURING_PROCESSING) { server_try_cancel_thd = - new std::thread(&TestServiceImpl::ServerTryCancel, this, context); + new std::thread([context] { ServerTryCancel(context); }); } int num_msgs_read = 0; @@ -380,7 +422,7 @@ Status TestServiceImpl::ResponseStream(ServerContext* context, std::thread* server_try_cancel_thd = nullptr; if (server_try_cancel == CANCEL_DURING_PROCESSING) { server_try_cancel_thd = - new std::thread(&TestServiceImpl::ServerTryCancel, this, context); + new std::thread([context] { ServerTryCancel(context); }); } for (int i = 0; i < server_responses_to_send; i++) { @@ -431,7 +473,7 @@ Status TestServiceImpl::BidiStream( std::thread* server_try_cancel_thd = nullptr; if (server_try_cancel == CANCEL_DURING_PROCESSING) { server_try_cancel_thd = - new std::thread(&TestServiceImpl::ServerTryCancel, this, context); + new std::thread([context] { ServerTryCancel(context); }); } // kServerFinishAfterNReads suggests after how many reads, the server should @@ -465,44 +507,244 @@ Status TestServiceImpl::BidiStream( return Status::OK; } -namespace { -int GetIntValueFromMetadataHelper( - const char* key, - const std::multimap& metadata, - int default_value) { - if (metadata.find(key) != metadata.end()) { - std::istringstream iss(ToString(metadata.find(key)->second)); - iss >> default_value; - gpr_log(GPR_INFO, "%s : %d", key, default_value); - } +experimental::ServerReadReactor* +CallbackTestServiceImpl::RequestStream() { + class Reactor : public ::grpc::experimental::ServerReadReactor { + public: + Reactor() {} + void OnStarted(ServerContext* context, EchoResponse* response) override { + ctx_ = context; + response_ = response; + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + server_try_cancel_ = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + + response_->set_message(""); + + if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + return; + } - return default_value; -} -}; // namespace + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx_->TryCancel(); + // Don't wait for it here + } -int TestServiceImpl::GetIntValueFromMetadata( - const char* key, - const std::multimap& metadata, - int default_value) { - return GetIntValueFromMetadataHelper(key, metadata, default_value); + StartRead(&request_); + } + void OnDone() override { delete this; } + void OnCancel() override { FinishOnce(Status::CANCELLED); } + void OnReadDone(bool ok) override { + if (ok) { + response_->mutable_message()->append(request_.message()); + num_msgs_read_++; + StartRead(&request_); + } else { + gpr_log(GPR_INFO, "Read: %d messages", num_msgs_read_); + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + return; + } + if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + return; + } + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + ServerContext* ctx_; + EchoResponse* response_; + EchoRequest request_; + int num_msgs_read_{0}; + int server_try_cancel_; + std::mutex finish_mu_; + bool finished_{false}; + }; + + return new Reactor; } -int CallbackTestServiceImpl::GetIntValueFromMetadata( - const char* key, - const std::multimap& metadata, - int default_value) { - return GetIntValueFromMetadataHelper(key, metadata, default_value); +// Return 'kNumResponseStreamMsgs' messages. +// TODO(yangg) make it generic by adding a parameter into EchoRequest +experimental::ServerWriteReactor* +CallbackTestServiceImpl::ResponseStream() { + class Reactor + : public ::grpc::experimental::ServerWriteReactor { + public: + Reactor() {} + void OnStarted(ServerContext* context, + const EchoRequest* request) override { + ctx_ = context; + request_ = request; + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + server_try_cancel_ = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + server_coalescing_api_ = GetIntValueFromMetadata( + kServerUseCoalescingApi, context->client_metadata(), 0); + server_responses_to_send_ = GetIntValueFromMetadata( + kServerResponseStreamsToSend, context->client_metadata(), + kServerDefaultResponseStreamsToSend); + if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + return; + } + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx_->TryCancel(); + } + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } + } + void OnDone() override { delete this; } + void OnCancel() override { FinishOnce(Status::CANCELLED); } + void OnWriteDone(bool ok) override { + if (num_msgs_sent_ < server_responses_to_send_) { + NextWrite(); + } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel recover this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + void NextWrite() { + response_.set_message(request_->message() + + grpc::to_string(num_msgs_sent_)); + if (num_msgs_sent_ == server_responses_to_send_ - 1 && + server_coalescing_api_ != 0) { + num_msgs_sent_++; + StartWriteLast(&response_, WriteOptions()); + } else { + num_msgs_sent_++; + StartWrite(&response_); + } + } + ServerContext* ctx_; + const EchoRequest* request_; + EchoResponse response_; + int num_msgs_sent_{0}; + int server_try_cancel_; + int server_coalescing_api_; + int server_responses_to_send_; + std::mutex finish_mu_; + bool finished_{false}; + }; + return new Reactor; } -void TestServiceImpl::ServerTryCancel(ServerContext* context) { - EXPECT_FALSE(context->IsCancelled()); - context->TryCancel(); - gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); - // Now wait until it's really canceled - while (!context->IsCancelled()) { - gpr_sleep_until(gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), - gpr_time_from_micros(1000, GPR_TIMESPAN))); - } +experimental::ServerBidiReactor* +CallbackTestServiceImpl::BidiStream() { + class Reactor : public ::grpc::experimental::ServerBidiReactor { + public: + Reactor() {} + void OnStarted(ServerContext* context) override { + ctx_ = context; + // If 'server_try_cancel' is set in the metadata, the RPC is cancelled by + // the server by calling ServerContext::TryCancel() depending on the + // value: + // CANCEL_BEFORE_PROCESSING: The RPC is cancelled before the server + // reads any message from the client CANCEL_DURING_PROCESSING: The RPC + // is cancelled while the server is reading messages from the client + // CANCEL_AFTER_PROCESSING: The RPC is cancelled after the server reads + // all the messages from the client + server_try_cancel_ = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + server_write_last_ = GetIntValueFromMetadata( + kServerFinishAfterNReads, context->client_metadata(), 0); + if (server_try_cancel_ == CANCEL_BEFORE_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + return; + } + + if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + ctx_->TryCancel(); + } + + StartRead(&request_); + } + void OnDone() override { delete this; } + void OnCancel() override { FinishOnce(Status::CANCELLED); } + void OnReadDone(bool ok) override { + if (ok) { + num_msgs_read_++; + gpr_log(GPR_INFO, "recv msg %s", request_.message().c_str()); + response_.set_message(request_.message()); + if (num_msgs_read_ == server_write_last_) { + StartWriteLast(&response_, WriteOptions()); + } else { + StartWrite(&response_); + } + } else if (server_try_cancel_ == CANCEL_DURING_PROCESSING) { + // Let OnCancel handle this + } else if (server_try_cancel_ == CANCEL_AFTER_PROCESSING) { + ServerTryCancelNonblocking(ctx_); + } else { + FinishOnce(Status::OK); + } + } + void OnWriteDone(bool ok) override { StartRead(&request_); } + + private: + void FinishOnce(const Status& s) { + std::lock_guard l(finish_mu_); + if (!finished_) { + Finish(s); + finished_ = true; + } + } + + ServerContext* ctx_; + EchoRequest request_; + EchoResponse response_; + int num_msgs_read_{0}; + int server_try_cancel_; + int server_write_last_; + std::mutex finish_mu_; + bool finished_{false}; + }; + + return new Reactor; } } // namespace testing diff --git a/test/cpp/end2end/test_service_impl.h b/test/cpp/end2end/test_service_impl.h index ddfe94487e..fad7768b87 100644 --- a/test/cpp/end2end/test_service_impl.h +++ b/test/cpp/end2end/test_service_impl.h @@ -72,13 +72,6 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { } private: - int GetIntValueFromMetadata( - const char* key, - const std::multimap& metadata, - int default_value); - - void ServerTryCancel(ServerContext* context); - bool signal_client_; std::mutex mu_; std::unique_ptr host_; @@ -95,6 +88,15 @@ class CallbackTestServiceImpl EchoResponse* response, experimental::ServerCallbackRpcController* controller) override; + experimental::ServerReadReactor* RequestStream() + override; + + experimental::ServerWriteReactor* ResponseStream() + override; + + experimental::ServerBidiReactor* BidiStream() + override; + // Unimplemented is left unimplemented to test the returned error. bool signal_client() { std::unique_lock lock(mu_); @@ -106,11 +108,6 @@ class CallbackTestServiceImpl EchoResponse* response, experimental::ServerCallbackRpcController* controller); - int GetIntValueFromMetadata( - const char* key, - const std::multimap& metadata, - int default_value); - Alarm alarm_; bool signal_client_; std::mutex mu_; -- cgit v1.2.3 From 97de30d7b3f3fdbddf140cc988ccfaf29bd8edab Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Thu, 6 Dec 2018 15:51:31 -0800 Subject: Allow the interceptor to know the method type --- include/grpcpp/impl/codegen/channel_interface.h | 1 - include/grpcpp/impl/codegen/client_context.h | 6 ++- include/grpcpp/impl/codegen/client_interceptor.h | 48 +++++++++++++++++++--- include/grpcpp/impl/codegen/server_context.h | 4 +- include/grpcpp/impl/codegen/server_interceptor.h | 25 +++++++++-- include/grpcpp/impl/codegen/server_interface.h | 18 ++++---- src/cpp/client/channel_cc.cc | 5 ++- src/cpp/server/server_cc.cc | 17 +++++--- .../end2end/client_interceptors_end2end_test.cc | 1 + .../end2end/server_interceptors_end2end_test.cc | 27 +++++++++++- 10 files changed, 121 insertions(+), 31 deletions(-) (limited to 'src/cpp') diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index 728a7b9049..5353f5feaa 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -21,7 +21,6 @@ #include #include -#include #include #include diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 142cfa35dd..0a71f3d9b6 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -418,12 +419,13 @@ class ClientContext { void set_call(grpc_call* call, const std::shared_ptr& channel); experimental::ClientRpcInfo* set_client_rpc_info( - const char* method, grpc::ChannelInterface* channel, + const char* method, internal::RpcMethod::RpcType type, + grpc::ChannelInterface* channel, const std::vector< std::unique_ptr>& creators, size_t interceptor_pos) { - rpc_info_ = experimental::ClientRpcInfo(this, method, channel); + rpc_info_ = experimental::ClientRpcInfo(this, type, method, channel); rpc_info_.RegisterInterceptors(creators, interceptor_pos); return &rpc_info_; } diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f69c99ab22..2bae11a251 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -23,6 +23,7 @@ #include #include +#include #include namespace grpc { @@ -52,23 +53,56 @@ extern experimental::ClientInterceptorFactoryInterface* namespace experimental { class ClientRpcInfo { public: - ClientRpcInfo() {} + // TODO(yashykt): Stop default-constructing ClientRpcInfo and remove UNKNOWN + // from the list of possible Types. + enum class Type { + UNARY, + CLIENT_STREAMING, + SERVER_STREAMING, + BIDI_STREAMING, + UNKNOWN // UNKNOWN is not API and will be removed later + }; ~ClientRpcInfo(){}; ClientRpcInfo(const ClientRpcInfo&) = delete; ClientRpcInfo(ClientRpcInfo&&) = default; - ClientRpcInfo& operator=(ClientRpcInfo&&) = default; // Getter methods - const char* method() { return method_; } + const char* method() const { return method_; } ChannelInterface* channel() { return channel_; } grpc::ClientContext* client_context() { return ctx_; } + Type type() const { return type_; } private: - ClientRpcInfo(grpc::ClientContext* ctx, const char* method, - grpc::ChannelInterface* channel) - : ctx_(ctx), method_(method), channel_(channel) {} + static_assert(Type::UNARY == + static_cast(internal::RpcMethod::NORMAL_RPC), + "violated expectation about Type enum"); + static_assert(Type::CLIENT_STREAMING == + static_cast(internal::RpcMethod::CLIENT_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::SERVER_STREAMING == + static_cast(internal::RpcMethod::SERVER_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::BIDI_STREAMING == + static_cast(internal::RpcMethod::BIDI_STREAMING), + "violated expectation about Type enum"); + + // Default constructor should only be used by ClientContext + ClientRpcInfo() = default; + + // Constructor will only be called from ClientContext + ClientRpcInfo(grpc::ClientContext* ctx, internal::RpcMethod::RpcType type, + const char* method, grpc::ChannelInterface* channel) + : ctx_(ctx), + type_(static_cast(type)), + method_(method), + channel_(channel) {} + + // Move assignment should only be used by ClientContext + // TODO(yashykt): Delete move assignment + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + // Runs interceptor at pos \a pos. void RunInterceptor( experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { @@ -97,6 +131,8 @@ class ClientRpcInfo { } grpc::ClientContext* ctx_ = nullptr; + // TODO(yashykt): make type_ const once move-assignment is deleted + Type type_{Type::UNKNOWN}; const char* method_ = nullptr; grpc::ChannelInterface* channel_ = nullptr; std::vector> interceptors_; diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index 4a5f9e2dd9..ccb5925e7d 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -314,12 +314,12 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } experimental::ServerRpcInfo* set_server_rpc_info( - const char* method, + const char* method, internal::RpcMethod::RpcType type, const std::vector< std::unique_ptr>& creators) { if (creators.size() != 0) { - rpc_info_ = new experimental::ServerRpcInfo(this, method); + rpc_info_ = new experimental::ServerRpcInfo(this, method, type); rpc_info_->RegisterInterceptors(creators); } return rpc_info_; diff --git a/include/grpcpp/impl/codegen/server_interceptor.h b/include/grpcpp/impl/codegen/server_interceptor.h index 5fb5df28b7..cd7c0600b6 100644 --- a/include/grpcpp/impl/codegen/server_interceptor.h +++ b/include/grpcpp/impl/codegen/server_interceptor.h @@ -23,6 +23,7 @@ #include #include +#include #include namespace grpc { @@ -44,6 +45,8 @@ class ServerInterceptorFactoryInterface { class ServerRpcInfo { public: + enum class Type { UNARY, CLIENT_STREAMING, SERVER_STREAMING, BIDI_STREAMING }; + ~ServerRpcInfo(){}; ServerRpcInfo(const ServerRpcInfo&) = delete; @@ -51,12 +54,27 @@ class ServerRpcInfo { ServerRpcInfo& operator=(ServerRpcInfo&&) = default; // Getter methods - const char* method() { return method_; } + const char* method() const { return method_; } + Type type() const { return type_; } grpc::ServerContext* server_context() { return ctx_; } private: - ServerRpcInfo(grpc::ServerContext* ctx, const char* method) - : ctx_(ctx), method_(method) { + static_assert(Type::UNARY == + static_cast(internal::RpcMethod::NORMAL_RPC), + "violated expectation about Type enum"); + static_assert(Type::CLIENT_STREAMING == + static_cast(internal::RpcMethod::CLIENT_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::SERVER_STREAMING == + static_cast(internal::RpcMethod::SERVER_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::BIDI_STREAMING == + static_cast(internal::RpcMethod::BIDI_STREAMING), + "violated expectation about Type enum"); + + ServerRpcInfo(grpc::ServerContext* ctx, const char* method, + internal::RpcMethod::RpcType type) + : ctx_(ctx), method_(method), type_(static_cast(type)) { ref_.store(1); } @@ -86,6 +104,7 @@ class ServerRpcInfo { grpc::ServerContext* ctx_ = nullptr; const char* method_ = nullptr; + const Type type_; std::atomic_int ref_; std::vector> interceptors_; diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 55c94f4d2f..e0e2629827 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -174,13 +174,14 @@ class ServerInterface : public internal::CallHook { bool done_intercepting_; }; + /// RegisteredAsyncRequest is not part of the C++ API class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, - const char* name); + const char* name, internal::RpcMethod::RpcType type); virtual bool FinalizeResult(void** tag, bool* status) override { /* If we are done intercepting, then there is nothing more for us to do */ @@ -189,7 +190,7 @@ class ServerInterface : public internal::CallHook { } call_wrapper_ = internal::Call( call_, server_, call_cq_, server_->max_receive_message_size(), - context_->set_server_rpc_info(name_, + context_->set_server_rpc_info(name_, type_, *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } @@ -198,6 +199,7 @@ class ServerInterface : public internal::CallHook { void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); const char* name_; + const internal::RpcMethod::RpcType type_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { @@ -207,9 +209,9 @@ class ServerInterface : public internal::CallHook { internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, - notification_cq, tag, - registered_method->name()) { + : RegisteredAsyncRequest( + server, context, stream, call_cq, notification_cq, tag, + registered_method->name(), registered_method->method_type()) { IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } @@ -225,9 +227,9 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, - notification_cq, tag, - registered_method->name()), + : RegisteredAsyncRequest( + server, context, stream, call_cq, notification_cq, tag, + registered_method->name(), registered_method->method_type()), registered_method_(registered_method), server_(server), context_(context), diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index d1c55319f7..a31d0b30b1 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -149,8 +149,9 @@ internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, // ClientRpcInfo should be set before call because set_call also checks // whether the call has been cancelled, and if the call was cancelled, we // should notify the interceptors too/ - auto* info = context->set_client_rpc_info( - method.name(), this, interceptor_creators_, interceptor_pos); + auto* info = + context->set_client_rpc_info(method.name(), method.method_type(), this, + interceptor_creators_, interceptor_pos); context->set_call(c_call, shared_from_this()); return internal::Call(c_call, this, cq, info); diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 69af43a656..1e3c57446f 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -236,9 +236,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { : nullptr), request_(nullptr), method_(mrd->method_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size(), - ctx_.set_server_rpc_info(method_->name(), - server->interceptor_creators_)), + call_( + mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), method_->method_type(), + server->interceptor_creators_)), server_(server), global_callbacks_(nullptr), resources_(false) { @@ -427,7 +428,8 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { req_->call_, req_->server_, req_->cq_, req_->server_->max_receive_message_size(), req_->ctx_.set_server_rpc_info( - req_->method_->name(), req_->server_->interceptor_creators_)); + req_->method_->name(), req_->method_->method_type(), + req_->server_->interceptor_creators_)); req_->interceptor_methods_.SetCall(call_); req_->interceptor_methods_.SetReverse(); @@ -1041,10 +1043,12 @@ void ServerInterface::BaseAsyncRequest:: ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag, const char* name) + ServerCompletionQueue* notification_cq, void* tag, const char* name, + internal::RpcMethod::RpcType type) : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, true), - name_(name) {} + name_(name), + type_(type) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -1091,6 +1095,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, call_, server_, call_cq_, server_->max_receive_message_size(), context_->set_server_rpc_info( static_cast(context_)->method_.c_str(), + internal::RpcMethod::BIDI_STREAMING, *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 3a191d1e03..f55ece1c2b 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -50,6 +50,7 @@ class HijackingInterceptor : public experimental::Interceptor { info_ = info; // Make sure it is the right method EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY); } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc index c98b6143c6..1e0e366870 100644 --- a/test/cpp/end2end/server_interceptors_end2end_test.cc +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -44,7 +44,32 @@ namespace { class LoggingInterceptor : public experimental::Interceptor { public: - LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; } + LoggingInterceptor(experimental::ServerRpcInfo* info) { + info_ = info; + + // Check the method name and compare to the type + const char* method = info->method(); + experimental::ServerRpcInfo::Type type = info->type(); + + // Check that we use one of our standard methods with expected type. + // We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService + // being tested (the GenericRpc test). + // The empty method is for the Unimplemented requests that arise + // when draining the CQ. + EXPECT_TRUE( + (strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 && + (type == experimental::ServerRpcInfo::Type::UNARY || + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) || + (strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 && + type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 && + type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) || + strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 || + (strcmp(method, "") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)); + } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { if (methods->QueryInterceptionHookPoint( -- cgit v1.2.3 From f1f557bc43992ade0f09c9240c2d5c71c3478f0a Mon Sep 17 00:00:00 2001 From: yang-g Date: Thu, 6 Dec 2018 16:16:58 -0800 Subject: Add a Shutdown call to HealthCheckServiceInterface --- include/grpcpp/health_check_service_interface.h | 4 + .../server/health/default_health_check_service.cc | 18 ++++ .../server/health/default_health_check_service.h | 3 + test/cpp/end2end/health_service_end2end_test.cc | 107 +++++++++++++++++++++ 4 files changed, 132 insertions(+) (limited to 'src/cpp') diff --git a/include/grpcpp/health_check_service_interface.h b/include/grpcpp/health_check_service_interface.h index b45a699bda..dfd4c3983a 100644 --- a/include/grpcpp/health_check_service_interface.h +++ b/include/grpcpp/health_check_service_interface.h @@ -37,6 +37,10 @@ class HealthCheckServiceInterface { bool serving) = 0; /// Apply to all registered service names. virtual void SetServingStatus(bool serving) = 0; + + /// Set all registered service names to not serving and prevent future + /// state changes. + virtual void Shutdown() {} }; /// Enable/disable the default health checking service. This applies to all C++ diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc index c951c69d51..db6286d240 100644 --- a/src/cpp/server/health/default_health_check_service.cc +++ b/src/cpp/server/health/default_health_check_service.cc @@ -42,18 +42,36 @@ DefaultHealthCheckService::DefaultHealthCheckService() { void DefaultHealthCheckService::SetServingStatus( const grpc::string& service_name, bool serving) { std::unique_lock lock(mu_); + if (shutdown_) { + return; + } services_map_[service_name].SetServingStatus(serving ? SERVING : NOT_SERVING); } void DefaultHealthCheckService::SetServingStatus(bool serving) { const ServingStatus status = serving ? SERVING : NOT_SERVING; std::unique_lock lock(mu_); + if (shutdown_) { + return; + } for (auto& p : services_map_) { ServiceData& service_data = p.second; service_data.SetServingStatus(status); } } +void DefaultHealthCheckService::Shutdown() { + std::unique_lock lock(mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto& p : services_map_) { + ServiceData& service_data = p.second; + service_data.SetServingStatus(NOT_SERVING); + } +} + DefaultHealthCheckService::ServingStatus DefaultHealthCheckService::GetServingStatus( const grpc::string& service_name) const { diff --git a/src/cpp/server/health/default_health_check_service.h b/src/cpp/server/health/default_health_check_service.h index 450bd543f5..9551cd2e2c 100644 --- a/src/cpp/server/health/default_health_check_service.h +++ b/src/cpp/server/health/default_health_check_service.h @@ -237,6 +237,8 @@ class DefaultHealthCheckService final : public HealthCheckServiceInterface { bool serving) override; void SetServingStatus(bool serving) override; + void Shutdown() override; + ServingStatus GetServingStatus(const grpc::string& service_name) const; HealthCheckServiceImpl* GetHealthCheckService( @@ -272,6 +274,7 @@ class DefaultHealthCheckService final : public HealthCheckServiceInterface { const std::shared_ptr& handler); mutable std::mutex mu_; + bool shutdown_ = false; // Guarded by mu_. std::map services_map_; // Guarded by mu_. std::unique_ptr impl_; }; diff --git a/test/cpp/end2end/health_service_end2end_test.cc b/test/cpp/end2end/health_service_end2end_test.cc index 89c4bef09c..a439d42b03 100644 --- a/test/cpp/end2end/health_service_end2end_test.cc +++ b/test/cpp/end2end/health_service_end2end_test.cc @@ -90,18 +90,36 @@ class HealthCheckServiceImpl : public ::grpc::health::v1::Health::Service { void SetStatus(const grpc::string& service_name, HealthCheckResponse::ServingStatus status) { std::lock_guard lock(mu_); + if (shutdown_) { + return; + } status_map_[service_name] = status; } void SetAll(HealthCheckResponse::ServingStatus status) { std::lock_guard lock(mu_); + if (shutdown_) { + return; + } for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { iter->second = status; } } + void Shutdown() { + std::lock_guard lock(mu_); + if (shutdown_) { + return; + } + shutdown_ = true; + for (auto iter = status_map_.begin(); iter != status_map_.end(); ++iter) { + iter->second = HealthCheckResponse::NOT_SERVING; + } + } + private: std::mutex mu_; + bool shutdown_ = false; std::map status_map_; }; @@ -125,6 +143,8 @@ class CustomHealthCheckService : public HealthCheckServiceInterface { : HealthCheckResponse::NOT_SERVING); } + void Shutdown() override { impl_->Shutdown(); } + private: HealthCheckServiceImpl* impl_; // not owned }; @@ -260,6 +280,71 @@ class HealthServiceEnd2endTest : public ::testing::Test { context.TryCancel(); } + // Verify that after HealthCheckServiceInterface::Shutdown is called + // 1. unary client will see NOT_SERVING. + // 2. unary client still sees NOT_SERVING after a SetServing(true) is called. + // 3. streaming (Watch) client will see an update. + // This has to be called last. + void VerifyHealthCheckServiceShutdown() { + const grpc::string kServiceName("service_name"); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service != nullptr); + const grpc::string kHealthyService("healthy_service"); + const grpc::string kUnhealthyService("unhealthy_service"); + const grpc::string kNotRegisteredService("not_registered"); + service->SetServingStatus(kHealthyService, true); + service->SetServingStatus(kUnhealthyService, false); + + ResetStubs(); + + // Start Watch for service. + ClientContext context; + HealthCheckRequest request; + request.set_service(kServiceName); + std::unique_ptr<::grpc::ClientReaderInterface> reader = + hc_stub_->Watch(&context, request); + + // Initial response will be SERVICE_UNKNOWN. + HealthCheckResponse response; + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVICE_UNKNOWN, response.status()); + + // Set service to SERVING and make sure we get an update. + service->SetServingStatus(kServiceName, true); + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.SERVING, response.status()); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + // Shutdown health check service. + service->Shutdown(); + + // Watch client gets another update. + EXPECT_TRUE(reader->Read(&response)); + EXPECT_EQ(response.NOT_SERVING, response.status()); + // Finish Watch call. + context.TryCancel(); + + SendHealthCheckRpc("", Status::OK, HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kUnhealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + SendHealthCheckRpc(kNotRegisteredService, + Status(StatusCode::NOT_FOUND, "")); + + // Setting status after Shutdown has no effect. + service->SetServingStatus(kHealthyService, true); + SendHealthCheckRpc(kHealthyService, Status::OK, + HealthCheckResponse::NOT_SERVING); + } + TestServiceImpl echo_test_service_; HealthCheckServiceImpl health_check_service_impl_; std::unique_ptr hc_stub_; @@ -295,6 +380,13 @@ TEST_F(HealthServiceEnd2endTest, DefaultHealthService) { Status(StatusCode::INVALID_ARGUMENT, "")); } +TEST_F(HealthServiceEnd2endTest, DefaultHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + SetUpServer(true, false, false, nullptr); + VerifyHealthCheckServiceShutdown(); +} + // Provide an empty service to disable the default service. TEST_F(HealthServiceEnd2endTest, ExplicitlyDisableViaOverride) { EnableDefaultHealthCheckService(true); @@ -326,6 +418,21 @@ TEST_F(HealthServiceEnd2endTest, ExplicitlyOverride) { VerifyHealthCheckServiceStreaming(); } +TEST_F(HealthServiceEnd2endTest, ExplicitlyHealthServiceShutdown) { + EnableDefaultHealthCheckService(true); + EXPECT_TRUE(DefaultHealthCheckServiceEnabled()); + std::unique_ptr override_service( + new CustomHealthCheckService(&health_check_service_impl_)); + HealthCheckServiceInterface* underlying_service = override_service.get(); + SetUpServer(false, false, true, std::move(override_service)); + HealthCheckServiceInterface* service = server_->GetHealthCheckService(); + EXPECT_TRUE(service == underlying_service); + + ResetStubs(); + + VerifyHealthCheckServiceShutdown(); +} + } // namespace } // namespace testing } // namespace grpc -- cgit v1.2.3 From 2246607dedfbad98c3aa4f39997907a203985863 Mon Sep 17 00:00:00 2001 From: yang-g Date: Fri, 7 Dec 2018 08:54:47 -0800 Subject: Review comments --- src/cpp/server/health/default_health_check_service.cc | 3 ++- test/cpp/end2end/health_service_end2end_test.cc | 19 +++++++++++-------- test/cpp/end2end/test_health_check_service_impl.cc | 3 ++- 3 files changed, 15 insertions(+), 10 deletions(-) (limited to 'src/cpp') diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc index db6286d240..44aebd2f9d 100644 --- a/src/cpp/server/health/default_health_check_service.cc +++ b/src/cpp/server/health/default_health_check_service.cc @@ -43,7 +43,8 @@ void DefaultHealthCheckService::SetServingStatus( const grpc::string& service_name, bool serving) { std::unique_lock lock(mu_); if (shutdown_) { - return; + // Set to NOT_SERVING in case service_name is not in the map. + serving = false; } services_map_[service_name].SetServingStatus(serving ? SERVING : NOT_SERVING); } diff --git a/test/cpp/end2end/health_service_end2end_test.cc b/test/cpp/end2end/health_service_end2end_test.cc index 94c92327c8..b96ff53a3e 100644 --- a/test/cpp/end2end/health_service_end2end_test.cc +++ b/test/cpp/end2end/health_service_end2end_test.cc @@ -211,14 +211,16 @@ class HealthServiceEnd2endTest : public ::testing::Test { // 1. unary client will see NOT_SERVING. // 2. unary client still sees NOT_SERVING after a SetServing(true) is called. // 3. streaming (Watch) client will see an update. + // 4. setting a new service to serving after shutdown will add the service + // name but return NOT_SERVING to client. // This has to be called last. void VerifyHealthCheckServiceShutdown() { - const grpc::string kServiceName("service_name"); HealthCheckServiceInterface* service = server_->GetHealthCheckService(); EXPECT_TRUE(service != nullptr); const grpc::string kHealthyService("healthy_service"); const grpc::string kUnhealthyService("unhealthy_service"); const grpc::string kNotRegisteredService("not_registered"); + const grpc::string kNewService("add_after_shutdown"); service->SetServingStatus(kHealthyService, true); service->SetServingStatus(kUnhealthyService, false); @@ -227,18 +229,12 @@ class HealthServiceEnd2endTest : public ::testing::Test { // Start Watch for service. ClientContext context; HealthCheckRequest request; - request.set_service(kServiceName); + request.set_service(kHealthyService); std::unique_ptr<::grpc::ClientReaderInterface> reader = hc_stub_->Watch(&context, request); - // Initial response will be SERVICE_UNKNOWN. HealthCheckResponse response; EXPECT_TRUE(reader->Read(&response)); - EXPECT_EQ(response.SERVICE_UNKNOWN, response.status()); - - // Set service to SERVING and make sure we get an update. - service->SetServingStatus(kServiceName, true); - EXPECT_TRUE(reader->Read(&response)); EXPECT_EQ(response.SERVING, response.status()); SendHealthCheckRpc("", Status::OK, HealthCheckResponse::SERVING); @@ -248,6 +244,7 @@ class HealthServiceEnd2endTest : public ::testing::Test { HealthCheckResponse::NOT_SERVING); SendHealthCheckRpc(kNotRegisteredService, Status(StatusCode::NOT_FOUND, "")); + SendHealthCheckRpc(kNewService, Status(StatusCode::NOT_FOUND, "")); // Shutdown health check service. service->Shutdown(); @@ -270,6 +267,12 @@ class HealthServiceEnd2endTest : public ::testing::Test { service->SetServingStatus(kHealthyService, true); SendHealthCheckRpc(kHealthyService, Status::OK, HealthCheckResponse::NOT_SERVING); + + // Adding serving status for a new service after shutdown will return + // NOT_SERVING. + service->SetServingStatus(kNewService, true); + SendHealthCheckRpc(kNewService, Status::OK, + HealthCheckResponse::NOT_SERVING); } TestServiceImpl echo_test_service_; diff --git a/test/cpp/end2end/test_health_check_service_impl.cc b/test/cpp/end2end/test_health_check_service_impl.cc index fa70a44d24..0801e30199 100644 --- a/test/cpp/end2end/test_health_check_service_impl.cc +++ b/test/cpp/end2end/test_health_check_service_impl.cc @@ -37,6 +37,7 @@ Status HealthCheckServiceImpl::Check(ServerContext* context, response->set_status(iter->second); return Status::OK; } + Status HealthCheckServiceImpl::Watch( ServerContext* context, const HealthCheckRequest* request, ::grpc::ServerWriter* writer) { @@ -66,7 +67,7 @@ void HealthCheckServiceImpl::SetStatus( HealthCheckResponse::ServingStatus status) { std::lock_guard lock(mu_); if (shutdown_) { - return; + status = HealthCheckResponse::NOT_SERVING; } status_map_[service_name] = status; } -- cgit v1.2.3 From c7f7db65e0fdc058b4caa2b76baaa308465fda9e Mon Sep 17 00:00:00 2001 From: ncteisen Date: Tue, 11 Dec 2018 11:28:12 -0800 Subject: Add test and fix bug --- src/core/lib/channel/channelz.cc | 7 ++- src/cpp/server/channelz/channelz_service.cc | 4 +- test/cpp/end2end/channelz_service_test.cc | 82 +++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 6 deletions(-) (limited to 'src/cpp') diff --git a/src/core/lib/channel/channelz.cc b/src/core/lib/channel/channelz.cc index 10d5e1b581..8a596ad460 100644 --- a/src/core/lib/channel/channelz.cc +++ b/src/core/lib/channel/channelz.cc @@ -205,7 +205,7 @@ ServerNode::~ServerNode() {} char* ServerNode::RenderServerSockets(intptr_t start_socket_id, intptr_t max_results) { - // if user does not set max_results, we choose 1000. + // if user does not set max_results, we choose 500. size_t pagination_limit = max_results == 0 ? 500 : max_results; grpc_json* top_level_json = grpc_json_create(GRPC_JSON_OBJECT); grpc_json* json = top_level_json; @@ -219,9 +219,8 @@ char* ServerNode::RenderServerSockets(intptr_t start_socket_id, grpc_json* array_parent = grpc_json_create_child( nullptr, json, "socketRef", nullptr, GRPC_JSON_ARRAY, false); for (i = 0; i < GPR_MIN(socket_refs.size(), pagination_limit); ++i) { - grpc_json* socket_ref_json = - grpc_json_create_child(json_iterator, array_parent, nullptr, nullptr, - GRPC_JSON_OBJECT, false); + grpc_json* socket_ref_json = grpc_json_create_child( + nullptr, array_parent, nullptr, nullptr, GRPC_JSON_OBJECT, false); json_iterator = grpc_json_add_number_string_child( socket_ref_json, nullptr, "socketId", socket_refs[i]->uuid()); grpc_json_create_child(json_iterator, socket_ref_json, "name", diff --git a/src/cpp/server/channelz/channelz_service.cc b/src/cpp/server/channelz/channelz_service.cc index 9ecb9de7e4..c44a9ac0de 100644 --- a/src/cpp/server/channelz/channelz_service.cc +++ b/src/cpp/server/channelz/channelz_service.cc @@ -79,8 +79,8 @@ Status ChannelzService::GetServer(ServerContext* unused, Status ChannelzService::GetServerSockets( ServerContext* unused, const channelz::v1::GetServerSocketsRequest* request, channelz::v1::GetServerSocketsResponse* response) { - char* json_str = grpc_channelz_get_server_sockets(request->server_id(), - request->start_socket_id()); + char* json_str = grpc_channelz_get_server_sockets( + request->server_id(), request->start_socket_id(), request->max_results()); if (json_str == nullptr) { return Status(StatusCode::INTERNAL, "grpc_channelz_get_server_sockets returned null"); diff --git a/test/cpp/end2end/channelz_service_test.cc b/test/cpp/end2end/channelz_service_test.cc index 29b59e4e5e..425334d972 100644 --- a/test/cpp/end2end/channelz_service_test.cc +++ b/test/cpp/end2end/channelz_service_test.cc @@ -54,6 +54,14 @@ using grpc::channelz::v1::GetSubchannelResponse; using grpc::channelz::v1::GetTopChannelsRequest; using grpc::channelz::v1::GetTopChannelsResponse; +// This code snippet can be used to print out any responses for +// visual debugging. +// +// +// string out_str; +// google::protobuf::TextFormat::PrintToString(resp, &out_str); +// std::cout << "resp: " << out_str << "\n"; + namespace grpc { namespace testing { namespace { @@ -164,6 +172,19 @@ class ChannelzServerTest : public ::testing::Test { echo_stub_ = grpc::testing::EchoTestService::NewStub(channel); } + std::unique_ptr NewEchoStub() { + static int salt = 0; + string target = "dns:localhost:" + to_string(proxy_port_); + ChannelArguments args; + // disable channelz. We only want to focus on proxy to backend outbound. + args.SetInt(GRPC_ARG_ENABLE_CHANNELZ, 0); + // This ensures that gRPC will not do connection sharing. + args.SetInt("salt", salt++); + std::shared_ptr channel = + CreateCustomChannel(target, InsecureChannelCredentials(), args); + return grpc::testing::EchoTestService::NewStub(channel); + } + void SendSuccessfulEcho(int channel_idx) { EchoRequest request; EchoResponse response; @@ -651,6 +672,67 @@ TEST_F(ChannelzServerTest, GetServerSocketsTest) { EXPECT_EQ(get_server_sockets_response.socket_ref_size(), 1); } +TEST_F(ChannelzServerTest, GetServerSocketsPaginationTest) { + ResetStubs(); + ConfigureProxy(1); + std::vector> stubs; + const int kNumServerSocketsCreated = 20; + for (int i = 0; i < kNumServerSocketsCreated; ++i) { + stubs.push_back(NewEchoStub()); + EchoRequest request; + EchoResponse response; + request.set_message("Hello channelz"); + request.mutable_param()->set_backend_channel_idx(0); + ClientContext context; + Status s = stubs.back()->Echo(&context, request, &response); + EXPECT_EQ(response.message(), request.message()); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + } + GetServersRequest get_server_request; + GetServersResponse get_server_response; + get_server_request.set_start_server_id(0); + ClientContext get_server_context; + Status s = channelz_stub_->GetServers(&get_server_context, get_server_request, + &get_server_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_response.server_size(), 1); + // Make a request that gets all of the serversockets + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + // We add one to account the the channelz stub that will end up creating + // a serversocket. + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), + kNumServerSocketsCreated + 1); + EXPECT_TRUE(get_server_sockets_response.end()); + } + // Now we make a request that exercises pagination. + { + GetServerSocketsRequest get_server_sockets_request; + GetServerSocketsResponse get_server_sockets_response; + get_server_sockets_request.set_server_id( + get_server_response.server(0).ref().server_id()); + get_server_sockets_request.set_start_socket_id(0); + const int kMaxResults = 10; + get_server_sockets_request.set_max_results(kMaxResults); + ClientContext get_server_sockets_context; + s = channelz_stub_->GetServerSockets(&get_server_sockets_context, + get_server_sockets_request, + &get_server_sockets_response); + EXPECT_TRUE(s.ok()) << "s.error_message() = " << s.error_message(); + EXPECT_EQ(get_server_sockets_response.socket_ref_size(), kMaxResults); + EXPECT_FALSE(get_server_sockets_response.end()); + } +} + TEST_F(ChannelzServerTest, GetServerListenSocketsTest) { ResetStubs(); ConfigureProxy(1); -- cgit v1.2.3 From d550af373cc993ee80e9e350d60f4ca662b1ec28 Mon Sep 17 00:00:00 2001 From: "Nicolas \"Pixel\" Noble" Date: Wed, 12 Dec 2018 04:09:42 +0100 Subject: Moving ::grpc::Alarm to ::grpc_impl::Alarm. --- BUILD | 1 + CMakeLists.txt | 3 + Makefile | 3 + build.yaml | 1 + gRPC-C++.podspec | 1 + include/grpcpp/alarm.h | 93 +---------------- include/grpcpp/alarm_impl.h | 116 +++++++++++++++++++++ src/cpp/common/alarm.cc | 13 +-- tools/doxygen/Doxyfile.c++ | 1 + tools/doxygen/Doxyfile.c++.internal | 1 + tools/run_tests/generated/sources_and_headers.json | 2 + 11 files changed, 140 insertions(+), 95 deletions(-) create mode 100644 include/grpcpp/alarm_impl.h (limited to 'src/cpp') diff --git a/BUILD b/BUILD index 9a2c16c601..cb03e560d5 100644 --- a/BUILD +++ b/BUILD @@ -204,6 +204,7 @@ GRPCXX_PUBLIC_HDRS = [ "include/grpc++/support/sync_stream.h", "include/grpc++/support/time.h", "include/grpcpp/alarm.h", + "include/grpcpp/alarm_impl.h", "include/grpcpp/channel.h", "include/grpcpp/client_context.h", "include/grpcpp/completion_queue.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 23b4bb77e7..e843397e00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2985,6 +2985,7 @@ foreach(_hdr include/grpc++/support/sync_stream.h include/grpc++/support/time.h include/grpcpp/alarm.h + include/grpcpp/alarm_impl.h include/grpcpp/channel.h include/grpcpp/client_context.h include/grpcpp/completion_queue.h @@ -3569,6 +3570,7 @@ foreach(_hdr include/grpc++/support/sync_stream.h include/grpc++/support/time.h include/grpcpp/alarm.h + include/grpcpp/alarm_impl.h include/grpcpp/channel.h include/grpcpp/client_context.h include/grpcpp/completion_queue.h @@ -4504,6 +4506,7 @@ foreach(_hdr include/grpc++/support/sync_stream.h include/grpc++/support/time.h include/grpcpp/alarm.h + include/grpcpp/alarm_impl.h include/grpcpp/channel.h include/grpcpp/client_context.h include/grpcpp/completion_queue.h diff --git a/Makefile b/Makefile index 4f7cd13012..438e5729d5 100644 --- a/Makefile +++ b/Makefile @@ -5340,6 +5340,7 @@ PUBLIC_HEADERS_CXX += \ include/grpc++/support/sync_stream.h \ include/grpc++/support/time.h \ include/grpcpp/alarm.h \ + include/grpcpp/alarm_impl.h \ include/grpcpp/channel.h \ include/grpcpp/client_context.h \ include/grpcpp/completion_queue.h \ @@ -5933,6 +5934,7 @@ PUBLIC_HEADERS_CXX += \ include/grpc++/support/sync_stream.h \ include/grpc++/support/time.h \ include/grpcpp/alarm.h \ + include/grpcpp/alarm_impl.h \ include/grpcpp/channel.h \ include/grpcpp/client_context.h \ include/grpcpp/completion_queue.h \ @@ -6835,6 +6837,7 @@ PUBLIC_HEADERS_CXX += \ include/grpc++/support/sync_stream.h \ include/grpc++/support/time.h \ include/grpcpp/alarm.h \ + include/grpcpp/alarm_impl.h \ include/grpcpp/channel.h \ include/grpcpp/client_context.h \ include/grpcpp/completion_queue.h \ diff --git a/build.yaml b/build.yaml index 381e649b39..74a5f5e6dc 100644 --- a/build.yaml +++ b/build.yaml @@ -1323,6 +1323,7 @@ filegroups: - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpcpp/alarm.h + - include/grpcpp/alarm_impl.h - include/grpcpp/channel.h - include/grpcpp/client_context.h - include/grpcpp/completion_queue.h diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index 6647201714..3da94dda4b 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -78,6 +78,7 @@ Pod::Spec.new do |s| ss.header_mappings_dir = 'include/grpcpp' ss.source_files = 'include/grpcpp/alarm.h', + 'include/grpcpp/alarm_impl.h', 'include/grpcpp/channel.h', 'include/grpcpp/client_context.h', 'include/grpcpp/completion_queue.h', diff --git a/include/grpcpp/alarm.h b/include/grpcpp/alarm.h index 365feb4eb9..2343c1149c 100644 --- a/include/grpcpp/alarm.h +++ b/include/grpcpp/alarm.h @@ -1,6 +1,6 @@ /* * - * Copyright 2015 gRPC authors. + * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,99 +16,14 @@ * */ -/// An Alarm posts the user provided tag to its associated completion queue upon -/// expiry or cancellation. #ifndef GRPCPP_ALARM_H #define GRPCPP_ALARM_H -#include - -#include -#include -#include -#include -#include -#include +#include namespace grpc { -/// A thin wrapper around \a grpc_alarm (see / \a / src/core/surface/alarm.h). -class Alarm : private GrpcLibraryCodegen { - public: - /// Create an unset completion queue alarm - Alarm(); - - /// Destroy the given completion queue alarm, cancelling it in the process. - ~Alarm(); - - /// DEPRECATED: Create and set a completion queue alarm instance associated to - /// \a cq. - /// This form is deprecated because it is inherently racy. - /// \internal We rely on the presence of \a cq for grpc initialization. If \a - /// cq were ever to be removed, a reference to a static - /// internal::GrpcLibraryInitializer instance would need to be introduced - /// here. \endinternal. - template - Alarm(CompletionQueue* cq, const T& deadline, void* tag) : Alarm() { - SetInternal(cq, TimePoint(deadline).raw_time(), tag); - } - - /// Trigger an alarm instance on completion queue \a cq at the specified time. - /// Once the alarm expires (at \a deadline) or it's cancelled (see \a Cancel), - /// an event with tag \a tag will be added to \a cq. If the alarm expired, the - /// event's success bit will be true, false otherwise (ie, upon cancellation). - template - void Set(CompletionQueue* cq, const T& deadline, void* tag) { - SetInternal(cq, TimePoint(deadline).raw_time(), tag); - } - - /// Alarms aren't copyable. - Alarm(const Alarm&) = delete; - Alarm& operator=(const Alarm&) = delete; - - /// Alarms are movable. - Alarm(Alarm&& rhs) : alarm_(rhs.alarm_) { rhs.alarm_ = nullptr; } - Alarm& operator=(Alarm&& rhs) { - alarm_ = rhs.alarm_; - rhs.alarm_ = nullptr; - return *this; - } - - /// Cancel a completion queue alarm. Calling this function over an alarm that - /// has already fired has no effect. - void Cancel(); - - /// NOTE: class experimental_type is not part of the public API of this class - /// TODO(vjpai): Move these contents to the public API of Alarm when - /// they are no longer experimental - class experimental_type { - public: - explicit experimental_type(Alarm* alarm) : alarm_(alarm) {} - - /// Set an alarm to invoke callback \a f. The argument to the callback - /// states whether the alarm expired at \a deadline (true) or was cancelled - /// (false) - template - void Set(const T& deadline, std::function f) { - alarm_->SetInternal(TimePoint(deadline).raw_time(), std::move(f)); - } - - private: - Alarm* alarm_; - }; - - /// NOTE: The function experimental() is not stable public API. It is a view - /// to the experimental components of this class. It may be changed or removed - /// at any time. - experimental_type experimental() { return experimental_type(this); } - - private: - void SetInternal(CompletionQueue* cq, gpr_timespec deadline, void* tag); - void SetInternal(gpr_timespec deadline, std::function f); - - internal::CompletionQueueTag* alarm_; -}; - -} // namespace grpc +typedef ::grpc_impl::Alarm Alarm; +} #endif // GRPCPP_ALARM_H diff --git a/include/grpcpp/alarm_impl.h b/include/grpcpp/alarm_impl.h new file mode 100644 index 0000000000..7844e7c886 --- /dev/null +++ b/include/grpcpp/alarm_impl.h @@ -0,0 +1,116 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/// An Alarm posts the user provided tag to its associated completion queue upon +/// expiry or cancellation. +#ifndef GRPCPP_ALARM_IMPL_H +#define GRPCPP_ALARM_IMPL_H + +#include + +#include +#include +#include +#include +#include +#include + +namespace grpc_impl { + +/// A thin wrapper around \a grpc_alarm (see / \a / src/core/surface/alarm.h). +class Alarm : private ::grpc::GrpcLibraryCodegen { + public: + /// Create an unset completion queue alarm + Alarm(); + + /// Destroy the given completion queue alarm, cancelling it in the process. + ~Alarm(); + + /// DEPRECATED: Create and set a completion queue alarm instance associated to + /// \a cq. + /// This form is deprecated because it is inherently racy. + /// \internal We rely on the presence of \a cq for grpc initialization. If \a + /// cq were ever to be removed, a reference to a static + /// internal::GrpcLibraryInitializer instance would need to be introduced + /// here. \endinternal. + template + Alarm(::grpc::CompletionQueue* cq, const T& deadline, void* tag) : Alarm() { + SetInternal(cq, ::grpc::TimePoint(deadline).raw_time(), tag); + } + + /// Trigger an alarm instance on completion queue \a cq at the specified time. + /// Once the alarm expires (at \a deadline) or it's cancelled (see \a Cancel), + /// an event with tag \a tag will be added to \a cq. If the alarm expired, the + /// event's success bit will be true, false otherwise (ie, upon cancellation). + template + void Set(::grpc::CompletionQueue* cq, const T& deadline, void* tag) { + SetInternal(cq, ::grpc::TimePoint(deadline).raw_time(), tag); + } + + /// Alarms aren't copyable. + Alarm(const Alarm&) = delete; + Alarm& operator=(const Alarm&) = delete; + + /// Alarms are movable. + Alarm(Alarm&& rhs) : alarm_(rhs.alarm_) { rhs.alarm_ = nullptr; } + Alarm& operator=(Alarm&& rhs) { + alarm_ = rhs.alarm_; + rhs.alarm_ = nullptr; + return *this; + } + + /// Cancel a completion queue alarm. Calling this function over an alarm that + /// has already fired has no effect. + void Cancel(); + + /// NOTE: class experimental_type is not part of the public API of this class + /// TODO(vjpai): Move these contents to the public API of Alarm when + /// they are no longer experimental + class experimental_type { + public: + explicit experimental_type(Alarm* alarm) : alarm_(alarm) {} + + /// Set an alarm to invoke callback \a f. The argument to the callback + /// states whether the alarm expired at \a deadline (true) or was cancelled + /// (false) + template + void Set(const T& deadline, std::function f) { + alarm_->SetInternal(::grpc::TimePoint(deadline).raw_time(), + std::move(f)); + } + + private: + Alarm* alarm_; + }; + + /// NOTE: The function experimental() is not stable public API. It is a view + /// to the experimental components of this class. It may be changed or removed + /// at any time. + experimental_type experimental() { return experimental_type(this); } + + private: + void SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline, + void* tag); + void SetInternal(gpr_timespec deadline, std::function f); + + ::grpc::internal::CompletionQueueTag* alarm_; +}; + +} // namespace grpc_impl + +#endif // GRPCPP_ALARM_IMPL_H diff --git a/src/cpp/common/alarm.cc b/src/cpp/common/alarm.cc index 5819a4210b..148f0b9bc9 100644 --- a/src/cpp/common/alarm.cc +++ b/src/cpp/common/alarm.cc @@ -31,10 +31,10 @@ #include #include "src/core/lib/debug/trace.h" -namespace grpc { +namespace grpc_impl { namespace internal { -class AlarmImpl : public CompletionQueueTag { +class AlarmImpl : public ::grpc::internal::CompletionQueueTag { public: AlarmImpl() : cq_(nullptr), tag_(nullptr) { gpr_ref_init(&refs_, 1); @@ -51,7 +51,7 @@ class AlarmImpl : public CompletionQueueTag { Unref(); return true; } - void Set(CompletionQueue* cq, gpr_timespec deadline, void* tag) { + void Set(::grpc::CompletionQueue* cq, gpr_timespec deadline, void* tag) { grpc_core::ExecCtx exec_ctx; GRPC_CQ_INTERNAL_REF(cq->cq(), "alarm"); cq_ = cq->cq(); @@ -114,13 +114,14 @@ class AlarmImpl : public CompletionQueueTag { }; } // namespace internal -static internal::GrpcLibraryInitializer g_gli_initializer; +static ::grpc::internal::GrpcLibraryInitializer g_gli_initializer; Alarm::Alarm() : alarm_(new internal::AlarmImpl()) { g_gli_initializer.summon(); } -void Alarm::SetInternal(CompletionQueue* cq, gpr_timespec deadline, void* tag) { +void Alarm::SetInternal(::grpc::CompletionQueue* cq, gpr_timespec deadline, + void* tag) { // Note that we know that alarm_ is actually an internal::AlarmImpl // but we declared it as the base pointer to avoid a forward declaration // or exposing core data structures in the C++ public headers. @@ -145,4 +146,4 @@ Alarm::~Alarm() { } void Alarm::Cancel() { static_cast(alarm_)->Cancel(); } -} // namespace grpc +} // namespace grpc_impl diff --git a/tools/doxygen/Doxyfile.c++ b/tools/doxygen/Doxyfile.c++ index 5c19711ee9..a5f817997d 100644 --- a/tools/doxygen/Doxyfile.c++ +++ b/tools/doxygen/Doxyfile.c++ @@ -925,6 +925,7 @@ include/grpc/support/thd_id.h \ include/grpc/support/time.h \ include/grpc/support/workaround_list.h \ include/grpcpp/alarm.h \ +include/grpcpp/alarm_impl.h \ include/grpcpp/channel.h \ include/grpcpp/client_context.h \ include/grpcpp/completion_queue.h \ diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 72b247c48d..7b10bdc699 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -926,6 +926,7 @@ include/grpc/support/thd_id.h \ include/grpc/support/time.h \ include/grpc/support/workaround_list.h \ include/grpcpp/alarm.h \ +include/grpcpp/alarm_impl.h \ include/grpcpp/channel.h \ include/grpcpp/client_context.h \ include/grpcpp/completion_queue.h \ diff --git a/tools/run_tests/generated/sources_and_headers.json b/tools/run_tests/generated/sources_and_headers.json index 4d9bbb0f09..a72e574f1e 100644 --- a/tools/run_tests/generated/sources_and_headers.json +++ b/tools/run_tests/generated/sources_and_headers.json @@ -11431,6 +11431,7 @@ "include/grpc++/support/sync_stream.h", "include/grpc++/support/time.h", "include/grpcpp/alarm.h", + "include/grpcpp/alarm_impl.h", "include/grpcpp/channel.h", "include/grpcpp/client_context.h", "include/grpcpp/completion_queue.h", @@ -11536,6 +11537,7 @@ "include/grpc++/support/sync_stream.h", "include/grpc++/support/time.h", "include/grpcpp/alarm.h", + "include/grpcpp/alarm_impl.h", "include/grpcpp/channel.h", "include/grpcpp/client_context.h", "include/grpcpp/completion_queue.h", -- cgit v1.2.3 From 9decf48632e2106a56515e67c4147e1a6506b47d Mon Sep 17 00:00:00 2001 From: Soheil Hassas Yeganeh Date: Thu, 6 Dec 2018 01:17:51 -0500 Subject: Move security credentials, connectors, and auth context to C++ This is to use `grpc_core::RefCount` to improve performnace. This commit also replaces explicit C vtables, with C++ vtable with its own compile time assertions and performance benefits. It also makes use of `RefCountedPtr` wherever possible. --- .../lb_policy/grpclb/grpclb_channel_secure.cc | 10 +- .../lb_policy/xds/xds_channel_secure.cc | 10 +- .../chttp2/client/secure/secure_channel_create.cc | 17 +- .../chttp2/server/secure/server_secure_chttp2.cc | 19 +- src/core/lib/gprpp/ref_counted_ptr.h | 8 +- src/core/lib/http/httpcli_security_connector.cc | 195 +++--- src/core/lib/http/parser.h | 10 +- src/core/lib/security/context/security_context.cc | 183 +++--- src/core/lib/security/context/security_context.h | 94 ++- .../security/credentials/alts/alts_credentials.cc | 84 ++- .../security/credentials/alts/alts_credentials.h | 47 +- .../credentials/composite/composite_credentials.cc | 297 ++++----- .../credentials/composite/composite_credentials.h | 111 +++- src/core/lib/security/credentials/credentials.cc | 160 +---- src/core/lib/security/credentials/credentials.h | 214 +++--- .../security/credentials/fake/fake_credentials.cc | 117 ++-- .../security/credentials/fake/fake_credentials.h | 28 +- .../google_default/google_default_credentials.cc | 83 ++- .../google_default/google_default_credentials.h | 33 +- .../security/credentials/iam/iam_credentials.cc | 62 +- .../lib/security/credentials/iam/iam_credentials.h | 22 +- .../security/credentials/jwt/jwt_credentials.cc | 129 ++-- .../lib/security/credentials/jwt/jwt_credentials.h | 39 +- .../credentials/local/local_credentials.cc | 51 +- .../security/credentials/local/local_credentials.h | 43 +- .../credentials/oauth2/oauth2_credentials.cc | 279 ++++---- .../credentials/oauth2/oauth2_credentials.h | 103 ++- .../credentials/plugin/plugin_credentials.cc | 136 ++-- .../credentials/plugin/plugin_credentials.h | 57 +- .../security/credentials/ssl/ssl_credentials.cc | 149 ++--- .../lib/security/credentials/ssl/ssl_credentials.h | 73 ++- .../alts/alts_security_connector.cc | 329 +++++----- .../alts/alts_security_connector.h | 22 +- .../fake/fake_security_connector.cc | 424 ++++++------ .../fake/fake_security_connector.h | 15 +- .../local/local_security_connector.cc | 278 ++++---- .../local/local_security_connector.h | 19 +- .../security_connector/security_connector.cc | 165 ++--- .../security_connector/security_connector.h | 206 +++--- .../ssl/ssl_security_connector.cc | 718 ++++++++++----------- .../ssl/ssl_security_connector.h | 26 +- .../lib/security/security_connector/ssl_utils.cc | 22 +- .../lib/security/security_connector/ssl_utils.h | 4 +- .../lib/security/transport/client_auth_filter.cc | 100 +-- .../lib/security/transport/security_handshaker.cc | 147 +++-- .../lib/security/transport/server_auth_filter.cc | 28 +- src/cpp/client/secure_credentials.cc | 6 +- src/cpp/client/secure_credentials.h | 9 +- src/cpp/common/secure_auth_context.cc | 38 +- src/cpp/common/secure_auth_context.h | 11 +- src/cpp/common/secure_create_auth_context.cc | 5 +- src/cpp/server/secure_server_credentials.cc | 2 +- test/core/security/alts_security_connector_test.cc | 41 +- test/core/security/auth_context_test.cc | 116 ++-- test/core/security/credentials_test.cc | 232 ++++--- test/core/security/oauth2_utils.cc | 5 +- .../security/print_google_default_creds_token.cc | 9 +- test/core/security/security_connector_test.cc | 95 +-- test/core/security/ssl_server_fuzzer.cc | 11 +- test/core/surface/secure_channel_create_test.cc | 2 +- test/cpp/common/auth_property_iterator_test.cc | 17 +- test/cpp/common/secure_auth_context_test.cc | 14 +- test/cpp/end2end/grpclb_end2end_test.cc | 4 +- 63 files changed, 2940 insertions(+), 3043 deletions(-) (limited to 'src/cpp') diff --git a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc index 6e8fbdcab7..657ff69312 100644 --- a/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.cc @@ -88,22 +88,18 @@ grpc_channel_args* grpc_lb_policy_grpclb_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc index 9a11f8e39f..55c646e6ee 100644 --- a/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc +++ b/src/core/ext/filters/client_channel/lb_policy/xds/xds_channel_secure.cc @@ -87,22 +87,18 @@ grpc_channel_args* grpc_lb_policy_xds_modify_lb_channel_args( // bearer token credentials. grpc_channel_credentials* channel_credentials = grpc_channel_credentials_find_in_args(args); - grpc_channel_credentials* creds_sans_call_creds = nullptr; + grpc_core::RefCountedPtr creds_sans_call_creds; if (channel_credentials != nullptr) { creds_sans_call_creds = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_credentials); + channel_credentials->duplicate_without_call_credentials(); GPR_ASSERT(creds_sans_call_creds != nullptr); args_to_remove[num_args_to_remove++] = GRPC_ARG_CHANNEL_CREDENTIALS; args_to_add[num_args_to_add++] = - grpc_channel_credentials_to_arg(creds_sans_call_creds); + grpc_channel_credentials_to_arg(creds_sans_call_creds.get()); } grpc_channel_args* result = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, num_args_to_remove, args_to_add, num_args_to_add); // Clean up. grpc_channel_args_destroy(args); - if (creds_sans_call_creds != nullptr) { - grpc_channel_credentials_unref(creds_sans_call_creds); - } return result; } diff --git a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc index e73eee4353..9612698e96 100644 --- a/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc +++ b/src/core/ext/transport/chttp2/client/secure/secure_channel_create.cc @@ -110,14 +110,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( grpc_channel_args* args_with_authority = grpc_channel_args_copy_and_add(args->args, args_to_add, num_args_to_add); grpc_uri_destroy(server_uri); - grpc_channel_security_connector* subchannel_security_connector = nullptr; // Create the security connector using the credentials and target name. grpc_channel_args* new_args_from_connector = nullptr; - const grpc_security_status security_status = - grpc_channel_credentials_create_security_connector( - channel_credentials, authority.get(), args_with_authority, - &subchannel_security_connector, &new_args_from_connector); - if (security_status != GRPC_SECURITY_OK) { + grpc_core::RefCountedPtr + subchannel_security_connector = + channel_credentials->create_security_connector( + /*call_creds=*/nullptr, authority.get(), args_with_authority, + &new_args_from_connector); + if (subchannel_security_connector == nullptr) { gpr_log(GPR_ERROR, "Failed to create secure subchannel for secure name '%s'", authority.get()); @@ -125,15 +125,14 @@ static grpc_subchannel_args* get_secure_naming_subchannel_args( return nullptr; } grpc_arg new_security_connector_arg = - grpc_security_connector_to_arg(&subchannel_security_connector->base); + grpc_security_connector_to_arg(subchannel_security_connector.get()); grpc_channel_args* new_args = grpc_channel_args_copy_and_add( new_args_from_connector != nullptr ? new_args_from_connector : args_with_authority, &new_security_connector_arg, 1); - GRPC_SECURITY_CONNECTOR_UNREF(&subchannel_security_connector->base, - "lb_channel_create"); + subchannel_security_connector.reset(DEBUG_LOCATION, "lb_channel_create"); if (new_args_from_connector != nullptr) { grpc_channel_args_destroy(new_args_from_connector); } diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc index 6689a17da6..98fdb62070 100644 --- a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc +++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.cc @@ -31,6 +31,7 @@ #include "src/core/ext/transport/chttp2/transport/chttp2_transport.h" #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/surface/api_trace.h" @@ -40,9 +41,8 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, grpc_server_credentials* creds) { grpc_core::ExecCtx exec_ctx; grpc_error* err = GRPC_ERROR_NONE; - grpc_server_security_connector* sc = nullptr; + grpc_core::RefCountedPtr sc; int port_num = 0; - grpc_security_status status; grpc_channel_args* args = nullptr; GRPC_API_TRACE( "grpc_server_add_secure_http2_port(" @@ -54,30 +54,27 @@ int grpc_server_add_secure_http2_port(grpc_server* server, const char* addr, "No credentials specified for secure server port (creds==NULL)"); goto done; } - status = grpc_server_credentials_create_security_connector(creds, &sc); - if (status != GRPC_SECURITY_OK) { + sc = creds->create_security_connector(); + if (sc == nullptr) { char* msg; gpr_asprintf(&msg, "Unable to create secure server with credentials of type %s.", - creds->type); - err = grpc_error_set_int(GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg), - GRPC_ERROR_INT_SECURITY_STATUS, status); + creds->type()); + err = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); gpr_free(msg); goto done; } // Create channel args. grpc_arg args_to_add[2]; args_to_add[0] = grpc_server_credentials_to_arg(creds); - args_to_add[1] = grpc_security_connector_to_arg(&sc->base); + args_to_add[1] = grpc_security_connector_to_arg(sc.get()); args = grpc_channel_args_copy_and_add(grpc_server_get_channel_args(server), args_to_add, GPR_ARRAY_SIZE(args_to_add)); // Add server port. err = grpc_chttp2_server_add_port(server, addr, args, &port_num); done: - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server"); - } + sc.reset(DEBUG_LOCATION, "server"); if (err != GRPC_ERROR_NONE) { const char* msg = grpc_error_string(err); diff --git a/src/core/lib/gprpp/ref_counted_ptr.h b/src/core/lib/gprpp/ref_counted_ptr.h index 1ed5d584c7..19f38d7f01 100644 --- a/src/core/lib/gprpp/ref_counted_ptr.h +++ b/src/core/lib/gprpp/ref_counted_ptr.h @@ -50,7 +50,7 @@ class RefCountedPtr { } template RefCountedPtr(RefCountedPtr&& other) { - value_ = other.value_; + value_ = static_cast(other.value_); other.value_ = nullptr; } @@ -77,7 +77,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor::value, "T does not have a virtual dtor"); if (other.value_ != nullptr) other.value_->IncrementRefCount(); - value_ = other.value_; + value_ = static_cast(other.value_); } // Copy assignment. @@ -118,7 +118,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor::value, "T does not have a virtual dtor"); if (value_ != nullptr) value_->Unref(); - value_ = value; + value_ = static_cast(value); } template void reset(const DebugLocation& location, const char* reason, @@ -126,7 +126,7 @@ class RefCountedPtr { static_assert(std::has_virtual_destructor::value, "T does not have a virtual dtor"); if (value_ != nullptr) value_->Unref(location, reason); - value_ = value; + value_ = static_cast(value); } // TODO(roth): This method exists solely as a transition mechanism to allow diff --git a/src/core/lib/http/httpcli_security_connector.cc b/src/core/lib/http/httpcli_security_connector.cc index 1c798d368b..6802851392 100644 --- a/src/core/lib/http/httpcli_security_connector.cc +++ b/src/core/lib/http/httpcli_security_connector.cc @@ -29,119 +29,125 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker_registry.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" +#include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/ssl_utils.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" #include "src/core/tsi/ssl_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* handshaker_factory; - char* secure_peer_name; -} grpc_httpcli_ssl_channel_security_connector; - -static void httpcli_ssl_destroy(grpc_security_connector* sc) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast(sc); - if (c->handshaker_factory != nullptr) { - tsi_ssl_client_handshaker_factory_unref(c->handshaker_factory); - c->handshaker_factory = nullptr; +class grpc_httpcli_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + explicit grpc_httpcli_ssl_channel_security_connector(char* secure_peer_name) + : grpc_channel_security_connector( + /*url_scheme=*/nullptr, + /*channel_creds=*/nullptr, + /*request_metadata_creds=*/nullptr), + secure_peer_name_(secure_peer_name) {} + + ~grpc_httpcli_ssl_channel_security_connector() override { + if (handshaker_factory_ != nullptr) { + tsi_ssl_client_handshaker_factory_unref(handshaker_factory_); + } + if (secure_peer_name_ != nullptr) { + gpr_free(secure_peer_name_); + } + } + + tsi_result InitHandshakerFactory(const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store) { + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + return tsi_create_ssl_client_handshaker_factory_with_options( + &options, &handshaker_factory_); } - if (c->secure_peer_name != nullptr) gpr_free(c->secure_peer_name); - gpr_free(sc); -} -static void httpcli_ssl_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast(sc); - tsi_handshaker* handshaker = nullptr; - if (c->handshaker_factory != nullptr) { - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->handshaker_factory, c->secure_peer_name, &handshaker); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + tsi_handshaker* handshaker = nullptr; + if (handshaker_factory_ != nullptr) { + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + handshaker_factory_, secure_peer_name_, &handshaker); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + } } + grpc_handshake_manager_add( + handshake_mgr, grpc_security_handshaker_create(handshaker, this)); } - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(handshaker, &sc->base)); -} -static void httpcli_ssl_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_httpcli_ssl_channel_security_connector* c = - reinterpret_cast(sc); - grpc_error* error = GRPC_ERROR_NONE; - - /* Check the peer name. */ - if (c->secure_peer_name != nullptr && - !tsi_ssl_peer_matches_name(&peer, c->secure_peer_name)) { - char* msg; - gpr_asprintf(&msg, "Peer name %s is not in peer certificate", - c->secure_peer_name); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); + tsi_ssl_client_handshaker_factory* handshaker_factory() const { + return handshaker_factory_; } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static int httpcli_ssl_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_httpcli_ssl_channel_security_connector* c1 = - reinterpret_cast(sc1); - grpc_httpcli_ssl_channel_security_connector* c2 = - reinterpret_cast(sc2); - return strcmp(c1->secure_peer_name, c2->secure_peer_name); -} + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* /*auth_context*/, + grpc_closure* on_peer_checked) override { + grpc_error* error = GRPC_ERROR_NONE; + + /* Check the peer name. */ + if (secure_peer_name_ != nullptr && + !tsi_ssl_peer_matches_name(&peer, secure_peer_name_)) { + char* msg; + gpr_asprintf(&msg, "Peer name %s is not in peer certificate", + secure_peer_name_); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); + } -static grpc_security_connector_vtable httpcli_ssl_vtable = { - httpcli_ssl_destroy, httpcli_ssl_check_peer, httpcli_ssl_cmp}; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast( + other_sc); + return strcmp(secure_peer_name_, other->secure_peer_name_); + } -static grpc_security_status httpcli_ssl_channel_security_connector_create( - const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, - const char* secure_peer_name, grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_httpcli_ssl_channel_security_connector* c; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + *error = GRPC_ERROR_NONE; + return true; + } - if (secure_peer_name != nullptr && pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, - "Cannot assert a secure peer name without a trust root."); - return GRPC_SECURITY_ERROR; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); } - c = static_cast( - gpr_zalloc(sizeof(grpc_httpcli_ssl_channel_security_connector))); + const char* secure_peer_name() const { return secure_peer_name_; } - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &httpcli_ssl_vtable; - if (secure_peer_name != nullptr) { - c->secure_peer_name = gpr_strdup(secure_peer_name); + private: + tsi_ssl_client_handshaker_factory* handshaker_factory_ = nullptr; + char* secure_peer_name_; +}; + +static grpc_core::RefCountedPtr +httpcli_ssl_channel_security_connector_create( + const char* pem_root_certs, const tsi_ssl_root_certs_store* root_store, + const char* secure_peer_name) { + if (secure_peer_name != nullptr && pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, + "Cannot assert a secure peer name without a trust root."); + return nullptr; } - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.pem_root_certs = pem_root_certs; - options.root_store = root_store; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->handshaker_factory); + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + secure_peer_name == nullptr ? nullptr : gpr_strdup(secure_peer_name)); + tsi_result result = c->InitHandshakerFactory(pem_root_certs, root_store); if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - httpcli_ssl_destroy(&c->base.base); - *sc = nullptr; - return GRPC_SECURITY_ERROR; + return nullptr; } - // We don't actually need a channel credentials object in this case, - // but we set it to a non-nullptr address so that we don't trigger - // assertions in grpc_channel_security_connector_cmp(). - c->base.channel_creds = (grpc_channel_credentials*)1; - c->base.add_handshakers = httpcli_ssl_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return c; } /* handshaker */ @@ -186,10 +192,11 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, } c->func = on_done; c->arg = arg; - grpc_channel_security_connector* sc = nullptr; - GPR_ASSERT(httpcli_ssl_channel_security_connector_create( - pem_root_certs, root_store, host, &sc) == GRPC_SECURITY_OK); - grpc_arg channel_arg = grpc_security_connector_to_arg(&sc->base); + grpc_core::RefCountedPtr sc = + httpcli_ssl_channel_security_connector_create(pem_root_certs, root_store, + host); + GPR_ASSERT(sc != nullptr); + grpc_arg channel_arg = grpc_security_connector_to_arg(sc.get()); grpc_channel_args args = {1, &channel_arg}; c->handshake_mgr = grpc_handshake_manager_create(); grpc_handshakers_add(HANDSHAKER_CLIENT, &args, @@ -197,7 +204,7 @@ static void ssl_handshake(void* arg, grpc_endpoint* tcp, const char* host, grpc_handshake_manager_do_handshake( c->handshake_mgr, tcp, nullptr /* channel_args */, deadline, nullptr /* acceptor */, on_handshake_done, c /* user_data */); - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "httpcli"); + sc.reset(DEBUG_LOCATION, "httpcli"); } const grpc_httpcli_handshaker grpc_httpcli_ssl = {"https", ssl_handshake}; diff --git a/src/core/lib/http/parser.h b/src/core/lib/http/parser.h index 1d2e13e831..a8f47c96c8 100644 --- a/src/core/lib/http/parser.h +++ b/src/core/lib/http/parser.h @@ -70,13 +70,13 @@ typedef struct grpc_http_request { /* A response */ typedef struct grpc_http_response { /* HTTP status code */ - int status; + int status = 0; /* Headers: count and key/values */ - size_t hdr_count; - grpc_http_header* hdrs; + size_t hdr_count = 0; + grpc_http_header* hdrs = nullptr; /* Body: length and contents; contents are NOT null-terminated */ - size_t body_length; - char* body; + size_t body_length = 0; + char* body = nullptr; } grpc_http_response; typedef struct { diff --git a/src/core/lib/security/context/security_context.cc b/src/core/lib/security/context/security_context.cc index 16f40b4f55..8443ee0695 100644 --- a/src/core/lib/security/context/security_context.cc +++ b/src/core/lib/security/context/security_context.cc @@ -23,6 +23,8 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/arena.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/surface/api_trace.h" #include "src/core/lib/surface/call.h" @@ -50,13 +52,11 @@ grpc_call_error grpc_call_set_credentials(grpc_call* call, ctx = static_cast( grpc_call_context_get(call, GRPC_CONTEXT_SECURITY)); if (ctx == nullptr) { - ctx = grpc_client_security_context_create(grpc_call_get_arena(call)); - ctx->creds = grpc_call_credentials_ref(creds); + ctx = grpc_client_security_context_create(grpc_call_get_arena(call), creds); grpc_call_context_set(call, GRPC_CONTEXT_SECURITY, ctx, grpc_client_security_context_destroy); } else { - grpc_call_credentials_unref(ctx->creds); - ctx->creds = grpc_call_credentials_ref(creds); + ctx->creds = creds != nullptr ? creds->Ref() : nullptr; } return GRPC_CALL_OK; @@ -66,33 +66,45 @@ grpc_auth_context* grpc_call_auth_context(grpc_call* call) { void* sec_ctx = grpc_call_context_get(call, GRPC_CONTEXT_SECURITY); GRPC_API_TRACE("grpc_call_auth_context(call=%p)", 1, (call)); if (sec_ctx == nullptr) return nullptr; - return grpc_call_is_client(call) - ? GRPC_AUTH_CONTEXT_REF( - ((grpc_client_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context client") - : GRPC_AUTH_CONTEXT_REF( - ((grpc_server_security_context*)sec_ctx)->auth_context, - "grpc_call_auth_context server"); + if (grpc_call_is_client(call)) { + auto* sc = static_cast(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context client") + .release(); + } + } else { + auto* sc = static_cast(sec_ctx); + if (sc->auth_context == nullptr) { + return nullptr; + } else { + return sc->auth_context + ->Ref(DEBUG_LOCATION, "grpc_call_auth_context server") + .release(); + } + } } void grpc_auth_context_release(grpc_auth_context* context) { GRPC_API_TRACE("grpc_auth_context_release(context=%p)", 1, (context)); - GRPC_AUTH_CONTEXT_UNREF(context, "grpc_auth_context_unref"); + if (context == nullptr) return; + context->Unref(DEBUG_LOCATION, "grpc_auth_context_unref"); } /* --- grpc_client_security_context --- */ grpc_client_security_context::~grpc_client_security_context() { - grpc_call_credentials_unref(creds); - GRPC_AUTH_CONTEXT_UNREF(auth_context, "client_security_context"); + auth_context.reset(DEBUG_LOCATION, "client_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } } grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena) { + gpr_arena* arena, grpc_call_credentials* creds) { return new (gpr_arena_alloc(arena, sizeof(grpc_client_security_context))) - grpc_client_security_context(); + grpc_client_security_context(creds != nullptr ? creds->Ref() : nullptr); } void grpc_client_security_context_destroy(void* ctx) { @@ -104,7 +116,7 @@ void grpc_client_security_context_destroy(void* ctx) { /* --- grpc_server_security_context --- */ grpc_server_security_context::~grpc_server_security_context() { - GRPC_AUTH_CONTEXT_UNREF(auth_context, "server_security_context"); + auth_context.reset(DEBUG_LOCATION, "server_security_context"); if (extension.instance != nullptr && extension.destroy != nullptr) { extension.destroy(extension.instance); } @@ -126,69 +138,11 @@ void grpc_server_security_context_destroy(void* ctx) { static grpc_auth_property_iterator empty_iterator = {nullptr, 0, nullptr}; -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained) { - grpc_auth_context* ctx = - static_cast(gpr_zalloc(sizeof(grpc_auth_context))); - gpr_ref_init(&ctx->refcount, 1); - if (chained != nullptr) { - ctx->chained = GRPC_AUTH_CONTEXT_REF(chained, "chained"); - ctx->peer_identity_property_name = - ctx->chained->peer_identity_property_name; - } - return ctx; -} - -#ifndef NDEBUG -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx, - const char* file, int line, - const char* reason) { - if (ctx == nullptr) return nullptr; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val + 1, reason); - } -#else -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* ctx) { - if (ctx == nullptr) return nullptr; -#endif - gpr_ref(&ctx->refcount); - return ctx; -} - -#ifndef NDEBUG -void grpc_auth_context_unref(grpc_auth_context* ctx, const char* file, int line, - const char* reason) { - if (ctx == nullptr) return; - if (grpc_trace_auth_context_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&ctx->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "AUTH_CONTEXT:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", ctx, val, - val - 1, reason); - } -#else -void grpc_auth_context_unref(grpc_auth_context* ctx) { - if (ctx == nullptr) return; -#endif - if (gpr_unref(&ctx->refcount)) { - size_t i; - GRPC_AUTH_CONTEXT_UNREF(ctx->chained, "chained"); - if (ctx->properties.array != nullptr) { - for (i = 0; i < ctx->properties.count; i++) { - grpc_auth_property_reset(&ctx->properties.array[i]); - } - gpr_free(ctx->properties.array); - } - gpr_free(ctx); - } -} - const char* grpc_auth_context_peer_identity_property_name( const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_identity_property_name(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name; + return ctx->peer_identity_property_name(); } int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, @@ -204,13 +158,13 @@ int grpc_auth_context_set_peer_identity_property_name(grpc_auth_context* ctx, name != nullptr ? name : "NULL"); return 0; } - ctx->peer_identity_property_name = prop->name; + ctx->set_peer_identity_property_name(prop->name); return 1; } int grpc_auth_context_peer_is_authenticated(const grpc_auth_context* ctx) { GRPC_API_TRACE("grpc_auth_context_peer_is_authenticated(ctx=%p)", 1, (ctx)); - return ctx->peer_identity_property_name == nullptr ? 0 : 1; + return ctx->is_authenticated(); } grpc_auth_property_iterator grpc_auth_context_property_iterator( @@ -226,16 +180,17 @@ const grpc_auth_property* grpc_auth_property_iterator_next( grpc_auth_property_iterator* it) { GRPC_API_TRACE("grpc_auth_property_iterator_next(it=%p)", 1, (it)); if (it == nullptr || it->ctx == nullptr) return nullptr; - while (it->index == it->ctx->properties.count) { - if (it->ctx->chained == nullptr) return nullptr; - it->ctx = it->ctx->chained; + while (it->index == it->ctx->properties().count) { + if (it->ctx->chained() == nullptr) return nullptr; + it->ctx = it->ctx->chained(); it->index = 0; } if (it->name == nullptr) { - return &it->ctx->properties.array[it->index++]; + return &it->ctx->properties().array[it->index++]; } else { - while (it->index < it->ctx->properties.count) { - const grpc_auth_property* prop = &it->ctx->properties.array[it->index++]; + while (it->index < it->ctx->properties().count) { + const grpc_auth_property* prop = + &it->ctx->properties().array[it->index++]; GPR_ASSERT(prop->name != nullptr); if (strcmp(it->name, prop->name) == 0) { return prop; @@ -262,49 +217,56 @@ grpc_auth_property_iterator grpc_auth_context_peer_identity( GRPC_API_TRACE("grpc_auth_context_peer_identity(ctx=%p)", 1, (ctx)); if (ctx == nullptr) return empty_iterator; return grpc_auth_context_find_properties_by_name( - ctx, ctx->peer_identity_property_name); + ctx, ctx->peer_identity_property_name()); } -static void ensure_auth_context_capacity(grpc_auth_context* ctx) { - if (ctx->properties.count == ctx->properties.capacity) { - ctx->properties.capacity = - GPR_MAX(ctx->properties.capacity + 8, ctx->properties.capacity * 2); - ctx->properties.array = static_cast( - gpr_realloc(ctx->properties.array, - ctx->properties.capacity * sizeof(grpc_auth_property))); +void grpc_auth_context::ensure_capacity() { + if (properties_.count == properties_.capacity) { + properties_.capacity = + GPR_MAX(properties_.capacity + 8, properties_.capacity * 2); + properties_.array = static_cast(gpr_realloc( + properties_.array, properties_.capacity * sizeof(grpc_auth_property))); } } +void grpc_auth_context::add_property(const char* name, const char* value, + size_t value_length) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; + prop->name = gpr_strdup(name); + prop->value = static_cast(gpr_malloc(value_length + 1)); + memcpy(prop->value, value, value_length); + prop->value[value_length] = '\0'; + prop->value_length = value_length; +} + void grpc_auth_context_add_property(grpc_auth_context* ctx, const char* name, const char* value, size_t value_length) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_property(ctx=%p, name=%s, value=%*.*s, " "value_length=%lu)", 6, (ctx, name, (int)value_length, (int)value_length, value, (unsigned long)value_length)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; + ctx->add_property(name, value, value_length); +} + +void grpc_auth_context::add_cstring_property(const char* name, + const char* value) { + ensure_capacity(); + grpc_auth_property* prop = &properties_.array[properties_.count++]; prop->name = gpr_strdup(name); - prop->value = static_cast(gpr_malloc(value_length + 1)); - memcpy(prop->value, value, value_length); - prop->value[value_length] = '\0'; - prop->value_length = value_length; + prop->value = gpr_strdup(value); + prop->value_length = strlen(value); } void grpc_auth_context_add_cstring_property(grpc_auth_context* ctx, const char* name, const char* value) { - grpc_auth_property* prop; GRPC_API_TRACE( "grpc_auth_context_add_cstring_property(ctx=%p, name=%s, value=%s)", 3, (ctx, name, value)); - ensure_auth_context_capacity(ctx); - prop = &ctx->properties.array[ctx->properties.count++]; - prop->name = gpr_strdup(name); - prop->value = gpr_strdup(value); - prop->value_length = strlen(value); + ctx->add_cstring_property(name, value); } void grpc_auth_property_reset(grpc_auth_property* property) { @@ -314,12 +276,17 @@ void grpc_auth_property_reset(grpc_auth_property* property) { } static void auth_context_pointer_arg_destroy(void* p) { - GRPC_AUTH_CONTEXT_UNREF((grpc_auth_context*)p, "auth_context_pointer_arg"); + if (p != nullptr) { + static_cast(p)->Unref(DEBUG_LOCATION, + "auth_context_pointer_arg"); + } } static void* auth_context_pointer_arg_copy(void* p) { - return GRPC_AUTH_CONTEXT_REF((grpc_auth_context*)p, - "auth_context_pointer_arg"); + auto* ctx = static_cast(p); + return ctx == nullptr + ? nullptr + : ctx->Ref(DEBUG_LOCATION, "auth_context_pointer_arg").release(); } static int auth_context_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } diff --git a/src/core/lib/security/context/security_context.h b/src/core/lib/security/context/security_context.h index e45415f63b..b43ee5e62d 100644 --- a/src/core/lib/security/context/security_context.h +++ b/src/core/lib/security/context/security_context.h @@ -21,6 +21,8 @@ #include +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/security/credentials/credentials.h" @@ -40,39 +42,59 @@ struct grpc_auth_property_array { size_t capacity = 0; }; -struct grpc_auth_context { - grpc_auth_context() { gpr_ref_init(&refcount, 0); } +void grpc_auth_property_reset(grpc_auth_property* property); - struct grpc_auth_context* chained = nullptr; - grpc_auth_property_array properties; - gpr_refcount refcount; - const char* peer_identity_property_name = nullptr; - grpc_pollset* pollset = nullptr; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_auth_context + : public grpc_core::RefCounted { + public: + explicit grpc_auth_context( + grpc_core::RefCountedPtr chained) + : grpc_core::RefCounted( + &grpc_trace_auth_context_refcount), + chained_(std::move(chained)) { + if (chained_ != nullptr) { + peer_identity_property_name_ = chained_->peer_identity_property_name_; + } + } + + ~grpc_auth_context() { + chained_.reset(DEBUG_LOCATION, "chained"); + if (properties_.array != nullptr) { + for (size_t i = 0; i < properties_.count; i++) { + grpc_auth_property_reset(&properties_.array[i]); + } + gpr_free(properties_.array); + } + } + + const grpc_auth_context* chained() const { return chained_.get(); } + const grpc_auth_property_array& properties() const { return properties_; } + + bool is_authenticated() const { + return peer_identity_property_name_ != nullptr; + } + const char* peer_identity_property_name() const { + return peer_identity_property_name_; + } + void set_peer_identity_property_name(const char* name) { + peer_identity_property_name_ = name; + } + + void ensure_capacity(); + void add_property(const char* name, const char* value, size_t value_length); + void add_cstring_property(const char* name, const char* value); + + private: + grpc_core::RefCountedPtr chained_; + grpc_auth_property_array properties_; + const char* peer_identity_property_name_ = nullptr; }; -/* Creation. */ -grpc_auth_context* grpc_auth_context_create(grpc_auth_context* chained); - -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_AUTH_CONTEXT_REF(p, r) \ - grpc_auth_context_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) \ - grpc_auth_context_unref((p), __FILE__, __LINE__, (r)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy, - const char* file, int line, - const char* reason); -void grpc_auth_context_unref(grpc_auth_context* policy, const char* file, - int line, const char* reason); -#else -#define GRPC_AUTH_CONTEXT_REF(p, r) grpc_auth_context_ref((p)) -#define GRPC_AUTH_CONTEXT_UNREF(p, r) grpc_auth_context_unref((p)) -grpc_auth_context* grpc_auth_context_ref(grpc_auth_context* policy); -void grpc_auth_context_unref(grpc_auth_context* policy); -#endif - -void grpc_auth_property_reset(grpc_auth_property* property); - /* --- grpc_security_context_extension --- Extension to the security context that may be set in a filter and accessed @@ -88,16 +110,18 @@ struct grpc_security_context_extension { Internal client-side security context. */ struct grpc_client_security_context { - grpc_client_security_context() = default; + explicit grpc_client_security_context( + grpc_core::RefCountedPtr creds) + : creds(std::move(creds)) {} ~grpc_client_security_context(); - grpc_call_credentials* creds = nullptr; - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr creds; + grpc_core::RefCountedPtr auth_context; grpc_security_context_extension extension; }; grpc_client_security_context* grpc_client_security_context_create( - gpr_arena* arena); + gpr_arena* arena, grpc_call_credentials* creds); void grpc_client_security_context_destroy(void* ctx); /* --- grpc_server_security_context --- @@ -108,7 +132,7 @@ struct grpc_server_security_context { grpc_server_security_context() = default; ~grpc_server_security_context(); - grpc_auth_context* auth_context = nullptr; + grpc_core::RefCountedPtr auth_context; grpc_security_context_extension extension; }; diff --git a/src/core/lib/security/credentials/alts/alts_credentials.cc b/src/core/lib/security/credentials/alts/alts_credentials.cc index 1fbef4ae0c..06546492bc 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.cc +++ b/src/core/lib/security/credentials/alts/alts_credentials.cc @@ -33,40 +33,47 @@ #define GRPC_CREDENTIALS_TYPE_ALTS "Alts" #define GRPC_ALTS_HANDSHAKER_SERVICE_URL "metadata.google.internal:8080" -static void alts_credentials_destruct(grpc_channel_credentials* creds) { - grpc_alts_credentials* alts_creds = - reinterpret_cast(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); -} - -static void alts_server_credentials_destruct(grpc_server_credentials* creds) { - grpc_alts_server_credentials* alts_creds = - reinterpret_cast(creds); - grpc_alts_credentials_options_destroy(alts_creds->options); - gpr_free(alts_creds->handshaker_service_url); +grpc_alts_credentials::grpc_alts_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_alts_credentials::~grpc_alts_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); } -static grpc_security_status alts_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr +grpc_alts_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_alts_channel_security_connector_create( - creds, request_metadata_creds, target_name, sc); + this->Ref(), std::move(call_creds), target_name); } -static grpc_security_status alts_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_alts_server_security_connector_create(creds, sc); +grpc_alts_server_credentials::grpc_alts_server_credentials( + const grpc_alts_credentials_options* options, + const char* handshaker_service_url) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_ALTS), + options_(grpc_alts_credentials_options_copy(options)), + handshaker_service_url_(handshaker_service_url == nullptr + ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) + : gpr_strdup(handshaker_service_url)) {} + +grpc_core::RefCountedPtr +grpc_alts_server_credentials::create_security_connector() { + return grpc_alts_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable alts_credentials_vtable = { - alts_credentials_destruct, alts_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable alts_server_credentials_vtable = { - alts_server_credentials_destruct, alts_server_create_security_connector}; +grpc_alts_server_credentials::~grpc_alts_server_credentials() { + grpc_alts_credentials_options_destroy(options_); + gpr_free(handshaker_service_url_); +} grpc_channel_credentials* grpc_alts_credentials_create_customized( const grpc_alts_credentials_options* options, @@ -74,17 +81,7 @@ grpc_channel_credentials* grpc_alts_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast( - gpr_zalloc(sizeof(grpc_alts_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New(options, handshaker_service_url); } grpc_server_credentials* grpc_alts_server_credentials_create_customized( @@ -93,17 +90,8 @@ grpc_server_credentials* grpc_alts_server_credentials_create_customized( if (!enable_untrusted_alts && !grpc_alts_is_running_on_gcp()) { return nullptr; } - auto creds = static_cast( - gpr_zalloc(sizeof(grpc_alts_server_credentials))); - creds->options = grpc_alts_credentials_options_copy(options); - creds->handshaker_service_url = - handshaker_service_url == nullptr - ? gpr_strdup(GRPC_ALTS_HANDSHAKER_SERVICE_URL) - : gpr_strdup(handshaker_service_url); - creds->base.type = GRPC_CREDENTIALS_TYPE_ALTS; - creds->base.vtable = &alts_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New(options, + handshaker_service_url); } grpc_channel_credentials* grpc_alts_credentials_create( diff --git a/src/core/lib/security/credentials/alts/alts_credentials.h b/src/core/lib/security/credentials/alts/alts_credentials.h index 810117f2be..cc6d5222b1 100644 --- a/src/core/lib/security/credentials/alts/alts_credentials.h +++ b/src/core/lib/security/credentials/alts/alts_credentials.h @@ -27,18 +27,45 @@ #include "src/core/lib/security/credentials/credentials.h" /* Main struct for grpc ALTS channel credential. */ -typedef struct grpc_alts_credentials { - grpc_channel_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_credentials; +class grpc_alts_credentials final : public grpc_channel_credentials { + public: + grpc_alts_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_credentials() override; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /* Main struct for grpc ALTS server credential. */ -typedef struct grpc_alts_server_credentials { - grpc_server_credentials base; - grpc_alts_credentials_options* options; - char* handshaker_service_url; -} grpc_alts_server_credentials; +class grpc_alts_server_credentials final : public grpc_server_credentials { + public: + grpc_alts_server_credentials(const grpc_alts_credentials_options* options, + const char* handshaker_service_url); + ~grpc_alts_server_credentials() override; + + grpc_core::RefCountedPtr + create_security_connector() override; + + const grpc_alts_credentials_options* options() const { return options_; } + grpc_alts_credentials_options* mutable_options() { return options_; } + const char* handshaker_service_url() const { return handshaker_service_url_; } + + private: + grpc_alts_credentials_options* options_; + char* handshaker_service_url_; +}; /** * This method creates an ALTS channel credential object with customized diff --git a/src/core/lib/security/credentials/composite/composite_credentials.cc b/src/core/lib/security/credentials/composite/composite_credentials.cc index b8f409260f..85dcd4693b 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.cc +++ b/src/core/lib/security/credentials/composite/composite_credentials.cc @@ -20,8 +20,10 @@ #include "src/core/lib/security/credentials/composite/composite_credentials.h" -#include +#include +#include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/polling_entity.h" #include "src/core/lib/surface/api_trace.h" @@ -31,36 +33,83 @@ /* -- Composite call credentials. -- */ -typedef struct { +static void composite_call_metadata_cb(void* arg, grpc_error* error); + +grpc_call_credentials_array::~grpc_call_credentials_array() { + for (size_t i = 0; i < num_creds_; ++i) { + creds_array_[i].~RefCountedPtr(); + } + if (creds_array_ != nullptr) { + gpr_free(creds_array_); + } +} + +grpc_call_credentials_array::grpc_call_credentials_array( + const grpc_call_credentials_array& that) + : num_creds_(that.num_creds_) { + reserve(that.capacity_); + for (size_t i = 0; i < num_creds_; ++i) { + new (&creds_array_[i]) + grpc_core::RefCountedPtr(that.creds_array_[i]); + } +} + +void grpc_call_credentials_array::reserve(size_t capacity) { + if (capacity_ >= capacity) { + return; + } + grpc_core::RefCountedPtr* new_arr = + static_cast*>(gpr_malloc( + sizeof(grpc_core::RefCountedPtr) * capacity)); + if (creds_array_ != nullptr) { + for (size_t i = 0; i < num_creds_; ++i) { + new (&new_arr[i]) grpc_core::RefCountedPtr( + std::move(creds_array_[i])); + creds_array_[i].~RefCountedPtr(); + } + gpr_free(creds_array_); + } + creds_array_ = new_arr; + capacity_ = capacity; +} + +namespace { +struct grpc_composite_call_credentials_metadata_context { + grpc_composite_call_credentials_metadata_context( + grpc_composite_call_credentials* composite_creds, + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata) + : composite_creds(composite_creds), + pollent(pollent), + auth_md_context(auth_md_context), + md_array(md_array), + on_request_metadata(on_request_metadata) { + GRPC_CLOSURE_INIT(&internal_on_request_metadata, composite_call_metadata_cb, + this, grpc_schedule_on_exec_ctx); + } + grpc_composite_call_credentials* composite_creds; - size_t creds_index; + size_t creds_index = 0; grpc_polling_entity* pollent; grpc_auth_metadata_context auth_md_context; grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_closure internal_on_request_metadata; -} grpc_composite_call_credentials_metadata_context; - -static void composite_call_destruct(grpc_call_credentials* creds) { - grpc_composite_call_credentials* c = - reinterpret_cast(creds); - for (size_t i = 0; i < c->inner.num_creds; i++) { - grpc_call_credentials_unref(c->inner.creds_array[i]); - } - gpr_free(c->inner.creds_array); -} +}; +} // namespace static void composite_call_metadata_cb(void* arg, grpc_error* error) { grpc_composite_call_credentials_metadata_context* ctx = static_cast(arg); if (error == GRPC_ERROR_NONE) { + const grpc_call_credentials_array& inner = ctx->composite_creds->inner(); /* See if we need to get some more metadata. */ - if (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, - &ctx->internal_on_request_metadata, &error)) { + if (ctx->creds_index < inner.size()) { + if (inner.get(ctx->creds_index++) + ->get_request_metadata( + ctx->pollent, ctx->auth_md_context, ctx->md_array, + &ctx->internal_on_request_metadata, &error)) { // Synchronous response, so call ourselves recursively. composite_call_metadata_cb(arg, error); GRPC_ERROR_UNREF(error); @@ -73,76 +122,86 @@ static void composite_call_metadata_cb(void* arg, grpc_error* error) { gpr_free(ctx); } -static bool composite_call_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context auth_md_context, +bool grpc_composite_call_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context auth_md_context, grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, grpc_error** error) { - grpc_composite_call_credentials* c = - reinterpret_cast(creds); grpc_composite_call_credentials_metadata_context* ctx; - ctx = static_cast( - gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context))); - ctx->composite_creds = c; - ctx->pollent = pollent; - ctx->auth_md_context = auth_md_context; - ctx->md_array = md_array; - ctx->on_request_metadata = on_request_metadata; - GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata, - composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx); + ctx = grpc_core::New( + this, pollent, auth_md_context, md_array, on_request_metadata); bool synchronous = true; - while (ctx->creds_index < ctx->composite_creds->inner.num_creds) { - grpc_call_credentials* inner_creds = - ctx->composite_creds->inner.creds_array[ctx->creds_index++]; - if (grpc_call_credentials_get_request_metadata( - inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array, - &ctx->internal_on_request_metadata, error)) { + const grpc_call_credentials_array& inner = ctx->composite_creds->inner(); + while (ctx->creds_index < inner.size()) { + if (inner.get(ctx->creds_index++) + ->get_request_metadata(ctx->pollent, ctx->auth_md_context, + ctx->md_array, + &ctx->internal_on_request_metadata, error)) { if (*error != GRPC_ERROR_NONE) break; } else { synchronous = false; // Async return. break; } } - if (synchronous) gpr_free(ctx); + if (synchronous) grpc_core::Delete(ctx); return synchronous; } -static void composite_call_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_composite_call_credentials* c = - reinterpret_cast(creds); - for (size_t i = 0; i < c->inner.num_creds; ++i) { - grpc_call_credentials_cancel_get_request_metadata( - c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error)); +void grpc_composite_call_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + for (size_t i = 0; i < inner_.size(); ++i) { + inner_.get(i)->cancel_get_request_metadata(md_array, GRPC_ERROR_REF(error)); } GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable composite_call_credentials_vtable = { - composite_call_destruct, composite_call_get_request_metadata, - composite_call_cancel_get_request_metadata}; +static size_t get_creds_array_size(const grpc_call_credentials* creds, + bool is_composite) { + return is_composite + ? static_cast(creds) + ->inner() + .size() + : 1; +} -static grpc_call_credentials_array get_creds_array( - grpc_call_credentials** creds_addr) { - grpc_call_credentials_array result; - grpc_call_credentials* creds = *creds_addr; - result.creds_array = creds_addr; - result.num_creds = 1; - if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - result = *grpc_composite_call_credentials_get_credentials(creds); +void grpc_composite_call_credentials::push_to_inner( + grpc_core::RefCountedPtr creds, bool is_composite) { + if (!is_composite) { + inner_.push_back(std::move(creds)); + return; } - return result; + auto composite_creds = + static_cast(creds.get()); + for (size_t i = 0; i < composite_creds->inner().size(); ++i) { + inner_.push_back(std::move(composite_creds->inner_.get_mutable(i))); + } +} + +grpc_composite_call_credentials::grpc_composite_call_credentials( + grpc_core::RefCountedPtr creds1, + grpc_core::RefCountedPtr creds2) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) { + const bool creds1_is_composite = + strcmp(creds1->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const bool creds2_is_composite = + strcmp(creds2->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0; + const size_t size = get_creds_array_size(creds1.get(), creds1_is_composite) + + get_creds_array_size(creds2.get(), creds2_is_composite); + inner_.reserve(size); + push_to_inner(std::move(creds1), creds1_is_composite); + push_to_inner(std::move(creds2), creds2_is_composite); +} + +static grpc_core::RefCountedPtr +composite_call_credentials_create( + grpc_core::RefCountedPtr creds1, + grpc_core::RefCountedPtr creds2) { + return grpc_core::MakeRefCounted( + std::move(creds1), std::move(creds2)); } grpc_call_credentials* grpc_composite_call_credentials_create( grpc_call_credentials* creds1, grpc_call_credentials* creds2, void* reserved) { - size_t i; - size_t creds_array_byte_size; - grpc_call_credentials_array creds1_array; - grpc_call_credentials_array creds2_array; - grpc_composite_call_credentials* c; GRPC_API_TRACE( "grpc_composite_call_credentials_create(creds1=%p, creds2=%p, " "reserved=%p)", @@ -150,120 +209,40 @@ grpc_call_credentials* grpc_composite_call_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(creds1 != nullptr); GPR_ASSERT(creds2 != nullptr); - c = static_cast( - gpr_zalloc(sizeof(grpc_composite_call_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE; - c->base.vtable = &composite_call_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - creds1_array = get_creds_array(&creds1); - creds2_array = get_creds_array(&creds2); - c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds; - creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*); - c->inner.creds_array = - static_cast(gpr_zalloc(creds_array_byte_size)); - for (i = 0; i < creds1_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds1_array.creds_array[i]; - c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds); - } - for (i = 0; i < creds2_array.num_creds; i++) { - grpc_call_credentials* cur_creds = creds2_array.creds_array[i]; - c->inner.creds_array[i + creds1_array.num_creds] = - grpc_call_credentials_ref(cur_creds); - } - return &c->base; -} -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) { - const grpc_composite_call_credentials* c = - reinterpret_cast(creds); - GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); - return &c->inner; -} - -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds) { - size_t i; - if (strcmp(creds->type, type) == 0) { - if (composite_creds != nullptr) *composite_creds = nullptr; - return creds; - } else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) { - const grpc_call_credentials_array* inner_creds_array = - grpc_composite_call_credentials_get_credentials(creds); - for (i = 0; i < inner_creds_array->num_creds; i++) { - if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) { - if (composite_creds != nullptr) *composite_creds = creds; - return inner_creds_array->creds_array[i]; - } - } - } - return nullptr; + return composite_call_credentials_create(creds1->Ref(), creds2->Ref()) + .release(); } /* -- Composite channel credentials. -- */ -static void composite_channel_destruct(grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast(creds); - grpc_channel_credentials_unref(c->inner_creds); - grpc_call_credentials_unref(c->call_creds); -} - -static grpc_security_status composite_channel_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr +grpc_composite_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_composite_channel_credentials* c = - reinterpret_cast(creds); - grpc_security_status status = GRPC_SECURITY_ERROR; - - GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr && - c->inner_creds->vtable != nullptr && - c->inner_creds->vtable->create_security_connector != nullptr); + grpc_channel_args** new_args) { + GPR_ASSERT(inner_creds_ != nullptr && call_creds_ != nullptr); /* If we are passed a call_creds, create a call composite to pass it downstream. */ if (call_creds != nullptr) { - grpc_call_credentials* composite_call_creds = - grpc_composite_call_credentials_create(c->call_creds, call_creds, - nullptr); - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, composite_call_creds, target, args, sc, new_args); - grpc_call_credentials_unref(composite_call_creds); + return inner_creds_->create_security_connector( + composite_call_credentials_create(call_creds_, std::move(call_creds)), + target, args, new_args); } else { - status = c->inner_creds->vtable->create_security_connector( - c->inner_creds, c->call_creds, target, args, sc, new_args); + return inner_creds_->create_security_connector(call_creds_, target, args, + new_args); } - return status; -} - -static grpc_channel_credentials* -composite_channel_duplicate_without_call_credentials( - grpc_channel_credentials* creds) { - grpc_composite_channel_credentials* c = - reinterpret_cast(creds); - return grpc_channel_credentials_ref(c->inner_creds); } -static grpc_channel_credentials_vtable composite_channel_credentials_vtable = { - composite_channel_destruct, composite_channel_create_security_connector, - composite_channel_duplicate_without_call_credentials}; - grpc_channel_credentials* grpc_composite_channel_credentials_create( grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds, void* reserved) { - grpc_composite_channel_credentials* c = - static_cast(gpr_zalloc(sizeof(*c))); GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr && reserved == nullptr); GRPC_API_TRACE( "grpc_composite_channel_credentials_create(channel_creds=%p, " "call_creds=%p, reserved=%p)", 3, (channel_creds, call_creds, reserved)); - c->base.type = channel_creds->type; - c->base.vtable = &composite_channel_credentials_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->inner_creds = grpc_channel_credentials_ref(channel_creds); - c->call_creds = grpc_call_credentials_ref(call_creds); - return &c->base; + return grpc_core::New( + channel_creds->Ref(), call_creds->Ref()); } diff --git a/src/core/lib/security/credentials/composite/composite_credentials.h b/src/core/lib/security/credentials/composite/composite_credentials.h index a952ad57f1..6b7fca1370 100644 --- a/src/core/lib/security/credentials/composite/composite_credentials.h +++ b/src/core/lib/security/credentials/composite/composite_credentials.h @@ -21,39 +21,104 @@ #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials** creds_array; - size_t num_creds; -} grpc_call_credentials_array; +// TODO(soheil): Replace this with InlinedVector once #16032 is resolved. +class grpc_call_credentials_array { + public: + grpc_call_credentials_array() = default; + grpc_call_credentials_array(const grpc_call_credentials_array& that); -const grpc_call_credentials_array* -grpc_composite_call_credentials_get_credentials( - grpc_call_credentials* composite_creds); + ~grpc_call_credentials_array(); -/* Returns creds if creds is of the specified type or the inner creds of the - specified type (if found), if the creds is of type COMPOSITE. - If composite_creds is not NULL, *composite_creds will point to creds if of - type COMPOSITE in case of success. */ -grpc_call_credentials* grpc_credentials_contains_type( - grpc_call_credentials* creds, const char* type, - grpc_call_credentials** composite_creds); + void reserve(size_t capacity); + + // Must reserve before pushing any data. + void push_back(grpc_core::RefCountedPtr cred) { + GPR_DEBUG_ASSERT(capacity_ > num_creds_); + new (&creds_array_[num_creds_++]) + grpc_core::RefCountedPtr(std::move(cred)); + } + + const grpc_core::RefCountedPtr& get(size_t i) const { + GPR_DEBUG_ASSERT(i < num_creds_); + return creds_array_[i]; + } + grpc_core::RefCountedPtr& get_mutable(size_t i) { + GPR_DEBUG_ASSERT(i < num_creds_); + return creds_array_[i]; + } + + size_t size() const { return num_creds_; } + + private: + grpc_core::RefCountedPtr* creds_array_ = nullptr; + size_t num_creds_ = 0; + size_t capacity_ = 0; +}; /* -- Composite channel credentials. -- */ -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* inner_creds; - grpc_call_credentials* call_creds; -} grpc_composite_channel_credentials; +class grpc_composite_channel_credentials : public grpc_channel_credentials { + public: + grpc_composite_channel_credentials( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr call_creds) + : grpc_channel_credentials(channel_creds->type()), + inner_creds_(std::move(channel_creds)), + call_creds_(std::move(call_creds)) {} + + ~grpc_composite_channel_credentials() override = default; + + grpc_core::RefCountedPtr + duplicate_without_call_credentials() override { + return inner_creds_; + } + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_channel_credentials* inner_creds() const { + return inner_creds_.get(); + } + const grpc_call_credentials* call_creds() const { return call_creds_.get(); } + grpc_call_credentials* mutable_call_creds() { return call_creds_.get(); } + + private: + grpc_core::RefCountedPtr inner_creds_; + grpc_core::RefCountedPtr call_creds_; +}; /* -- Composite call credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_call_credentials_array inner; -} grpc_composite_call_credentials; +class grpc_composite_call_credentials : public grpc_call_credentials { + public: + grpc_composite_call_credentials( + grpc_core::RefCountedPtr creds1, + grpc_core::RefCountedPtr creds2); + ~grpc_composite_call_credentials() override = default; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const grpc_call_credentials_array& inner() const { return inner_; } + + private: + void push_to_inner(grpc_core::RefCountedPtr creds, + bool is_composite); + + grpc_call_credentials_array inner_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_COMPOSITE_COMPOSITE_CREDENTIALS_H \ */ diff --git a/src/core/lib/security/credentials/credentials.cc b/src/core/lib/security/credentials/credentials.cc index c43cb440eb..90452d68d6 100644 --- a/src/core/lib/security/credentials/credentials.cc +++ b/src/core/lib/security/credentials/credentials.cc @@ -39,120 +39,24 @@ /* -- Common. -- */ -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds) { - grpc_credentials_metadata_request* r = - static_cast( - gpr_zalloc(sizeof(grpc_credentials_metadata_request))); - r->creds = grpc_call_credentials_ref(creds); - return r; -} - -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r) { - grpc_call_credentials_unref(r->creds); - grpc_http_response_destroy(&r->response); - gpr_free(r); -} - -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_channel_credentials_unref(grpc_channel_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } -} - void grpc_channel_credentials_release(grpc_channel_credentials* creds) { GRPC_API_TRACE("grpc_channel_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_channel_credentials_unref(creds); -} - -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_call_credentials_unref(grpc_call_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - gpr_free(creds); - } + if (creds) creds->Unref(); } void grpc_call_credentials_release(grpc_call_credentials* creds) { GRPC_API_TRACE("grpc_call_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials_unref(creds); -} - -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - if (creds == nullptr || creds->vtable->get_request_metadata == nullptr) { - return true; - } - return creds->vtable->get_request_metadata(creds, pollent, context, md_array, - on_request_metadata, error); -} - -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - if (creds == nullptr || - creds->vtable->cancel_get_request_metadata == nullptr) { - return; - } - creds->vtable->cancel_get_request_metadata(creds, md_array, error); -} - -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* channel_creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args) { - *new_args = nullptr; - if (channel_creds == nullptr) { - return GRPC_SECURITY_ERROR; - } - GPR_ASSERT(channel_creds->vtable->create_security_connector != nullptr); - return channel_creds->vtable->create_security_connector( - channel_creds, nullptr, target, args, sc, new_args); -} - -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* channel_creds) { - if (channel_creds != nullptr && channel_creds->vtable != nullptr && - channel_creds->vtable->duplicate_without_call_credentials != nullptr) { - return channel_creds->vtable->duplicate_without_call_credentials( - channel_creds); - } else { - return grpc_channel_credentials_ref(channel_creds); - } + if (creds) creds->Unref(); } static void credentials_pointer_arg_destroy(void* p) { - grpc_channel_credentials_unref(static_cast(p)); + static_cast(p)->Unref(); } static void* credentials_pointer_arg_copy(void* p) { - return grpc_channel_credentials_ref( - static_cast(p)); + return static_cast(p)->Ref().release(); } static int credentials_pointer_cmp(void* a, void* b) { return GPR_ICMP(a, b); } @@ -191,63 +95,35 @@ grpc_channel_credentials* grpc_channel_credentials_find_in_args( return nullptr; } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds) { - if (creds == nullptr) return nullptr; - gpr_ref(&creds->refcount); - return creds; -} - -void grpc_server_credentials_unref(grpc_server_credentials* creds) { - if (creds == nullptr) return; - if (gpr_unref(&creds->refcount)) { - if (creds->vtable->destruct != nullptr) { - creds->vtable->destruct(creds); - } - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - gpr_free(creds); - } -} - void grpc_server_credentials_release(grpc_server_credentials* creds) { GRPC_API_TRACE("grpc_server_credentials_release(creds=%p)", 1, (creds)); grpc_core::ExecCtx exec_ctx; - grpc_server_credentials_unref(creds); + if (creds) creds->Unref(); } -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - if (creds == nullptr || creds->vtable->create_security_connector == nullptr) { - gpr_log(GPR_ERROR, "Server credentials cannot create security context."); - return GRPC_SECURITY_ERROR; - } - return creds->vtable->create_security_connector(creds, sc); -} - -void grpc_server_credentials_set_auth_metadata_processor( - grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { +void grpc_server_credentials::set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor) { GRPC_API_TRACE( "grpc_server_credentials_set_auth_metadata_processor(" "creds=%p, " "processor=grpc_auth_metadata_processor { process: %p, state: %p })", - 3, (creds, (void*)(intptr_t)processor.process, processor.state)); - if (creds == nullptr) return; - if (creds->processor.destroy != nullptr && - creds->processor.state != nullptr) { - creds->processor.destroy(creds->processor.state); - } - creds->processor = processor; + 3, (this, (void*)(intptr_t)processor.process, processor.state)); + DestroyProcessor(); + processor_ = processor; +} + +void grpc_server_credentials_set_auth_metadata_processor( + grpc_server_credentials* creds, grpc_auth_metadata_processor processor) { + GPR_DEBUG_ASSERT(creds != nullptr); + creds->set_auth_metadata_processor(processor); } static void server_credentials_pointer_arg_destroy(void* p) { - grpc_server_credentials_unref(static_cast(p)); + static_cast(p)->Unref(); } static void* server_credentials_pointer_arg_copy(void* p) { - return grpc_server_credentials_ref(static_cast(p)); + return static_cast(p)->Ref().release(); } static int server_credentials_pointer_cmp(void* a, void* b) { diff --git a/src/core/lib/security/credentials/credentials.h b/src/core/lib/security/credentials/credentials.h index 3878958b38..4091ef3dfb 100644 --- a/src/core/lib/security/credentials/credentials.h +++ b/src/core/lib/security/credentials/credentials.h @@ -26,6 +26,7 @@ #include #include "src/core/lib/transport/metadata_batch.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/polling_entity.h" @@ -90,44 +91,46 @@ void grpc_override_well_known_credentials_path_getter( #define GRPC_ARG_CHANNEL_CREDENTIALS "grpc.channel_credentials" -typedef struct { - void (*destruct)(grpc_channel_credentials* c); - - grpc_security_status (*create_security_connector)( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_channel_credentials + : grpc_core::RefCounted { + public: + explicit grpc_channel_credentials(const char* type) : type_(type) {} + virtual ~grpc_channel_credentials() = default; + + // Creates a security connector for the channel. May also create new channel + // args for the channel to be used in place of the passed in const args if + // returned non NULL. In that case the caller is responsible for destroying + // new_args after channel creation. + virtual grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args); - - grpc_channel_credentials* (*duplicate_without_call_credentials)( - grpc_channel_credentials* c); -} grpc_channel_credentials_vtable; - -struct grpc_channel_credentials { - const grpc_channel_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; + grpc_channel_args** new_args) { + // Tell clang-tidy that call_creds cannot be passed as const-ref. + call_creds.reset(); + GRPC_ABSTRACT; + } + + // Creates a version of the channel credentials without any attached call + // credentials. This can be used in order to open a channel to a non-trusted + // gRPC load balancer. + virtual grpc_core::RefCountedPtr + duplicate_without_call_credentials() { + // By default we just increment the refcount. + return Ref(); + } + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_channel_credentials* grpc_channel_credentials_ref( - grpc_channel_credentials* creds); -void grpc_channel_credentials_unref(grpc_channel_credentials* creds); - -/* Creates a security connector for the channel. May also create new channel - args for the channel to be used in place of the passed in const args if - returned non NULL. In that case the caller is responsible for destroying - new_args after channel creation. */ -grpc_security_status grpc_channel_credentials_create_security_connector( - grpc_channel_credentials* creds, const char* target, - const grpc_channel_args* args, grpc_channel_security_connector** sc, - grpc_channel_args** new_args); - -/* Creates a version of the channel credentials without any attached call - credentials. This can be used in order to open a channel to a non-trusted - gRPC load balancer. */ -grpc_channel_credentials* -grpc_channel_credentials_duplicate_without_call_credentials( - grpc_channel_credentials* creds); - /* Util to encapsulate the channel credentials in a channel arg. */ grpc_arg grpc_channel_credentials_to_arg(grpc_channel_credentials* credentials); @@ -158,44 +161,39 @@ void grpc_credentials_mdelem_array_destroy(grpc_credentials_mdelem_array* list); /* --- grpc_call_credentials. --- */ -typedef struct { - void (*destruct)(grpc_call_credentials* c); - bool (*get_request_metadata)(grpc_call_credentials* c, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error); - void (*cancel_get_request_metadata)(grpc_call_credentials* c, - grpc_credentials_mdelem_array* md_array, - grpc_error* error); -} grpc_call_credentials_vtable; - -struct grpc_call_credentials { - const grpc_call_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_call_credentials + : public grpc_core::RefCounted { + public: + explicit grpc_call_credentials(const char* type) : type_(type) {} + virtual ~grpc_call_credentials() = default; + + // Returns true if completed synchronously, in which case \a error will + // be set to indicate the result. Otherwise, \a on_request_metadata will + // be invoked asynchronously when complete. \a md_array will be populated + // with the resulting metadata once complete. + virtual bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) GRPC_ABSTRACT; + + // Cancels a pending asynchronous operation started by + // grpc_call_credentials_get_request_metadata() with the corresponding + // value of \a md_array. + virtual void cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) GRPC_ABSTRACT; + + const char* type() const { return type_; } + + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* type_; }; -grpc_call_credentials* grpc_call_credentials_ref(grpc_call_credentials* creds); -void grpc_call_credentials_unref(grpc_call_credentials* creds); - -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_request_metadata will -/// be invoked asynchronously when complete. \a md_array will be populated -/// with the resulting metadata once complete. -bool grpc_call_credentials_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error); - -/// Cancels a pending asynchronous operation started by -/// grpc_call_credentials_get_request_metadata() with the corresponding -/// value of \a md_array. -void grpc_call_credentials_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error); - /* Metadata-only credentials with the specified key and value where asynchronicity can be simulated for testing. */ grpc_call_credentials* grpc_md_only_test_credentials_create( @@ -203,26 +201,40 @@ grpc_call_credentials* grpc_md_only_test_credentials_create( /* --- grpc_server_credentials. --- */ -typedef struct { - void (*destruct)(grpc_server_credentials* c); - grpc_security_status (*create_security_connector)( - grpc_server_credentials* c, grpc_server_security_connector** sc); -} grpc_server_credentials_vtable; - -struct grpc_server_credentials { - const grpc_server_credentials_vtable* vtable; - const char* type; - gpr_refcount refcount; - grpc_auth_metadata_processor processor; -}; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_server_credentials + : public grpc_core::RefCounted { + public: + explicit grpc_server_credentials(const char* type) : type_(type) {} -grpc_security_status grpc_server_credentials_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc); + virtual ~grpc_server_credentials() { DestroyProcessor(); } -grpc_server_credentials* grpc_server_credentials_ref( - grpc_server_credentials* creds); + virtual grpc_core::RefCountedPtr + create_security_connector() GRPC_ABSTRACT; -void grpc_server_credentials_unref(grpc_server_credentials* creds); + const char* type() const { return type_; } + + const grpc_auth_metadata_processor& auth_metadata_processor() const { + return processor_; + } + void set_auth_metadata_processor( + const grpc_auth_metadata_processor& processor); + + GRPC_ABSTRACT_BASE_CLASS + + private: + void DestroyProcessor() { + if (processor_.destroy != nullptr && processor_.state != nullptr) { + processor_.destroy(processor_.state); + } + } + + const char* type_; + grpc_auth_metadata_processor processor_ = + grpc_auth_metadata_processor(); // Zero-initialize the C struct. +}; #define GRPC_SERVER_CREDENTIALS_ARG "grpc.server_credentials" @@ -233,15 +245,27 @@ grpc_server_credentials* grpc_find_server_credentials_in_args( /* -- Credentials Metadata Request. -- */ -typedef struct { - grpc_call_credentials* creds; +struct grpc_credentials_metadata_request { + explicit grpc_credentials_metadata_request( + grpc_core::RefCountedPtr creds) + : creds(std::move(creds)) {} + ~grpc_credentials_metadata_request() { + grpc_http_response_destroy(&response); + } + + grpc_core::RefCountedPtr creds; grpc_http_response response; -} grpc_credentials_metadata_request; +}; -grpc_credentials_metadata_request* grpc_credentials_metadata_request_create( - grpc_call_credentials* creds); +inline grpc_credentials_metadata_request* +grpc_credentials_metadata_request_create( + grpc_core::RefCountedPtr creds) { + return grpc_core::New(std::move(creds)); +} -void grpc_credentials_metadata_request_destroy( - grpc_credentials_metadata_request* r); +inline void grpc_credentials_metadata_request_destroy( + grpc_credentials_metadata_request* r) { + grpc_core::Delete(r); +} #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/fake/fake_credentials.cc b/src/core/lib/security/credentials/fake/fake_credentials.cc index d3e0e8c816..337dd7679f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.cc +++ b/src/core/lib/security/credentials/fake/fake_credentials.cc @@ -33,49 +33,45 @@ /* -- Fake transport security credentials. -- */ -static grpc_security_status fake_transport_security_create_security_connector( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, - const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - *sc = - grpc_fake_channel_security_connector_create(c, call_creds, target, args); - return GRPC_SECURITY_OK; -} - -static grpc_security_status -fake_transport_security_server_create_security_connector( - grpc_server_credentials* c, grpc_server_security_connector** sc) { - *sc = grpc_fake_server_security_connector_create(c); - return GRPC_SECURITY_OK; -} +namespace { +class grpc_fake_channel_credentials final : public grpc_channel_credentials { + public: + grpc_fake_channel_credentials() + : grpc_channel_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_channel_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override { + return grpc_fake_channel_security_connector_create( + this->Ref(), std::move(call_creds), target, args); + } +}; + +class grpc_fake_server_credentials final : public grpc_server_credentials { + public: + grpc_fake_server_credentials() + : grpc_server_credentials( + GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY) {} + ~grpc_fake_server_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector() override { + return grpc_fake_server_security_connector_create(this->Ref()); + } +}; +} // namespace -static grpc_channel_credentials_vtable - fake_transport_security_credentials_vtable = { - nullptr, fake_transport_security_create_security_connector, nullptr}; - -static grpc_server_credentials_vtable - fake_transport_security_server_credentials_vtable = { - nullptr, fake_transport_security_server_create_security_connector}; - -grpc_channel_credentials* grpc_fake_transport_security_credentials_create( - void) { - grpc_channel_credentials* c = static_cast( - gpr_zalloc(sizeof(grpc_channel_credentials))); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - c->vtable = &fake_transport_security_credentials_vtable; - gpr_ref_init(&c->refcount, 1); - return c; +grpc_channel_credentials* grpc_fake_transport_security_credentials_create() { + return grpc_core::New(); } -grpc_server_credentials* grpc_fake_transport_security_server_credentials_create( - void) { - grpc_server_credentials* c = static_cast( - gpr_malloc(sizeof(grpc_server_credentials))); - memset(c, 0, sizeof(grpc_server_credentials)); - c->type = GRPC_CHANNEL_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; - gpr_ref_init(&c->refcount, 1); - c->vtable = &fake_transport_security_server_credentials_vtable; - return c; +grpc_server_credentials* +grpc_fake_transport_security_server_credentials_create() { + return grpc_core::New(); } grpc_arg grpc_fake_transport_expected_targets_arg(char* expected_targets) { @@ -92,46 +88,25 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only test credentials. -- */ -static void md_only_test_destruct(grpc_call_credentials* creds) { - grpc_md_only_test_credentials* c = - reinterpret_cast(creds); - GRPC_MDELEM_UNREF(c->md); -} - -static bool md_only_test_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_md_only_test_credentials* c = - reinterpret_cast(creds); - grpc_credentials_mdelem_array_add(md_array, c->md); - if (c->is_async) { +bool grpc_md_only_test_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, md_); + if (is_async_) { GRPC_CLOSURE_SCHED(on_request_metadata, GRPC_ERROR_NONE); return false; } return true; } -static void md_only_test_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_md_only_test_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable md_only_test_vtable = { - md_only_test_destruct, md_only_test_get_request_metadata, - md_only_test_cancel_get_request_metadata}; - grpc_call_credentials* grpc_md_only_test_credentials_create( const char* md_key, const char* md_value, bool is_async) { - grpc_md_only_test_credentials* c = - static_cast( - gpr_zalloc(sizeof(grpc_md_only_test_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &md_only_test_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->md = grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), - grpc_slice_from_copied_string(md_value)); - c->is_async = is_async; - return &c->base; + return grpc_core::New(md_key, md_value, + is_async); } diff --git a/src/core/lib/security/credentials/fake/fake_credentials.h b/src/core/lib/security/credentials/fake/fake_credentials.h index e89e6e24cc..b7f6a1909f 100644 --- a/src/core/lib/security/credentials/fake/fake_credentials.h +++ b/src/core/lib/security/credentials/fake/fake_credentials.h @@ -55,10 +55,28 @@ const char* grpc_fake_transport_get_expected_targets( /* -- Metadata-only Test credentials. -- */ -typedef struct { - grpc_call_credentials base; - grpc_mdelem md; - bool is_async; -} grpc_md_only_test_credentials; +class grpc_md_only_test_credentials : public grpc_call_credentials { + public: + grpc_md_only_test_credentials(const char* md_key, const char* md_value, + bool is_async) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + md_(grpc_mdelem_from_slices(grpc_slice_from_copied_string(md_key), + grpc_slice_from_copied_string(md_value))), + is_async_(is_async) {} + ~grpc_md_only_test_credentials() override { GRPC_MDELEM_UNREF(md_); } + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem md_; + bool is_async_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_FAKE_FAKE_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.cc b/src/core/lib/security/credentials/google_default/google_default_credentials.cc index 0674540d01..a86a17d586 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.cc +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/http/httpcli.h" #include "src/core/lib/http/parser.h" #include "src/core/lib/iomgr/load_file.h" @@ -72,20 +73,11 @@ typedef struct { grpc_http_response response; } metadata_server_detector; -static void google_default_credentials_destruct( - grpc_channel_credentials* creds) { - grpc_google_default_channel_credentials* c = - reinterpret_cast(creds); - grpc_channel_credentials_unref(c->alts_creds); - grpc_channel_credentials_unref(c->ssl_creds); -} - -static grpc_security_status google_default_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr +grpc_google_default_channel_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_google_default_channel_credentials* c = - reinterpret_cast(creds); + grpc_channel_args** new_args) { bool is_grpclb_load_balancer = grpc_channel_arg_get_bool( grpc_channel_args_find(args, GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER), false); @@ -95,22 +87,22 @@ static grpc_security_status google_default_create_security_connector( false); bool use_alts = is_grpclb_load_balancer || is_backend_from_grpclb_load_balancer; - grpc_security_status status = GRPC_SECURITY_ERROR; /* Return failure if ALTS is selected but not running on GCE. */ if (use_alts && !g_is_on_gce) { gpr_log(GPR_ERROR, "ALTS is selected, but not running on GCE."); - goto end; + return nullptr; } - status = use_alts ? c->alts_creds->vtable->create_security_connector( - c->alts_creds, call_creds, target, args, sc, new_args) - : c->ssl_creds->vtable->create_security_connector( - c->ssl_creds, call_creds, target, args, sc, new_args); -/* grpclb-specific channel args are removed from the channel args set - * to ensure backends and fallback adresses will have the same set of channel - * args. By doing that, it guarantees the connections to backends will not be - * torn down and re-connected when switching in and out of fallback mode. - */ -end: + + grpc_core::RefCountedPtr sc = + use_alts ? alts_creds_->create_security_connector(call_creds, target, + args, new_args) + : ssl_creds_->create_security_connector(call_creds, target, args, + new_args); + /* grpclb-specific channel args are removed from the channel args set + * to ensure backends and fallback adresses will have the same set of channel + * args. By doing that, it guarantees the connections to backends will not be + * torn down and re-connected when switching in and out of fallback mode. + */ if (use_alts) { static const char* args_to_remove[] = { GRPC_ARG_ADDRESS_IS_GRPCLB_LOAD_BALANCER, @@ -119,13 +111,9 @@ end: *new_args = grpc_channel_args_copy_and_add_and_remove( args, args_to_remove, GPR_ARRAY_SIZE(args_to_remove), nullptr, 0); } - return status; + return sc; } -static grpc_channel_credentials_vtable google_default_credentials_vtable = { - google_default_credentials_destruct, - google_default_create_security_connector, nullptr}; - static void on_metadata_server_detection_http_response(void* user_data, grpc_error* error) { metadata_server_detector* detector = @@ -215,11 +203,11 @@ static int is_metadata_server_reachable() { /* Takes ownership of creds_path if not NULL. */ static grpc_error* create_default_creds_from_path( - char* creds_path, grpc_call_credentials** creds) { + char* creds_path, grpc_core::RefCountedPtr* creds) { grpc_json* json = nullptr; grpc_auth_json_key key; grpc_auth_refresh_token token; - grpc_call_credentials* result = nullptr; + grpc_core::RefCountedPtr result; grpc_slice creds_data = grpc_empty_slice(); grpc_error* error = GRPC_ERROR_NONE; if (creds_path == nullptr) { @@ -276,9 +264,9 @@ end: return error; } -grpc_channel_credentials* grpc_google_default_credentials_create(void) { +grpc_channel_credentials* grpc_google_default_credentials_create() { grpc_channel_credentials* result = nullptr; - grpc_call_credentials* call_creds = nullptr; + grpc_core::RefCountedPtr call_creds; grpc_error* error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Failed to create Google credentials"); grpc_error* err; @@ -316,7 +304,8 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { gpr_mu_unlock(&g_state_mu); if (g_metadata_server_available) { - call_creds = grpc_google_compute_engine_credentials_create(nullptr); + call_creds = grpc_core::RefCountedPtr( + grpc_google_compute_engine_credentials_create(nullptr)); if (call_creds == nullptr) { error = grpc_error_add_child( error, GRPC_ERROR_CREATE_FROM_STATIC_STRING( @@ -327,23 +316,23 @@ grpc_channel_credentials* grpc_google_default_credentials_create(void) { end: if (call_creds != nullptr) { /* Create google default credentials. */ - auto creds = static_cast( - gpr_zalloc(sizeof(grpc_google_default_channel_credentials))); - creds->base.vtable = &google_default_credentials_vtable; - creds->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT; - gpr_ref_init(&creds->base.refcount, 1); - creds->ssl_creds = + grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create(nullptr, nullptr, nullptr, nullptr); - GPR_ASSERT(creds->ssl_creds != nullptr); + GPR_ASSERT(ssl_creds != nullptr); grpc_alts_credentials_options* options = grpc_alts_credentials_client_options_create(); - creds->alts_creds = grpc_alts_credentials_create(options); + grpc_channel_credentials* alts_creds = + grpc_alts_credentials_create(options); grpc_alts_credentials_options_destroy(options); - result = grpc_composite_channel_credentials_create(&creds->base, call_creds, - nullptr); + auto creds = + grpc_core::MakeRefCounted( + alts_creds != nullptr ? alts_creds->Ref() : nullptr, + ssl_creds != nullptr ? ssl_creds->Ref() : nullptr); + if (ssl_creds) ssl_creds->Unref(); + if (alts_creds) alts_creds->Unref(); + result = grpc_composite_channel_credentials_create( + creds.get(), call_creds.get(), nullptr); GPR_ASSERT(result != nullptr); - grpc_channel_credentials_unref(&creds->base); - grpc_call_credentials_unref(call_creds); } else { gpr_log(GPR_ERROR, "Could not create google default credentials: %s", grpc_error_string(error)); diff --git a/src/core/lib/security/credentials/google_default/google_default_credentials.h b/src/core/lib/security/credentials/google_default/google_default_credentials.h index b9e2efb04f..bf00f7285a 100644 --- a/src/core/lib/security/credentials/google_default/google_default_credentials.h +++ b/src/core/lib/security/credentials/google_default/google_default_credentials.h @@ -21,6 +21,7 @@ #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/credentials.h" #define GRPC_GOOGLE_CLOUD_SDK_CONFIG_DIRECTORY "gcloud" @@ -39,11 +40,33 @@ "/" GRPC_GOOGLE_WELL_KNOWN_CREDENTIALS_FILE #endif -typedef struct { - grpc_channel_credentials base; - grpc_channel_credentials* alts_creds; - grpc_channel_credentials* ssl_creds; -} grpc_google_default_channel_credentials; +class grpc_google_default_channel_credentials + : public grpc_channel_credentials { + public: + grpc_google_default_channel_credentials( + grpc_core::RefCountedPtr alts_creds, + grpc_core::RefCountedPtr ssl_creds) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_GOOGLE_DEFAULT), + alts_creds_(std::move(alts_creds)), + ssl_creds_(std::move(ssl_creds)) {} + + ~grpc_google_default_channel_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + const grpc_channel_credentials* alts_creds() const { + return alts_creds_.get(); + } + const grpc_channel_credentials* ssl_creds() const { return ssl_creds_.get(); } + + private: + grpc_core::RefCountedPtr alts_creds_; + grpc_core::RefCountedPtr ssl_creds_; +}; namespace grpc_core { namespace internal { diff --git a/src/core/lib/security/credentials/iam/iam_credentials.cc b/src/core/lib/security/credentials/iam/iam_credentials.cc index 5d92fa88c4..5cd561f676 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.cc +++ b/src/core/lib/security/credentials/iam/iam_credentials.cc @@ -22,6 +22,7 @@ #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include @@ -29,32 +30,37 @@ #include #include -static void iam_destruct(grpc_call_credentials* creds) { - grpc_google_iam_credentials* c = - reinterpret_cast(creds); - grpc_credentials_mdelem_array_destroy(&c->md_array); +grpc_google_iam_credentials::~grpc_google_iam_credentials() { + grpc_credentials_mdelem_array_destroy(&md_array_); } -static bool iam_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_google_iam_credentials* c = - reinterpret_cast(creds); - grpc_credentials_mdelem_array_append(md_array, &c->md_array); +bool grpc_google_iam_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_append(md_array, &md_array_); return true; } -static void iam_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_google_iam_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable iam_vtable = { - iam_destruct, iam_get_request_metadata, iam_cancel_get_request_metadata}; +grpc_google_iam_credentials::grpc_google_iam_credentials( + const char* token, const char* authority_selector) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_IAM) { + grpc_mdelem md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), + grpc_slice_from_copied_string(token)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); + md = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), + grpc_slice_from_copied_string(authority_selector)); + grpc_credentials_mdelem_array_add(&md_array_, md); + GRPC_MDELEM_UNREF(md); +} grpc_call_credentials* grpc_google_iam_credentials_create( const char* token, const char* authority_selector, void* reserved) { @@ -66,21 +72,7 @@ grpc_call_credentials* grpc_google_iam_credentials_create( GPR_ASSERT(reserved == nullptr); GPR_ASSERT(token != nullptr); GPR_ASSERT(authority_selector != nullptr); - grpc_google_iam_credentials* c = - static_cast(gpr_zalloc(sizeof(*c))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_IAM; - c->base.vtable = &iam_vtable; - gpr_ref_init(&c->base.refcount, 1); - grpc_mdelem md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORIZATION_TOKEN_METADATA_KEY), - grpc_slice_from_copied_string(token)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_IAM_AUTHORITY_SELECTOR_METADATA_KEY), - grpc_slice_from_copied_string(authority_selector)); - grpc_credentials_mdelem_array_add(&c->md_array, md); - GRPC_MDELEM_UNREF(md); - - return &c->base; + return grpc_core::MakeRefCounted( + token, authority_selector) + .release(); } diff --git a/src/core/lib/security/credentials/iam/iam_credentials.h b/src/core/lib/security/credentials/iam/iam_credentials.h index a45710fe0f..36f5ee8930 100644 --- a/src/core/lib/security/credentials/iam/iam_credentials.h +++ b/src/core/lib/security/credentials/iam/iam_credentials.h @@ -23,9 +23,23 @@ #include "src/core/lib/security/credentials/credentials.h" -typedef struct { - grpc_call_credentials base; - grpc_credentials_mdelem_array md_array; -} grpc_google_iam_credentials; +class grpc_google_iam_credentials : public grpc_call_credentials { + public: + grpc_google_iam_credentials(const char* token, + const char* authority_selector); + ~grpc_google_iam_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_credentials_mdelem_array md_array_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_IAM_IAM_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.cc b/src/core/lib/security/credentials/jwt/jwt_credentials.cc index 05c08a68b0..f2591a1ea5 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.cc +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.cc @@ -23,6 +23,8 @@ #include #include +#include "src/core/lib/gprpp/ref_counted.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/surface/api_trace.h" #include @@ -30,71 +32,66 @@ #include #include -static void jwt_reset_cache(grpc_service_account_jwt_access_credentials* c) { - GRPC_MDELEM_UNREF(c->cached.jwt_md); - c->cached.jwt_md = GRPC_MDNULL; - if (c->cached.service_url != nullptr) { - gpr_free(c->cached.service_url); - c->cached.service_url = nullptr; +void grpc_service_account_jwt_access_credentials::reset_cache() { + GRPC_MDELEM_UNREF(cached_.jwt_md); + cached_.jwt_md = GRPC_MDNULL; + if (cached_.service_url != nullptr) { + gpr_free(cached_.service_url); + cached_.service_url = nullptr; } - c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); + cached_.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); } -static void jwt_destruct(grpc_call_credentials* creds) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast(creds); - grpc_auth_json_key_destruct(&c->key); - jwt_reset_cache(c); - gpr_mu_destroy(&c->cache_mu); +grpc_service_account_jwt_access_credentials:: + ~grpc_service_account_jwt_access_credentials() { + grpc_auth_json_key_destruct(&key_); + reset_cache(); + gpr_mu_destroy(&cache_mu_); } -static bool jwt_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_service_account_jwt_access_credentials* c = - reinterpret_cast(creds); +bool grpc_service_account_jwt_access_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { gpr_timespec refresh_threshold = gpr_time_from_seconds( GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN); /* See if we can return a cached jwt. */ grpc_mdelem jwt_md = GRPC_MDNULL; { - gpr_mu_lock(&c->cache_mu); - if (c->cached.service_url != nullptr && - strcmp(c->cached.service_url, context.service_url) == 0 && - !GRPC_MDISNULL(c->cached.jwt_md) && - (gpr_time_cmp(gpr_time_sub(c->cached.jwt_expiration, - gpr_now(GPR_CLOCK_REALTIME)), - refresh_threshold) > 0)) { - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + gpr_mu_lock(&cache_mu_); + if (cached_.service_url != nullptr && + strcmp(cached_.service_url, context.service_url) == 0 && + !GRPC_MDISNULL(cached_.jwt_md) && + (gpr_time_cmp( + gpr_time_sub(cached_.jwt_expiration, gpr_now(GPR_CLOCK_REALTIME)), + refresh_threshold) > 0)) { + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (GRPC_MDISNULL(jwt_md)) { char* jwt = nullptr; /* Generate a new jwt. */ - gpr_mu_lock(&c->cache_mu); - jwt_reset_cache(c); - jwt = grpc_jwt_encode_and_sign(&c->key, context.service_url, - c->jwt_lifetime, nullptr); + gpr_mu_lock(&cache_mu_); + reset_cache(); + jwt = grpc_jwt_encode_and_sign(&key_, context.service_url, jwt_lifetime_, + nullptr); if (jwt != nullptr) { char* md_value; gpr_asprintf(&md_value, "Bearer %s", jwt); gpr_free(jwt); - c->cached.jwt_expiration = - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), c->jwt_lifetime); - c->cached.service_url = gpr_strdup(context.service_url); - c->cached.jwt_md = grpc_mdelem_from_slices( + cached_.jwt_expiration = + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), jwt_lifetime_); + cached_.service_url = gpr_strdup(context.service_url); + cached_.jwt_md = grpc_mdelem_from_slices( grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), grpc_slice_from_copied_string(md_value)); gpr_free(md_value); - jwt_md = GRPC_MDELEM_REF(c->cached.jwt_md); + jwt_md = GRPC_MDELEM_REF(cached_.jwt_md); } - gpr_mu_unlock(&c->cache_mu); + gpr_mu_unlock(&cache_mu_); } if (!GRPC_MDISNULL(jwt_md)) { @@ -106,29 +103,15 @@ static bool jwt_get_request_metadata(grpc_call_credentials* creds, return true; } -static void jwt_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_service_account_jwt_access_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable jwt_vtable = { - jwt_destruct, jwt_get_request_metadata, jwt_cancel_get_request_metadata}; - -grpc_call_credentials* -grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key key, gpr_timespec token_lifetime) { - grpc_service_account_jwt_access_credentials* c; - if (!grpc_auth_json_key_is_valid(&key)) { - gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); - return nullptr; - } - c = static_cast( - gpr_zalloc(sizeof(grpc_service_account_jwt_access_credentials))); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_JWT; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &jwt_vtable; - c->key = key; +grpc_service_account_jwt_access_credentials:: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_JWT), key_(key) { gpr_timespec max_token_lifetime = grpc_max_auth_token_lifetime(); if (gpr_time_cmp(token_lifetime, max_token_lifetime) > 0) { gpr_log(GPR_INFO, @@ -136,10 +119,20 @@ grpc_service_account_jwt_access_credentials_create_from_auth_json_key( static_cast(max_token_lifetime.tv_sec)); token_lifetime = grpc_max_auth_token_lifetime(); } - c->jwt_lifetime = token_lifetime; - gpr_mu_init(&c->cache_mu); - jwt_reset_cache(c); - return &c->base; + jwt_lifetime_ = token_lifetime; + gpr_mu_init(&cache_mu_); + reset_cache(); +} + +grpc_core::RefCountedPtr +grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key key, gpr_timespec token_lifetime) { + if (!grpc_auth_json_key_is_valid(&key)) { + gpr_log(GPR_ERROR, "Invalid input for jwt credentials creation"); + return nullptr; + } + return grpc_core::MakeRefCounted( + key, token_lifetime); } static char* redact_private_key(const char* json_key) { @@ -182,9 +175,7 @@ grpc_call_credentials* grpc_service_account_jwt_access_credentials_create( } GPR_ASSERT(reserved == nullptr); grpc_core::ExecCtx exec_ctx; - grpc_call_credentials* creds = - grpc_service_account_jwt_access_credentials_create_from_auth_json_key( - grpc_auth_json_key_create_from_string(json_key), token_lifetime); - - return creds; + return grpc_service_account_jwt_access_credentials_create_from_auth_json_key( + grpc_auth_json_key_create_from_string(json_key), token_lifetime) + .release(); } diff --git a/src/core/lib/security/credentials/jwt/jwt_credentials.h b/src/core/lib/security/credentials/jwt/jwt_credentials.h index 5c3d34aa56..5af909f44d 100644 --- a/src/core/lib/security/credentials/jwt/jwt_credentials.h +++ b/src/core/lib/security/credentials/jwt/jwt_credentials.h @@ -24,25 +24,44 @@ #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/jwt/json_token.h" -typedef struct { - grpc_call_credentials base; +class grpc_service_account_jwt_access_credentials + : public grpc_call_credentials { + public: + grpc_service_account_jwt_access_credentials(grpc_auth_json_key key, + gpr_timespec token_lifetime); + ~grpc_service_account_jwt_access_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + const gpr_timespec& jwt_lifetime() const { return jwt_lifetime_; } + const grpc_auth_json_key& key() const { return key_; } + + private: + void reset_cache(); // Have a simple cache for now with just 1 entry. We could have a map based on // the service_url for a more sophisticated one. - gpr_mu cache_mu; + gpr_mu cache_mu_; struct { - grpc_mdelem jwt_md; - char* service_url; + grpc_mdelem jwt_md = GRPC_MDNULL; + char* service_url = nullptr; gpr_timespec jwt_expiration; - } cached; + } cached_; - grpc_auth_json_key key; - gpr_timespec jwt_lifetime; -} grpc_service_account_jwt_access_credentials; + grpc_auth_json_key key_; + gpr_timespec jwt_lifetime_; +}; // Private constructor for jwt credentials from an already parsed json key. // Takes ownership of the key. -grpc_call_credentials* +grpc_core::RefCountedPtr grpc_service_account_jwt_access_credentials_create_from_auth_json_key( grpc_auth_json_key key, gpr_timespec token_lifetime); diff --git a/src/core/lib/security/credentials/local/local_credentials.cc b/src/core/lib/security/credentials/local/local_credentials.cc index 3ccfa2b908..6f6f95a34a 100644 --- a/src/core/lib/security/credentials/local/local_credentials.cc +++ b/src/core/lib/security/credentials/local/local_credentials.cc @@ -29,49 +29,36 @@ #define GRPC_CREDENTIALS_TYPE_LOCAL "Local" -static void local_credentials_destruct(grpc_channel_credentials* creds) {} - -static void local_server_credentials_destruct(grpc_server_credentials* creds) {} - -static grpc_security_status local_create_security_connector( - grpc_channel_credentials* creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - const grpc_channel_args* args, grpc_channel_security_connector** sc, +grpc_core::RefCountedPtr +grpc_local_credentials::create_security_connector( + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name, const grpc_channel_args* args, grpc_channel_args** new_args) { return grpc_local_channel_security_connector_create( - creds, request_metadata_creds, args, target_name, sc); + this->Ref(), std::move(request_metadata_creds), args, target_name); } -static grpc_security_status local_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_local_server_security_connector_create(creds, sc); +grpc_core::RefCountedPtr +grpc_local_server_credentials::create_security_connector() { + return grpc_local_server_security_connector_create(this->Ref()); } -static const grpc_channel_credentials_vtable local_credentials_vtable = { - local_credentials_destruct, local_create_security_connector, - /*duplicate_without_call_credentials=*/nullptr}; - -static const grpc_server_credentials_vtable local_server_credentials_vtable = { - local_server_credentials_destruct, local_server_create_security_connector}; +grpc_local_credentials::grpc_local_credentials( + grpc_local_connect_type connect_type) + : grpc_channel_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} grpc_channel_credentials* grpc_local_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast( - gpr_zalloc(sizeof(grpc_local_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New(connect_type); } +grpc_local_server_credentials::grpc_local_server_credentials( + grpc_local_connect_type connect_type) + : grpc_server_credentials(GRPC_CREDENTIALS_TYPE_LOCAL), + connect_type_(connect_type) {} + grpc_server_credentials* grpc_local_server_credentials_create( grpc_local_connect_type connect_type) { - auto creds = static_cast( - gpr_zalloc(sizeof(grpc_local_server_credentials))); - creds->connect_type = connect_type; - creds->base.type = GRPC_CREDENTIALS_TYPE_LOCAL; - creds->base.vtable = &local_server_credentials_vtable; - gpr_ref_init(&creds->base.refcount, 1); - return &creds->base; + return grpc_core::New(connect_type); } diff --git a/src/core/lib/security/credentials/local/local_credentials.h b/src/core/lib/security/credentials/local/local_credentials.h index 47358b04bc..60a8a4f64c 100644 --- a/src/core/lib/security/credentials/local/local_credentials.h +++ b/src/core/lib/security/credentials/local/local_credentials.h @@ -25,16 +25,37 @@ #include "src/core/lib/security/credentials/credentials.h" -/* Main struct for grpc local channel credential. */ -typedef struct grpc_local_credentials { - grpc_channel_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_credentials; - -/* Main struct for grpc local server credential. */ -typedef struct grpc_local_server_credentials { - grpc_server_credentials base; - grpc_local_connect_type connect_type; -} grpc_local_server_credentials; +/* Main class for grpc local channel credential. */ +class grpc_local_credentials final : public grpc_channel_credentials { + public: + explicit grpc_local_credentials(grpc_local_connect_type connect_type); + ~grpc_local_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; + +/* Main class for grpc local server credential. */ +class grpc_local_server_credentials final : public grpc_server_credentials { + public: + explicit grpc_local_server_credentials(grpc_local_connect_type connect_type); + ~grpc_local_server_credentials() override = default; + + grpc_core::RefCountedPtr + create_security_connector() override; + + grpc_local_connect_type connect_type() const { return connect_type_; } + + private: + grpc_local_connect_type connect_type_; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_LOCAL_LOCAL_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc index 44b093557f..ad63b01e75 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.cc @@ -22,6 +22,7 @@ #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/util/json_util.h" #include "src/core/lib/surface/api_trace.h" @@ -105,13 +106,12 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token) { // Oauth2 Token Fetcher credentials. // -static void oauth2_token_fetcher_destruct(grpc_call_credentials* creds) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast(creds); - GRPC_MDELEM_UNREF(c->access_token_md); - gpr_mu_destroy(&c->mu); - grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&c->pollent)); - grpc_httpcli_context_destroy(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials:: + ~grpc_oauth2_token_fetcher_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); + gpr_mu_destroy(&mu_); + grpc_pollset_set_destroy(grpc_polling_entity_pollset_set(&pollent_)); + grpc_httpcli_context_destroy(&httpcli_context_); } grpc_credentials_status @@ -209,25 +209,29 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, grpc_credentials_metadata_request* r = static_cast(user_data); grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast(r->creds); + reinterpret_cast(r->creds.get()); + c->on_http_response(r, error); +} + +void grpc_oauth2_token_fetcher_credentials::on_http_response( + grpc_credentials_metadata_request* r, grpc_error* error) { grpc_mdelem access_token_md = GRPC_MDNULL; grpc_millis token_lifetime; grpc_credentials_status status = grpc_oauth2_token_fetcher_credentials_parse_server_response( &r->response, &access_token_md, &token_lifetime); // Update cache and grab list of pending requests. - gpr_mu_lock(&c->mu); - c->token_fetch_pending = false; - c->access_token_md = GRPC_MDELEM_REF(access_token_md); - c->token_expiration = + gpr_mu_lock(&mu_); + token_fetch_pending_ = false; + access_token_md_ = GRPC_MDELEM_REF(access_token_md); + token_expiration_ = status == GRPC_CREDENTIALS_OK ? gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_millis(token_lifetime, GPR_TIMESPAN)) : gpr_inf_past(GPR_CLOCK_MONOTONIC); - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; - c->pending_requests = nullptr; - gpr_mu_unlock(&c->mu); + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; + pending_requests_ = nullptr; + gpr_mu_unlock(&mu_); // Invoke callbacks for all pending requests. while (pending_request != nullptr) { if (status == GRPC_CREDENTIALS_OK) { @@ -239,42 +243,40 @@ static void on_oauth2_token_fetcher_http_response(void* user_data, } GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, error); grpc_polling_entity_del_from_pollset_set( - pending_request->pollent, grpc_polling_entity_pollset_set(&c->pollent)); + pending_request->pollent, grpc_polling_entity_pollset_set(&pollent_)); grpc_oauth2_pending_get_request_metadata* prev = pending_request; pending_request = pending_request->next; gpr_free(prev); } GRPC_MDELEM_UNREF(access_token_md); - grpc_call_credentials_unref(r->creds); + Unref(); grpc_credentials_metadata_request_destroy(r); } -static bool oauth2_token_fetcher_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast(creds); +bool grpc_oauth2_token_fetcher_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { // Check if we can use the cached token. grpc_millis refresh_threshold = GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS * GPR_MS_PER_SEC; grpc_mdelem cached_access_token_md = GRPC_MDNULL; - gpr_mu_lock(&c->mu); - if (!GRPC_MDISNULL(c->access_token_md) && + gpr_mu_lock(&mu_); + if (!GRPC_MDISNULL(access_token_md_) && gpr_time_cmp( - gpr_time_sub(c->token_expiration, gpr_now(GPR_CLOCK_MONOTONIC)), + gpr_time_sub(token_expiration_, gpr_now(GPR_CLOCK_MONOTONIC)), gpr_time_from_seconds(GRPC_SECURE_TOKEN_REFRESH_THRESHOLD_SECS, GPR_TIMESPAN)) > 0) { - cached_access_token_md = GRPC_MDELEM_REF(c->access_token_md); + cached_access_token_md = GRPC_MDELEM_REF(access_token_md_); } if (!GRPC_MDISNULL(cached_access_token_md)) { - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); grpc_credentials_mdelem_array_add(md_array, cached_access_token_md); GRPC_MDELEM_UNREF(cached_access_token_md); return true; } // Couldn't get the token from the cache. - // Add request to c->pending_requests and start a new fetch if needed. + // Add request to pending_requests_ and start a new fetch if needed. grpc_oauth2_pending_get_request_metadata* pending_request = static_cast( gpr_malloc(sizeof(*pending_request))); @@ -282,41 +284,37 @@ static bool oauth2_token_fetcher_get_request_metadata( pending_request->on_request_metadata = on_request_metadata; pending_request->pollent = pollent; grpc_polling_entity_add_to_pollset_set( - pollent, grpc_polling_entity_pollset_set(&c->pollent)); - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; + pollent, grpc_polling_entity_pollset_set(&pollent_)); + pending_request->next = pending_requests_; + pending_requests_ = pending_request; bool start_fetch = false; - if (!c->token_fetch_pending) { - c->token_fetch_pending = true; + if (!token_fetch_pending_) { + token_fetch_pending_ = true; start_fetch = true; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); if (start_fetch) { - grpc_call_credentials_ref(creds); - c->fetch_func(grpc_credentials_metadata_request_create(creds), - &c->httpcli_context, &c->pollent, - on_oauth2_token_fetcher_http_response, - grpc_core::ExecCtx::Get()->Now() + refresh_threshold); + Ref().release(); + fetch_oauth2(grpc_credentials_metadata_request_create(this->Ref()), + &httpcli_context_, &pollent_, + on_oauth2_token_fetcher_http_response, + grpc_core::ExecCtx::Get()->Now() + refresh_threshold); } return false; } -static void oauth2_token_fetcher_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_oauth2_token_fetcher_credentials* c = - reinterpret_cast(creds); - gpr_mu_lock(&c->mu); +void grpc_oauth2_token_fetcher_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); grpc_oauth2_pending_get_request_metadata* prev = nullptr; - grpc_oauth2_pending_get_request_metadata* pending_request = - c->pending_requests; + grpc_oauth2_pending_get_request_metadata* pending_request = pending_requests_; while (pending_request != nullptr) { if (pending_request->md_array == md_array) { // Remove matching pending request from the list. if (prev != nullptr) { prev->next = pending_request->next; } else { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } // Invoke the callback immediately with an error. GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, @@ -327,96 +325,89 @@ static void oauth2_token_fetcher_cancel_get_request_metadata( prev = pending_request; pending_request = pending_request->next; } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials* c, - grpc_fetch_oauth2_func fetch_func) { - memset(c, 0, sizeof(grpc_oauth2_token_fetcher_credentials)); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - gpr_ref_init(&c->base.refcount, 1); - gpr_mu_init(&c->mu); - c->token_expiration = gpr_inf_past(GPR_CLOCK_MONOTONIC); - c->fetch_func = fetch_func; - c->pollent = - grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create()); - grpc_httpcli_context_init(&c->httpcli_context); +grpc_oauth2_token_fetcher_credentials::grpc_oauth2_token_fetcher_credentials() + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2), + token_expiration_(gpr_inf_past(GPR_CLOCK_MONOTONIC)), + pollent_(grpc_polling_entity_create_from_pollset_set( + grpc_pollset_set_create())) { + gpr_mu_init(&mu_); + grpc_httpcli_context_init(&httpcli_context_); } // // Google Compute Engine credentials. // -static grpc_call_credentials_vtable compute_engine_vtable = { - oauth2_token_fetcher_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; +namespace { + +class grpc_compute_engine_token_fetcher_credentials + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_compute_engine_token_fetcher_credentials() = default; + ~grpc_compute_engine_token_fetcher_credentials() override = default; + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* metadata_req, + grpc_httpcli_context* http_context, + grpc_polling_entity* pollent, + grpc_iomgr_cb_func response_cb, + grpc_millis deadline) override { + grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; + grpc_httpcli_request request; + memset(&request, 0, sizeof(grpc_httpcli_request)); + request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; + request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; + request.http.hdr_count = 1; + request.http.hdrs = &header; + /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host + channel. This would allow us to cancel an authentication query when under + extreme memory pressure. */ + grpc_resource_quota* resource_quota = + grpc_resource_quota_create("oauth2_credentials"); + grpc_httpcli_get(http_context, pollent, resource_quota, &request, deadline, + GRPC_CLOSURE_CREATE(response_cb, metadata_req, + grpc_schedule_on_exec_ctx), + &metadata_req->response); + grpc_resource_quota_unref_internal(resource_quota); + } +}; -static void compute_engine_fetch_oauth2( - grpc_credentials_metadata_request* metadata_req, - grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, - grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_http_header header = {(char*)"Metadata-Flavor", (char*)"Google"}; - grpc_httpcli_request request; - memset(&request, 0, sizeof(grpc_httpcli_request)); - request.host = (char*)GRPC_COMPUTE_ENGINE_METADATA_HOST; - request.http.path = (char*)GRPC_COMPUTE_ENGINE_METADATA_TOKEN_PATH; - request.http.hdr_count = 1; - request.http.hdrs = &header; - /* TODO(ctiller): Carry the resource_quota in ctx and share it with the host - channel. This would allow us to cancel an authentication query when under - extreme memory pressure. */ - grpc_resource_quota* resource_quota = - grpc_resource_quota_create("oauth2_credentials"); - grpc_httpcli_get( - httpcli_context, pollent, resource_quota, &request, deadline, - GRPC_CLOSURE_CREATE(response_cb, metadata_req, grpc_schedule_on_exec_ctx), - &metadata_req->response); - grpc_resource_quota_unref_internal(resource_quota); -} +} // namespace grpc_call_credentials* grpc_google_compute_engine_credentials_create( void* reserved) { - grpc_oauth2_token_fetcher_credentials* c = - static_cast( - gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials))); GRPC_API_TRACE("grpc_compute_engine_credentials_create(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - init_oauth2_token_fetcher(c, compute_engine_fetch_oauth2); - c->base.vtable = &compute_engine_vtable; - return &c->base; + return grpc_core::MakeRefCounted< + grpc_compute_engine_token_fetcher_credentials>() + .release(); } // // Google Refresh Token credentials. // -static void refresh_token_destruct(grpc_call_credentials* creds) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast(creds); - grpc_auth_refresh_token_destruct(&c->refresh_token); - oauth2_token_fetcher_destruct(&c->base.base); +grpc_google_refresh_token_credentials:: + ~grpc_google_refresh_token_credentials() { + grpc_auth_refresh_token_destruct(&refresh_token_); } -static grpc_call_credentials_vtable refresh_token_vtable = { - refresh_token_destruct, oauth2_token_fetcher_get_request_metadata, - oauth2_token_fetcher_cancel_get_request_metadata}; - -static void refresh_token_fetch_oauth2( +void grpc_google_refresh_token_credentials::fetch_oauth2( grpc_credentials_metadata_request* metadata_req, grpc_httpcli_context* httpcli_context, grpc_polling_entity* pollent, grpc_iomgr_cb_func response_cb, grpc_millis deadline) { - grpc_google_refresh_token_credentials* c = - reinterpret_cast( - metadata_req->creds); grpc_http_header header = {(char*)"Content-Type", (char*)"application/x-www-form-urlencoded"}; grpc_httpcli_request request; char* body = nullptr; gpr_asprintf(&body, GRPC_REFRESH_TOKEN_POST_BODY_FORMAT_STRING, - c->refresh_token.client_id, c->refresh_token.client_secret, - c->refresh_token.refresh_token); + refresh_token_.client_id, refresh_token_.client_secret, + refresh_token_.refresh_token); memset(&request, 0, sizeof(grpc_httpcli_request)); request.host = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_HOST; request.http.path = (char*)GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH; @@ -437,20 +428,19 @@ static void refresh_token_fetch_oauth2( gpr_free(body); } -grpc_call_credentials* +grpc_google_refresh_token_credentials::grpc_google_refresh_token_credentials( + grpc_auth_refresh_token refresh_token) + : refresh_token_(refresh_token) {} + +grpc_core::RefCountedPtr grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token refresh_token) { - grpc_google_refresh_token_credentials* c; if (!grpc_auth_refresh_token_is_valid(&refresh_token)) { gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation"); return nullptr; } - c = static_cast( - gpr_zalloc(sizeof(grpc_google_refresh_token_credentials))); - init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2); - c->base.base.vtable = &refresh_token_vtable; - c->refresh_token = refresh_token; - return &c->base.base; + return grpc_core::MakeRefCounted( + refresh_token); } static char* create_loggable_refresh_token(grpc_auth_refresh_token* token) { @@ -478,59 +468,50 @@ grpc_call_credentials* grpc_google_refresh_token_credentials_create( gpr_free(loggable_token); } GPR_ASSERT(reserved == nullptr); - return grpc_refresh_token_credentials_create_from_auth_refresh_token(token); + return grpc_refresh_token_credentials_create_from_auth_refresh_token(token) + .release(); } // // Oauth2 Access Token credentials. // -static void access_token_destruct(grpc_call_credentials* creds) { - grpc_access_token_credentials* c = - reinterpret_cast(creds); - GRPC_MDELEM_UNREF(c->access_token_md); +grpc_access_token_credentials::~grpc_access_token_credentials() { + GRPC_MDELEM_UNREF(access_token_md_); } -static bool access_token_get_request_metadata( - grpc_call_credentials* creds, grpc_polling_entity* pollent, - grpc_auth_metadata_context context, grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, grpc_error** error) { - grpc_access_token_credentials* c = - reinterpret_cast(creds); - grpc_credentials_mdelem_array_add(md_array, c->access_token_md); +bool grpc_access_token_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { + grpc_credentials_mdelem_array_add(md_array, access_token_md_); return true; } -static void access_token_cancel_get_request_metadata( - grpc_call_credentials* c, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { +void grpc_access_token_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable access_token_vtable = { - access_token_destruct, access_token_get_request_metadata, - access_token_cancel_get_request_metadata}; +grpc_access_token_credentials::grpc_access_token_credentials( + const char* access_token) + : grpc_call_credentials(GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) { + char* token_md_value; + gpr_asprintf(&token_md_value, "Bearer %s", access_token); + grpc_core::ExecCtx exec_ctx; + access_token_md_ = grpc_mdelem_from_slices( + grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), + grpc_slice_from_copied_string(token_md_value)); + gpr_free(token_md_value); +} grpc_call_credentials* grpc_access_token_credentials_create( const char* access_token, void* reserved) { - grpc_access_token_credentials* c = - static_cast( - gpr_zalloc(sizeof(grpc_access_token_credentials))); GRPC_API_TRACE( "grpc_access_token_credentials_create(access_token=, " "reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CALL_CREDENTIALS_TYPE_OAUTH2; - c->base.vtable = &access_token_vtable; - gpr_ref_init(&c->base.refcount, 1); - char* token_md_value; - gpr_asprintf(&token_md_value, "Bearer %s", access_token); - grpc_core::ExecCtx exec_ctx; - c->access_token_md = grpc_mdelem_from_slices( - grpc_slice_from_static_string(GRPC_AUTHORIZATION_METADATA_KEY), - grpc_slice_from_copied_string(token_md_value)); - - gpr_free(token_md_value); - return &c->base; + return grpc_core::MakeRefCounted(access_token) + .release(); } diff --git a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h index 12a1d4484f..510a78b484 100644 --- a/src/core/lib/security/credentials/oauth2/oauth2_credentials.h +++ b/src/core/lib/security/credentials/oauth2/oauth2_credentials.h @@ -54,46 +54,91 @@ void grpc_auth_refresh_token_destruct(grpc_auth_refresh_token* refresh_token); // This object is a base for credentials that need to acquire an oauth2 token // from an http service. -typedef void (*grpc_fetch_oauth2_func)(grpc_credentials_metadata_request* req, - grpc_httpcli_context* http_context, - grpc_polling_entity* pollent, - grpc_iomgr_cb_func cb, - grpc_millis deadline); - -typedef struct grpc_oauth2_pending_get_request_metadata { +struct grpc_oauth2_pending_get_request_metadata { grpc_credentials_mdelem_array* md_array; grpc_closure* on_request_metadata; grpc_polling_entity* pollent; struct grpc_oauth2_pending_get_request_metadata* next; -} grpc_oauth2_pending_get_request_metadata; - -typedef struct { - grpc_call_credentials base; - gpr_mu mu; - grpc_mdelem access_token_md; - gpr_timespec token_expiration; - bool token_fetch_pending; - grpc_oauth2_pending_get_request_metadata* pending_requests; - grpc_httpcli_context httpcli_context; - grpc_fetch_oauth2_func fetch_func; - grpc_polling_entity pollent; -} grpc_oauth2_token_fetcher_credentials; +}; + +class grpc_oauth2_token_fetcher_credentials : public grpc_call_credentials { + public: + grpc_oauth2_token_fetcher_credentials(); + ~grpc_oauth2_token_fetcher_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + void on_http_response(grpc_credentials_metadata_request* r, + grpc_error* error); + + GRPC_ABSTRACT_BASE_CLASS + + protected: + virtual void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) GRPC_ABSTRACT; + + private: + gpr_mu mu_; + grpc_mdelem access_token_md_ = GRPC_MDNULL; + gpr_timespec token_expiration_; + bool token_fetch_pending_ = false; + grpc_oauth2_pending_get_request_metadata* pending_requests_ = nullptr; + grpc_httpcli_context httpcli_context_; + grpc_polling_entity pollent_; +}; // Google refresh token credentials. -typedef struct { - grpc_oauth2_token_fetcher_credentials base; - grpc_auth_refresh_token refresh_token; -} grpc_google_refresh_token_credentials; +class grpc_google_refresh_token_credentials final + : public grpc_oauth2_token_fetcher_credentials { + public: + grpc_google_refresh_token_credentials(grpc_auth_refresh_token refresh_token); + ~grpc_google_refresh_token_credentials() override; + + const grpc_auth_refresh_token& refresh_token() const { + return refresh_token_; + } + + protected: + void fetch_oauth2(grpc_credentials_metadata_request* req, + grpc_httpcli_context* httpcli_context, + grpc_polling_entity* pollent, grpc_iomgr_cb_func cb, + grpc_millis deadline) override; + + private: + grpc_auth_refresh_token refresh_token_; +}; // Access token credentials. -typedef struct { - grpc_call_credentials base; - grpc_mdelem access_token_md; -} grpc_access_token_credentials; +class grpc_access_token_credentials final : public grpc_call_credentials { + public: + grpc_access_token_credentials(const char* access_token); + ~grpc_access_token_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + private: + grpc_mdelem access_token_md_; +}; // Private constructor for refresh token credentials from an already parsed // refresh token. Takes ownership of the refresh token. -grpc_call_credentials* +grpc_core::RefCountedPtr grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token token); diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.cc b/src/core/lib/security/credentials/plugin/plugin_credentials.cc index 4015124298..52982fdb8f 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.cc +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.cc @@ -35,20 +35,17 @@ grpc_core::TraceFlag grpc_plugin_credentials_trace(false, "plugin_credentials"); -static void plugin_destruct(grpc_call_credentials* creds) { - grpc_plugin_credentials* c = - reinterpret_cast(creds); - gpr_mu_destroy(&c->mu); - if (c->plugin.state != nullptr && c->plugin.destroy != nullptr) { - c->plugin.destroy(c->plugin.state); +grpc_plugin_credentials::~grpc_plugin_credentials() { + gpr_mu_destroy(&mu_); + if (plugin_.state != nullptr && plugin_.destroy != nullptr) { + plugin_.destroy(plugin_.state); } } -static void pending_request_remove_locked( - grpc_plugin_credentials* c, - grpc_plugin_credentials_pending_request* pending_request) { +void grpc_plugin_credentials::pending_request_remove_locked( + pending_request* pending_request) { if (pending_request->prev == nullptr) { - c->pending_requests = pending_request->next; + pending_requests_ = pending_request->next; } else { pending_request->prev->next = pending_request->next; } @@ -62,17 +59,17 @@ static void pending_request_remove_locked( // cancelled out from under us. // When this returns, r->cancelled indicates whether the request was // cancelled before completion. -static void pending_request_complete( - grpc_plugin_credentials_pending_request* r) { - gpr_mu_lock(&r->creds->mu); - if (!r->cancelled) pending_request_remove_locked(r->creds, r); - gpr_mu_unlock(&r->creds->mu); +void grpc_plugin_credentials::pending_request_complete(pending_request* r) { + GPR_DEBUG_ASSERT(r->creds == this); + gpr_mu_lock(&mu_); + if (!r->cancelled) pending_request_remove_locked(r); + gpr_mu_unlock(&mu_); // Ref to credentials not needed anymore. - grpc_call_credentials_unref(&r->creds->base); + Unref(); } static grpc_error* process_plugin_result( - grpc_plugin_credentials_pending_request* r, const grpc_metadata* md, + grpc_plugin_credentials::pending_request* r, const grpc_metadata* md, size_t num_md, grpc_status_code status, const char* error_details) { grpc_error* error = GRPC_ERROR_NONE; if (status != GRPC_STATUS_OK) { @@ -119,8 +116,8 @@ static void plugin_md_request_metadata_ready(void* request, /* called from application code */ grpc_core::ExecCtx exec_ctx(GRPC_EXEC_CTX_FLAG_IS_FINISHED | GRPC_EXEC_CTX_FLAG_THREAD_RESOURCE_LOOP); - grpc_plugin_credentials_pending_request* r = - static_cast(request); + grpc_plugin_credentials::pending_request* r = + static_cast(request); if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " @@ -128,7 +125,7 @@ static void plugin_md_request_metadata_ready(void* request, r->creds, r); } // Remove request from pending list if not previously cancelled. - pending_request_complete(r); + r->creds->pending_request_complete(r); // If it has not been cancelled, process it. if (!r->cancelled) { grpc_error* error = @@ -143,65 +140,59 @@ static void plugin_md_request_metadata_ready(void* request, gpr_free(r); } -static bool plugin_get_request_metadata(grpc_call_credentials* creds, - grpc_polling_entity* pollent, - grpc_auth_metadata_context context, - grpc_credentials_mdelem_array* md_array, - grpc_closure* on_request_metadata, - grpc_error** error) { - grpc_plugin_credentials* c = - reinterpret_cast(creds); +bool grpc_plugin_credentials::get_request_metadata( + grpc_polling_entity* pollent, grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata, + grpc_error** error) { bool retval = true; // Synchronous return. - if (c->plugin.get_metadata != nullptr) { + if (plugin_.get_metadata != nullptr) { // Create pending_request object. - grpc_plugin_credentials_pending_request* pending_request = - static_cast( - gpr_zalloc(sizeof(*pending_request))); - pending_request->creds = c; - pending_request->md_array = md_array; - pending_request->on_request_metadata = on_request_metadata; + pending_request* request = + static_cast(gpr_zalloc(sizeof(*request))); + request->creds = this; + request->md_array = md_array; + request->on_request_metadata = on_request_metadata; // Add it to the pending list. - gpr_mu_lock(&c->mu); - if (c->pending_requests != nullptr) { - c->pending_requests->prev = pending_request; + gpr_mu_lock(&mu_); + if (pending_requests_ != nullptr) { + pending_requests_->prev = request; } - pending_request->next = c->pending_requests; - c->pending_requests = pending_request; - gpr_mu_unlock(&c->mu); + request->next = pending_requests_; + pending_requests_ = request; + gpr_mu_unlock(&mu_); // Invoke the plugin. The callback holds a ref to us. if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: invoking plugin", - c, pending_request); + this, request); } - grpc_call_credentials_ref(creds); + Ref().release(); grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX]; size_t num_creds_md = 0; grpc_status_code status = GRPC_STATUS_OK; const char* error_details = nullptr; - if (!c->plugin.get_metadata(c->plugin.state, context, - plugin_md_request_metadata_ready, - pending_request, creds_md, &num_creds_md, - &status, &error_details)) { + if (!plugin_.get_metadata( + plugin_.state, context, plugin_md_request_metadata_ready, request, + creds_md, &num_creds_md, &status, &error_details)) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin will return " "asynchronously", - c, pending_request); + this, request); } return false; // Asynchronous return. } // Returned synchronously. // Remove request from pending list if not previously cancelled. - pending_request_complete(pending_request); + request->creds->pending_request_complete(request); // If the request was cancelled, the error will have been returned // asynchronously by plugin_cancel_get_request_metadata(), so return // false. Otherwise, process the result. - if (pending_request->cancelled) { + if (request->cancelled) { if (grpc_plugin_credentials_trace.enabled()) { gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p was cancelled, error " "will be returned asynchronously", - c, pending_request); + this, request); } retval = false; } else { @@ -209,10 +200,10 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, gpr_log(GPR_INFO, "plugin_credentials[%p]: request %p: plugin returned " "synchronously", - c, pending_request); + this, request); } - *error = process_plugin_result(pending_request, creds_md, num_creds_md, - status, error_details); + *error = process_plugin_result(request, creds_md, num_creds_md, status, + error_details); } // Clean up. for (size_t i = 0; i < num_creds_md; ++i) { @@ -220,51 +211,42 @@ static bool plugin_get_request_metadata(grpc_call_credentials* creds, grpc_slice_unref_internal(creds_md[i].value); } gpr_free((void*)error_details); - gpr_free(pending_request); + gpr_free(request); } return retval; } -static void plugin_cancel_get_request_metadata( - grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array, - grpc_error* error) { - grpc_plugin_credentials* c = - reinterpret_cast(creds); - gpr_mu_lock(&c->mu); - for (grpc_plugin_credentials_pending_request* pending_request = - c->pending_requests; +void grpc_plugin_credentials::cancel_get_request_metadata( + grpc_credentials_mdelem_array* md_array, grpc_error* error) { + gpr_mu_lock(&mu_); + for (pending_request* pending_request = pending_requests_; pending_request != nullptr; pending_request = pending_request->next) { if (pending_request->md_array == md_array) { if (grpc_plugin_credentials_trace.enabled()) { - gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", c, + gpr_log(GPR_INFO, "plugin_credentials[%p]: cancelling request %p", this, pending_request); } pending_request->cancelled = true; GRPC_CLOSURE_SCHED(pending_request->on_request_metadata, GRPC_ERROR_REF(error)); - pending_request_remove_locked(c, pending_request); + pending_request_remove_locked(pending_request); break; } } - gpr_mu_unlock(&c->mu); + gpr_mu_unlock(&mu_); GRPC_ERROR_UNREF(error); } -static grpc_call_credentials_vtable plugin_vtable = { - plugin_destruct, plugin_get_request_metadata, - plugin_cancel_get_request_metadata}; +grpc_plugin_credentials::grpc_plugin_credentials( + grpc_metadata_credentials_plugin plugin) + : grpc_call_credentials(plugin.type), plugin_(plugin) { + gpr_mu_init(&mu_); +} grpc_call_credentials* grpc_metadata_credentials_create_from_plugin( grpc_metadata_credentials_plugin plugin, void* reserved) { - grpc_plugin_credentials* c = - static_cast(gpr_zalloc(sizeof(*c))); GRPC_API_TRACE("grpc_metadata_credentials_create_from_plugin(reserved=%p)", 1, (reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = plugin.type; - c->base.vtable = &plugin_vtable; - gpr_ref_init(&c->base.refcount, 1); - c->plugin = plugin; - gpr_mu_init(&c->mu); - return &c->base; + return grpc_core::New(plugin); } diff --git a/src/core/lib/security/credentials/plugin/plugin_credentials.h b/src/core/lib/security/credentials/plugin/plugin_credentials.h index caf990efa1..77a957e513 100644 --- a/src/core/lib/security/credentials/plugin/plugin_credentials.h +++ b/src/core/lib/security/credentials/plugin/plugin_credentials.h @@ -25,22 +25,45 @@ extern grpc_core::TraceFlag grpc_plugin_credentials_trace; -struct grpc_plugin_credentials; - -typedef struct grpc_plugin_credentials_pending_request { - bool cancelled; - struct grpc_plugin_credentials* creds; - grpc_credentials_mdelem_array* md_array; - grpc_closure* on_request_metadata; - struct grpc_plugin_credentials_pending_request* prev; - struct grpc_plugin_credentials_pending_request* next; -} grpc_plugin_credentials_pending_request; - -typedef struct grpc_plugin_credentials { - grpc_call_credentials base; - grpc_metadata_credentials_plugin plugin; - gpr_mu mu; - grpc_plugin_credentials_pending_request* pending_requests; -} grpc_plugin_credentials; +// This type is forward declared as a C struct and we cannot define it as a +// class. Otherwise, compiler will complain about type mismatch due to +// -Wmismatched-tags. +struct grpc_plugin_credentials final : public grpc_call_credentials { + public: + struct pending_request { + bool cancelled; + struct grpc_plugin_credentials* creds; + grpc_credentials_mdelem_array* md_array; + grpc_closure* on_request_metadata; + struct pending_request* prev; + struct pending_request* next; + }; + + explicit grpc_plugin_credentials(grpc_metadata_credentials_plugin plugin); + ~grpc_plugin_credentials() override; + + bool get_request_metadata(grpc_polling_entity* pollent, + grpc_auth_metadata_context context, + grpc_credentials_mdelem_array* md_array, + grpc_closure* on_request_metadata, + grpc_error** error) override; + + void cancel_get_request_metadata(grpc_credentials_mdelem_array* md_array, + grpc_error* error) override; + + // Checks if the request has been cancelled. + // If not, removes it from the pending list, so that it cannot be + // cancelled out from under us. + // When this returns, r->cancelled indicates whether the request was + // cancelled before completion. + void pending_request_complete(pending_request* r); + + private: + void pending_request_remove_locked(pending_request* pending_request); + + grpc_metadata_credentials_plugin plugin_; + gpr_mu mu_; + pending_request* pending_requests_ = nullptr; +}; #endif /* GRPC_CORE_LIB_SECURITY_CREDENTIALS_PLUGIN_PLUGIN_CREDENTIALS_H */ diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.cc b/src/core/lib/security/credentials/ssl/ssl_credentials.cc index 3d6f2f200a..83db86f1ea 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.cc +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.cc @@ -44,22 +44,27 @@ void grpc_tsi_ssl_pem_key_cert_pairs_destroy(tsi_ssl_pem_key_cert_pair* kp, gpr_free(kp); } -static void ssl_destruct(grpc_channel_credentials* creds) { - grpc_ssl_credentials* c = reinterpret_cast(creds); - gpr_free(c->config.pem_root_certs); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pair, 1); - if (c->config.verify_options.verify_peer_destruct != nullptr) { - c->config.verify_options.verify_peer_destruct( - c->config.verify_options.verify_peer_callback_userdata); +grpc_ssl_credentials::grpc_ssl_credentials( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) + : grpc_channel_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + build_config(pem_root_certs, pem_key_cert_pair, verify_options); +} + +grpc_ssl_credentials::~grpc_ssl_credentials() { + gpr_free(config_.pem_root_certs); + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pair, 1); + if (config_.verify_options.verify_peer_destruct != nullptr) { + config_.verify_options.verify_peer_destruct( + config_.verify_options.verify_peer_callback_userdata); } } -static grpc_security_status ssl_create_security_connector( - grpc_channel_credentials* creds, grpc_call_credentials* call_creds, +grpc_core::RefCountedPtr +grpc_ssl_credentials::create_security_connector( + grpc_core::RefCountedPtr call_creds, const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - grpc_ssl_credentials* c = reinterpret_cast(creds); - grpc_security_status status = GRPC_SECURITY_OK; + grpc_channel_args** new_args) { const char* overridden_target_name = nullptr; tsi_ssl_session_cache* ssl_session_cache = nullptr; for (size_t i = 0; args && i < args->num_args; i++) { @@ -74,52 +79,47 @@ static grpc_security_status ssl_create_security_connector( static_cast(arg->value.pointer.p); } } - status = grpc_ssl_channel_security_connector_create( - creds, call_creds, &c->config, target, overridden_target_name, - ssl_session_cache, sc); - if (status != GRPC_SECURITY_OK) { - return status; + grpc_core::RefCountedPtr sc = + grpc_ssl_channel_security_connector_create( + this->Ref(), std::move(call_creds), &config_, target, + overridden_target_name, ssl_session_cache); + if (sc == nullptr) { + return sc; } grpc_arg new_arg = grpc_channel_arg_string_create( (char*)GRPC_ARG_HTTP2_SCHEME, (char*)"https"); *new_args = grpc_channel_args_copy_and_add(args, &new_arg, 1); - return status; + return sc; } -static grpc_channel_credentials_vtable ssl_vtable = { - ssl_destruct, ssl_create_security_connector, nullptr}; - -static void ssl_build_config(const char* pem_root_certs, - grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, - const verify_peer_options* verify_options, - grpc_ssl_config* config) { - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } +void grpc_ssl_credentials::build_config( + const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options) { + config_.pem_root_certs = gpr_strdup(pem_root_certs); if (pem_key_cert_pair != nullptr) { GPR_ASSERT(pem_key_cert_pair->private_key != nullptr); GPR_ASSERT(pem_key_cert_pair->cert_chain != nullptr); - config->pem_key_cert_pair = static_cast( + config_.pem_key_cert_pair = static_cast( gpr_zalloc(sizeof(tsi_ssl_pem_key_cert_pair))); - config->pem_key_cert_pair->cert_chain = + config_.pem_key_cert_pair->cert_chain = gpr_strdup(pem_key_cert_pair->cert_chain); - config->pem_key_cert_pair->private_key = + config_.pem_key_cert_pair->private_key = gpr_strdup(pem_key_cert_pair->private_key); + } else { + config_.pem_key_cert_pair = nullptr; } if (verify_options != nullptr) { - memcpy(&config->verify_options, verify_options, + memcpy(&config_.verify_options, verify_options, sizeof(verify_peer_options)); } else { // Otherwise set all options to default values - memset(&config->verify_options, 0, sizeof(verify_peer_options)); + memset(&config_.verify_options, 0, sizeof(verify_peer_options)); } } grpc_channel_credentials* grpc_ssl_credentials_create( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, const verify_peer_options* verify_options, void* reserved) { - grpc_ssl_credentials* c = static_cast( - gpr_zalloc(sizeof(grpc_ssl_credentials))); GRPC_API_TRACE( "grpc_ssl_credentials_create(pem_root_certs=%s, " "pem_key_cert_pair=%p, " @@ -127,12 +127,9 @@ grpc_channel_credentials* grpc_ssl_credentials_create( "reserved=%p)", 4, (pem_root_certs, pem_key_cert_pair, verify_options, reserved)); GPR_ASSERT(reserved == nullptr); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - c->base.vtable = &ssl_vtable; - gpr_ref_init(&c->base.refcount, 1); - ssl_build_config(pem_root_certs, pem_key_cert_pair, verify_options, - &c->config); - return &c->base; + + return grpc_core::New(pem_root_certs, pem_key_cert_pair, + verify_options); } // @@ -145,21 +142,29 @@ struct grpc_ssl_server_credentials_options { grpc_ssl_server_certificate_config_fetcher* certificate_config_fetcher; }; -static void ssl_server_destruct(grpc_server_credentials* creds) { - grpc_ssl_server_credentials* c = - reinterpret_cast(creds); - grpc_tsi_ssl_pem_key_cert_pairs_destroy(c->config.pem_key_cert_pairs, - c->config.num_key_cert_pairs); - gpr_free(c->config.pem_root_certs); +grpc_ssl_server_credentials::grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options) + : grpc_server_credentials(GRPC_CHANNEL_CREDENTIALS_TYPE_SSL) { + if (options.certificate_config_fetcher != nullptr) { + config_.client_certificate_request = options.client_certificate_request; + certificate_config_fetcher_ = *options.certificate_config_fetcher; + } else { + build_config(options.certificate_config->pem_root_certs, + options.certificate_config->pem_key_cert_pairs, + options.certificate_config->num_key_cert_pairs, + options.client_certificate_request); + } } -static grpc_security_status ssl_server_create_security_connector( - grpc_server_credentials* creds, grpc_server_security_connector** sc) { - return grpc_ssl_server_security_connector_create(creds, sc); +grpc_ssl_server_credentials::~grpc_ssl_server_credentials() { + grpc_tsi_ssl_pem_key_cert_pairs_destroy(config_.pem_key_cert_pairs, + config_.num_key_cert_pairs); + gpr_free(config_.pem_root_certs); +} +grpc_core::RefCountedPtr +grpc_ssl_server_credentials::create_security_connector() { + return grpc_ssl_server_security_connector_create(this->Ref()); } - -static grpc_server_credentials_vtable ssl_server_vtable = { - ssl_server_destruct, ssl_server_create_security_connector}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, @@ -179,18 +184,15 @@ tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( return tsi_pairs; } -static void ssl_build_server_config( +void grpc_ssl_server_credentials::build_config( const char* pem_root_certs, grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, - grpc_ssl_client_certificate_request_type client_certificate_request, - grpc_ssl_server_config* config) { - config->client_certificate_request = client_certificate_request; - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } - config->pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + grpc_ssl_client_certificate_request_type client_certificate_request) { + config_.client_certificate_request = client_certificate_request; + config_.pem_root_certs = gpr_strdup(pem_root_certs); + config_.pem_key_cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( pem_key_cert_pairs, num_key_cert_pairs); - config->num_key_cert_pairs = num_key_cert_pairs; + config_.num_key_cert_pairs = num_key_cert_pairs; } grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( @@ -200,9 +202,7 @@ grpc_ssl_server_certificate_config* grpc_ssl_server_certificate_config_create( grpc_ssl_server_certificate_config* config = static_cast( gpr_zalloc(sizeof(grpc_ssl_server_certificate_config))); - if (pem_root_certs != nullptr) { - config->pem_root_certs = gpr_strdup(pem_root_certs); - } + config->pem_root_certs = gpr_strdup(pem_root_certs); if (num_key_cert_pairs > 0) { GPR_ASSERT(pem_key_cert_pairs != nullptr); config->pem_key_cert_pairs = static_cast( @@ -311,7 +311,6 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_ex( grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( grpc_ssl_server_credentials_options* options) { grpc_server_credentials* retval = nullptr; - grpc_ssl_server_credentials* c = nullptr; if (options == nullptr) { gpr_log(GPR_ERROR, @@ -331,23 +330,7 @@ grpc_server_credentials* grpc_ssl_server_credentials_create_with_options( goto done; } - c = static_cast( - gpr_zalloc(sizeof(grpc_ssl_server_credentials))); - c->base.type = GRPC_CHANNEL_CREDENTIALS_TYPE_SSL; - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &ssl_server_vtable; - - if (options->certificate_config_fetcher != nullptr) { - c->config.client_certificate_request = options->client_certificate_request; - c->certificate_config_fetcher = *options->certificate_config_fetcher; - } else { - ssl_build_server_config(options->certificate_config->pem_root_certs, - options->certificate_config->pem_key_cert_pairs, - options->certificate_config->num_key_cert_pairs, - options->client_certificate_request, &c->config); - } - - retval = &c->base; + retval = grpc_core::New(*options); done: grpc_ssl_server_credentials_options_destroy(options); diff --git a/src/core/lib/security/credentials/ssl/ssl_credentials.h b/src/core/lib/security/credentials/ssl/ssl_credentials.h index 0fba413876..e1174327b3 100644 --- a/src/core/lib/security/credentials/ssl/ssl_credentials.h +++ b/src/core/lib/security/credentials/ssl/ssl_credentials.h @@ -24,27 +24,70 @@ #include "src/core/lib/security/security_connector/ssl/ssl_security_connector.h" -typedef struct { - grpc_channel_credentials base; - grpc_ssl_config config; -} grpc_ssl_credentials; +class grpc_ssl_credentials : public grpc_channel_credentials { + public: + grpc_ssl_credentials(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + ~grpc_ssl_credentials() override; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override; + + private: + void build_config(const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pair, + const verify_peer_options* verify_options); + + grpc_ssl_config config_; +}; struct grpc_ssl_server_certificate_config { - grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; }; -typedef struct { - grpc_ssl_server_certificate_config_callback cb; +struct grpc_ssl_server_certificate_config_fetcher { + grpc_ssl_server_certificate_config_callback cb = nullptr; void* user_data; -} grpc_ssl_server_certificate_config_fetcher; +}; + +class grpc_ssl_server_credentials final : public grpc_server_credentials { + public: + grpc_ssl_server_credentials( + const grpc_ssl_server_credentials_options& options); + ~grpc_ssl_server_credentials() override; -typedef struct { - grpc_server_credentials base; - grpc_ssl_server_config config; - grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher; -} grpc_ssl_server_credentials; + grpc_core::RefCountedPtr + create_security_connector() override; + + bool has_cert_config_fetcher() const { + return certificate_config_fetcher_.cb != nullptr; + } + + grpc_ssl_certificate_config_reload_status FetchCertConfig( + grpc_ssl_server_certificate_config** config) { + GPR_DEBUG_ASSERT(has_cert_config_fetcher()); + return certificate_config_fetcher_.cb(certificate_config_fetcher_.user_data, + config); + } + + const grpc_ssl_server_config& config() const { return config_; } + + private: + void build_config( + const char* pem_root_certs, + grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, size_t num_key_cert_pairs, + grpc_ssl_client_certificate_request_type client_certificate_request); + + grpc_ssl_server_config config_; + grpc_ssl_server_certificate_config_fetcher certificate_config_fetcher_; +}; tsi_ssl_pem_key_cert_pair* grpc_convert_grpc_to_tsi_cert_pairs( const grpc_ssl_pem_key_cert_pair* pem_key_cert_pairs, diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.cc b/src/core/lib/security/security_connector/alts/alts_security_connector.cc index dd71c8bc60..6db70ef172 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.cc +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.cc @@ -28,6 +28,7 @@ #include #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/credentials/alts/alts_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" #include "src/core/lib/slice/slice_internal.h" @@ -35,64 +36,9 @@ #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_alts_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_alts_server_security_connector; - -static void alts_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void alts_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void alts_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast(sc); - grpc_alts_credentials* creds = - reinterpret_cast(c->base.channel_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, c->target_name, creds->handshaker_service_url, - true, interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - auto c = reinterpret_cast(sc); - grpc_alts_server_credentials* creds = - reinterpret_cast(c->base.server_creds); - GPR_ASSERT(alts_tsi_handshaker_create( - creds->options, nullptr, creds->handshaker_service_url, false, - interested_parties, &handshaker) == TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void alts_set_rpc_protocol_versions( +void alts_set_rpc_protocol_versions( grpc_gcp_rpc_protocol_versions* rpc_versions) { grpc_gcp_rpc_protocol_versions_set_max(rpc_versions, GRPC_PROTOCOL_VERSION_MAX_MAJOR, @@ -102,17 +48,131 @@ static void alts_set_rpc_protocol_versions( GRPC_PROTOCOL_VERSION_MIN_MINOR); } +void atls_check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + *auth_context = + grpc_core::internal::grpc_alts_auth_context_from_tsi_peer(&peer); + tsi_peer_destruct(&peer); + grpc_error* error = + *auth_context != nullptr + ? GRPC_ERROR_NONE + : GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Could not get ALTS auth context from TSI peer"); + GRPC_CLOSURE_SCHED(on_peer_checked, error); +} + +class grpc_alts_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_alts_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(/*url_scheme=*/nullptr, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) { + grpc_alts_credentials* creds = + static_cast(mutable_channel_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + + ~grpc_alts_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_credentials* creds = + static_cast(channel_creds()); + GPR_ASSERT(alts_tsi_handshaker_create(creds->options(), target_name_, + creds->handshaker_service_url(), true, + interested_parties, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + atls_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "ALTS call host does not match target name"); + } + return true; + } + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + private: + char* target_name_; +}; + +class grpc_alts_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_alts_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(/*url_scheme=*/nullptr, + std::move(server_creds)) { + grpc_alts_server_credentials* creds = + reinterpret_cast(mutable_server_creds()); + alts_set_rpc_protocol_versions(&creds->mutable_options()->rpc_versions); + } + ~grpc_alts_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + const grpc_alts_server_credentials* creds = + static_cast(server_creds()); + GPR_ASSERT(alts_tsi_handshaker_create( + creds->options(), nullptr, creds->handshaker_service_url(), + false, interested_parties, &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + atls_check_peer(peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } +}; +} // namespace + namespace grpc_core { namespace internal { - -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx) { - if (peer == nullptr || ctx == nullptr) { +grpc_core::RefCountedPtr +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer) { + if (peer == nullptr) { gpr_log(GPR_ERROR, "Invalid arguments to grpc_alts_auth_context_from_tsi_peer()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - *ctx = nullptr; /* Validate certificate type. */ const tsi_peer_property* cert_type_prop = tsi_peer_get_property_by_name(peer, TSI_CERTIFICATE_TYPE_PEER_PROPERTY); @@ -120,14 +180,14 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( strncmp(cert_type_prop->value.data, TSI_ALTS_CERTIFICATE_TYPE, cert_type_prop->value.length) != 0) { gpr_log(GPR_ERROR, "Invalid or missing certificate type property."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Validate RPC protocol versions. */ const tsi_peer_property* rpc_versions_prop = tsi_peer_get_property_by_name(peer, TSI_ALTS_RPC_VERSIONS); if (rpc_versions_prop == nullptr) { gpr_log(GPR_ERROR, "Missing rpc protocol versions property."); - return GRPC_SECURITY_ERROR; + return nullptr; } grpc_gcp_rpc_protocol_versions local_versions, peer_versions; alts_set_rpc_protocol_versions(&local_versions); @@ -138,19 +198,19 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( grpc_slice_unref_internal(slice); if (!decode_result) { gpr_log(GPR_ERROR, "Invalid peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* TODO: Pass highest common rpc protocol version to grpc caller. */ bool check_result = grpc_gcp_rpc_protocol_versions_check( &local_versions, &peer_versions, nullptr); if (!check_result) { gpr_log(GPR_ERROR, "Mismatch of local and peer rpc protocol versions."); - return GRPC_SECURITY_ERROR; + return nullptr; } /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + auto ctx = grpc_core::MakeRefCounted(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_ALTS_TRANSPORT_SECURITY_TYPE); size_t i = 0; for (i = 0; i < peer->property_count; i++) { @@ -158,132 +218,47 @@ grpc_security_status grpc_alts_auth_context_from_tsi_peer( /* Add service account to auth context. */ if (strcmp(tsi_prop->name, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 0) { grpc_auth_context_add_property( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, tsi_prop->value.data, - tsi_prop->value.length); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + tsi_prop->value.data, tsi_prop->value.length); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); + ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY) == 1); } } - if (!grpc_auth_context_peer_is_authenticated(*ctx)) { + if (!grpc_auth_context_peer_is_authenticated(ctx.get())) { gpr_log(GPR_ERROR, "Invalid unauthenticated peer."); - GRPC_AUTH_CONTEXT_UNREF(*ctx, "test"); - *ctx = nullptr; - return GRPC_SECURITY_ERROR; + ctx.reset(DEBUG_LOCATION, "test"); + return nullptr; } - return GRPC_SECURITY_OK; + return ctx; } } // namespace internal } // namespace grpc_core -static void alts_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; - status = grpc_core::internal::grpc_alts_auth_context_from_tsi_peer( - &peer, auth_context); - tsi_peer_destruct(&peer); - grpc_error* error = - status == GRPC_SECURITY_OK - ? GRPC_ERROR_NONE - : GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Could not get ALTS auth context from TSI peer"); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} - -static int alts_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_channel_security_connector* c1 = - reinterpret_cast(sc1); - grpc_alts_channel_security_connector* c2 = - reinterpret_cast(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int alts_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_alts_server_security_connector* c1 = - reinterpret_cast(sc1); - grpc_alts_server_security_connector* c2 = - reinterpret_cast(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_connector_vtable alts_channel_vtable = { - alts_channel_destroy, alts_check_peer, alts_channel_cmp}; - -static grpc_security_connector_vtable alts_server_vtable = { - alts_server_destroy, alts_check_peer, alts_server_cmp}; - -static bool alts_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_alts_channel_security_connector* alts_sc = - reinterpret_cast(sc); - if (host == nullptr || alts_sc == nullptr || - strcmp(host, alts_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "ALTS call host does not match target name"); - } - return true; -} - -static void alts_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { +grpc_core::RefCountedPtr +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast( - gpr_zalloc(sizeof(grpc_alts_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_channel_vtable; - c->base.add_handshakers = alts_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = alts_check_call_host; - c->base.cancel_check_call_host = alts_cancel_check_call_host; - grpc_alts_credentials* creds = - reinterpret_cast(c->base.channel_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), target_name); } -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_alts_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast( - gpr_zalloc(sizeof(grpc_alts_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &alts_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.add_handshakers = alts_server_add_handshakers; - grpc_alts_server_credentials* creds = - reinterpret_cast(c->base.server_creds); - alts_set_rpc_protocol_versions(&creds->options->rpc_versions); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/alts/alts_security_connector.h b/src/core/lib/security/security_connector/alts/alts_security_connector.h index d2e057a76a..b96dc36b30 100644 --- a/src/core/lib/security/security_connector/alts/alts_security_connector.h +++ b/src/core/lib/security/security_connector/alts/alts_security_connector.h @@ -36,12 +36,13 @@ * - sc: address of ALTS channel security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr +grpc_alts_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name); /** * This method creates an ALTS server security connector. @@ -50,17 +51,18 @@ grpc_security_status grpc_alts_channel_security_connector_create( * - sc: address of ALTS server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_alts_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr +grpc_alts_server_security_connector_create( + grpc_core::RefCountedPtr server_creds); namespace grpc_core { namespace internal { /* Exposed only for testing. */ -grpc_security_status grpc_alts_auth_context_from_tsi_peer( - const tsi_peer* peer, grpc_auth_context** ctx); +grpc_core::RefCountedPtr +grpc_alts_auth_context_from_tsi_peer(const tsi_peer* peer); } // namespace internal } // namespace grpc_core diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.cc b/src/core/lib/security/security_connector/fake/fake_security_connector.cc index 5c0c89b88f..d2cdaaac77 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.cc +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.cc @@ -31,6 +31,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/fake/fake_credentials.h" @@ -38,91 +39,183 @@ #include "src/core/lib/security/transport/target_authority_table.h" #include "src/core/tsi/fake_transport_security.h" -typedef struct { - grpc_channel_security_connector base; - char* target; - char* expected_targets; - bool is_lb_channel; - char* target_name_override; -} grpc_fake_channel_security_connector; +namespace { +class grpc_fake_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_fake_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target, const grpc_channel_args* args) + : grpc_channel_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_(gpr_strdup(target)), + expected_targets_( + gpr_strdup(grpc_fake_transport_get_expected_targets(args))), + is_lb_channel_(grpc_core::FindTargetAuthorityTableInArgs(args) != + nullptr) { + const grpc_arg* target_name_override_arg = + grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); + if (target_name_override_arg != nullptr) { + target_name_override_ = + gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + } else { + target_name_override_ = nullptr; + } + } -static void fake_channel_destroy(grpc_security_connector* sc) { - grpc_fake_channel_security_connector* c = - reinterpret_cast(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - gpr_free(c->target); - gpr_free(c->expected_targets); - gpr_free(c->target_name_override); - gpr_free(c); -} + ~grpc_fake_channel_security_connector() override { + gpr_free(target_); + gpr_free(expected_targets_); + if (target_name_override_ != nullptr) gpr_free(target_name_override_); + } -static void fake_server_destroy(grpc_security_connector* sc) { gpr_free(sc); } + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override; -static bool fake_check_target(const char* target_type, const char* target, - const char* set_str) { - GPR_ASSERT(target_type != nullptr); - GPR_ASSERT(target != nullptr); - char** set = nullptr; - size_t set_size = 0; - gpr_string_split(set_str, ",", &set, &set_size); - bool found = false; - for (size_t i = 0; i < set_size; ++i) { - if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_, other->target_); + if (c != 0) return c; + if (expected_targets_ == nullptr || other->expected_targets_ == nullptr) { + c = GPR_ICMP(expected_targets_, other->expected_targets_); + } else { + c = strcmp(expected_targets_, other->expected_targets_); + } + if (c != 0) return c; + return GPR_ICMP(is_lb_channel_, other->is_lb_channel_); } - for (size_t i = 0; i < set_size; ++i) { - gpr_free(set[i]); + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*is_client=*/true), this)); } - gpr_free(set); - return found; -} -static void fake_secure_name_check(const char* target, - const char* expected_targets, - bool is_lb_channel) { - if (expected_targets == nullptr) return; - char** lbs_and_backends = nullptr; - size_t lbs_and_backends_size = 0; - bool success = false; - gpr_string_split(expected_targets, ";", &lbs_and_backends, - &lbs_and_backends_size); - if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { - gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", - expected_targets); - goto done; + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + char* authority_hostname = nullptr; + char* authority_ignored_port = nullptr; + char* target_hostname = nullptr; + char* target_ignored_port = nullptr; + gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); + gpr_split_host_port(target_, &target_hostname, &target_ignored_port); + if (target_name_override_ != nullptr) { + char* fake_security_target_name_override_hostname = nullptr; + char* fake_security_target_name_override_ignored_port = nullptr; + gpr_split_host_port(target_name_override_, + &fake_security_target_name_override_hostname, + &fake_security_target_name_override_ignored_port); + if (strcmp(authority_hostname, + fake_security_target_name_override_hostname) != 0) { + gpr_log(GPR_ERROR, + "Authority (host) '%s' != Fake Security Target override '%s'", + host, fake_security_target_name_override_hostname); + abort(); + } + gpr_free(fake_security_target_name_override_hostname); + gpr_free(fake_security_target_name_override_ignored_port); + } else if (strcmp(authority_hostname, target_hostname) != 0) { + gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", + authority_hostname, target_hostname); + abort(); + } + gpr_free(authority_hostname); + gpr_free(authority_ignored_port); + gpr_free(target_hostname); + gpr_free(target_ignored_port); + return true; } - if (is_lb_channel) { - if (lbs_and_backends_size != 2) { - gpr_log(GPR_ERROR, - "Invalid expected targets arg value: '%s'. Expectations for LB " - "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", - expected_targets); - goto done; + + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + char* target() const { return target_; } + char* expected_targets() const { return expected_targets_; } + bool is_lb_channel() const { return is_lb_channel_; } + char* target_name_override() const { return target_name_override_; } + + private: + bool fake_check_target(const char* target_type, const char* target, + const char* set_str) const { + GPR_ASSERT(target_type != nullptr); + GPR_ASSERT(target != nullptr); + char** set = nullptr; + size_t set_size = 0; + gpr_string_split(set_str, ",", &set, &set_size); + bool found = false; + for (size_t i = 0; i < set_size; ++i) { + if (set[i] != nullptr && strcmp(target, set[i]) == 0) found = true; } - if (!fake_check_target("LB", target, lbs_and_backends[1])) { - gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", - target, lbs_and_backends[1]); - goto done; + for (size_t i = 0; i < set_size; ++i) { + gpr_free(set[i]); } - success = true; - } else { - if (!fake_check_target("Backend", target, lbs_and_backends[0])) { - gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", - target, lbs_and_backends[0]); + gpr_free(set); + return found; + } + + void fake_secure_name_check() const { + if (expected_targets_ == nullptr) return; + char** lbs_and_backends = nullptr; + size_t lbs_and_backends_size = 0; + bool success = false; + gpr_string_split(expected_targets_, ";", &lbs_and_backends, + &lbs_and_backends_size); + if (lbs_and_backends_size > 2 || lbs_and_backends_size == 0) { + gpr_log(GPR_ERROR, "Invalid expected targets arg value: '%s'", + expected_targets_); goto done; } - success = true; - } -done: - for (size_t i = 0; i < lbs_and_backends_size; ++i) { - gpr_free(lbs_and_backends[i]); + if (is_lb_channel_) { + if (lbs_and_backends_size != 2) { + gpr_log(GPR_ERROR, + "Invalid expected targets arg value: '%s'. Expectations for LB " + "channels must be of the form 'be1,be2,be3,...;lb1,lb2,...", + expected_targets_); + goto done; + } + if (!fake_check_target("LB", target_, lbs_and_backends[1])) { + gpr_log(GPR_ERROR, "LB target '%s' not found in expected set '%s'", + target_, lbs_and_backends[1]); + goto done; + } + success = true; + } else { + if (!fake_check_target("Backend", target_, lbs_and_backends[0])) { + gpr_log(GPR_ERROR, "Backend target '%s' not found in expected set '%s'", + target_, lbs_and_backends[0]); + goto done; + } + success = true; + } + done: + for (size_t i = 0; i < lbs_and_backends_size; ++i) { + gpr_free(lbs_and_backends[i]); + } + gpr_free(lbs_and_backends); + if (!success) abort(); } - gpr_free(lbs_and_backends); - if (!success) abort(); -} -static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { + char* target_; + char* expected_targets_; + bool is_lb_channel_; + char* target_name_override_; +}; + +static void fake_check_peer( + grpc_security_connector* sc, tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { const char* prop_name; grpc_error* error = GRPC_ERROR_NONE; *auth_context = nullptr; @@ -147,164 +240,65 @@ static void fake_check_peer(grpc_security_connector* sc, tsi_peer peer, "Invalid value for cert type property."); goto end; } - *auth_context = grpc_auth_context_create(nullptr); + *auth_context = grpc_core::MakeRefCounted(nullptr); grpc_auth_context_add_cstring_property( - *auth_context, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + auth_context->get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_FAKE_TRANSPORT_SECURITY_TYPE); end: GRPC_CLOSURE_SCHED(on_peer_checked, error); tsi_peer_destruct(&peer); } -static void fake_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); - grpc_fake_channel_security_connector* c = - reinterpret_cast(sc); - fake_secure_name_check(c->target, c->expected_targets, c->is_lb_channel); +void grpc_fake_channel_security_connector::check_peer( + tsi_peer peer, grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { + fake_check_peer(this, peer, auth_context, on_peer_checked); + fake_secure_name_check(); } -static void fake_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - fake_check_peer(sc, peer, auth_context, on_peer_checked); -} +class grpc_fake_server_security_connector + : public grpc_server_security_connector { + public: + grpc_fake_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_FAKE_SECURITY_URL_SCHEME, + std::move(server_creds)) {} + ~grpc_fake_server_security_connector() override = default; -static int fake_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_fake_channel_security_connector* c1 = - reinterpret_cast(sc1); - grpc_fake_channel_security_connector* c2 = - reinterpret_cast(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target, c2->target); - if (c != 0) return c; - if (c1->expected_targets == nullptr || c2->expected_targets == nullptr) { - c = GPR_ICMP(c1->expected_targets, c2->expected_targets); - } else { - c = strcmp(c1->expected_targets, c2->expected_targets); + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + fake_check_peer(this, peer, auth_context, on_peer_checked); } - if (c != 0) return c; - return GPR_ICMP(c1->is_lb_channel, c2->is_lb_channel); -} -static int fake_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast(sc1), - reinterpret_cast(sc2)); -} - -static bool fake_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_fake_channel_security_connector* c = - reinterpret_cast(sc); - char* authority_hostname = nullptr; - char* authority_ignored_port = nullptr; - char* target_hostname = nullptr; - char* target_ignored_port = nullptr; - gpr_split_host_port(host, &authority_hostname, &authority_ignored_port); - gpr_split_host_port(c->target, &target_hostname, &target_ignored_port); - if (c->target_name_override != nullptr) { - char* fake_security_target_name_override_hostname = nullptr; - char* fake_security_target_name_override_ignored_port = nullptr; - gpr_split_host_port(c->target_name_override, - &fake_security_target_name_override_hostname, - &fake_security_target_name_override_ignored_port); - if (strcmp(authority_hostname, - fake_security_target_name_override_hostname) != 0) { - gpr_log(GPR_ERROR, - "Authority (host) '%s' != Fake Security Target override '%s'", - host, fake_security_target_name_override_hostname); - abort(); - } - gpr_free(fake_security_target_name_override_hostname); - gpr_free(fake_security_target_name_override_ignored_port); - } else if (strcmp(authority_hostname, target_hostname) != 0) { - gpr_log(GPR_ERROR, "Authority (host) '%s' != Target '%s'", - authority_hostname, target_hostname); - abort(); + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + grpc_handshake_manager_add( + handshake_mgr, + grpc_security_handshaker_create( + tsi_create_fake_handshaker(/*=is_client*/ false), this)); } - gpr_free(authority_hostname); - gpr_free(authority_ignored_port); - gpr_free(target_hostname); - gpr_free(target_ignored_port); - return true; -} -static void fake_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} - -static void fake_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(true /* is_client */), &sc->base)); -} - -static void fake_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_handshake_manager_add( - handshake_mgr, - grpc_security_handshaker_create( - tsi_create_fake_handshaker(false /* is_client */), &sc->base)); -} - -static grpc_security_connector_vtable fake_channel_vtable = { - fake_channel_destroy, fake_channel_check_peer, fake_channel_cmp}; - -static grpc_security_connector_vtable fake_server_vtable = { - fake_server_destroy, fake_server_check_peer, fake_server_cmp}; - -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args) { - grpc_fake_channel_security_connector* c = - static_cast( - gpr_zalloc(sizeof(*c))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->base.base.vtable = &fake_channel_vtable; - c->base.channel_creds = channel_creds; - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = fake_channel_check_call_host; - c->base.cancel_check_call_host = fake_channel_cancel_check_call_host; - c->base.add_handshakers = fake_channel_add_handshakers; - c->target = gpr_strdup(target); - const char* expected_targets = grpc_fake_transport_get_expected_targets(args); - c->expected_targets = gpr_strdup(expected_targets); - c->is_lb_channel = grpc_core::FindTargetAuthorityTableInArgs(args) != nullptr; - const grpc_arg* target_name_override_arg = - grpc_channel_args_find(args, GRPC_SSL_TARGET_NAME_OVERRIDE_ARG); - if (target_name_override_arg != nullptr) { - c->target_name_override = - gpr_strdup(grpc_channel_arg_get_string(target_name_override_arg)); + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); } - return &c->base; +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target, const grpc_channel_args* args) { + return grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), target, + args); } -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds) { - grpc_server_security_connector* c = - static_cast( - gpr_zalloc(sizeof(grpc_server_security_connector))); - gpr_ref_init(&c->base.refcount, 1); - c->base.vtable = &fake_server_vtable; - c->base.url_scheme = GRPC_FAKE_SECURITY_URL_SCHEME; - c->server_creds = server_creds; - c->add_handshakers = fake_server_add_handshakers; - return c; +grpc_core::RefCountedPtr +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + return grpc_core::MakeRefCounted( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/fake/fake_security_connector.h b/src/core/lib/security/security_connector/fake/fake_security_connector.h index fdfe048c6e..344a2349a4 100644 --- a/src/core/lib/security/security_connector/fake/fake_security_connector.h +++ b/src/core/lib/security/security_connector/fake/fake_security_connector.h @@ -24,19 +24,22 @@ #include #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/security_connector/security_connector.h" #define GRPC_FAKE_SECURITY_URL_SCHEME "http+fake_security" /* Creates a fake connector that emulates real channel security. */ -grpc_channel_security_connector* grpc_fake_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, const char* target, - const grpc_channel_args* args); +grpc_core::RefCountedPtr +grpc_fake_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target, const grpc_channel_args* args); /* Creates a fake connector that emulates real server security. */ -grpc_server_security_connector* grpc_fake_server_security_connector_create( - grpc_server_credentials* server_creds); +grpc_core::RefCountedPtr +grpc_fake_server_security_connector_create( + grpc_core::RefCountedPtr server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_FAKE_FAKE_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/local/local_security_connector.cc b/src/core/lib/security/security_connector/local/local_security_connector.cc index 008a98df28..7a59e54e9a 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.cc +++ b/src/core/lib/security/security_connector/local/local_security_connector.cc @@ -30,6 +30,7 @@ #include "src/core/ext/filters/client_channel/client_channel.h" #include "src/core/lib/channel/channel_args.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/security/credentials/local/local_credentials.h" #include "src/core/lib/security/transport/security_handshaker.h" @@ -39,153 +40,145 @@ #define GRPC_UDS_URL_SCHEME "unix" #define GRPC_LOCAL_TRANSPORT_SECURITY_TYPE "local" -typedef struct { - grpc_channel_security_connector base; - char* target_name; -} grpc_local_channel_security_connector; +namespace { -typedef struct { - grpc_server_security_connector base; -} grpc_local_server_security_connector; - -static void local_channel_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast(sc); - grpc_call_credentials_unref(c->base.request_metadata_creds); - grpc_channel_credentials_unref(c->base.channel_creds); - gpr_free(c->target_name); - gpr_free(sc); -} - -static void local_server_destroy(grpc_security_connector* sc) { - if (sc == nullptr) { - return; - } - auto c = reinterpret_cast(sc); - grpc_server_credentials_unref(c->base.server_creds); - gpr_free(sc); -} - -static void local_channel_add_handshakers( - grpc_channel_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static void local_server_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_manager) { - tsi_handshaker* handshaker = nullptr; - GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, &handshaker) == - TSI_OK); - grpc_handshake_manager_add(handshake_manager, grpc_security_handshaker_create( - handshaker, &sc->base)); -} - -static int local_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_channel_security_connector* c1 = - reinterpret_cast(sc1); - grpc_local_channel_security_connector* c2 = - reinterpret_cast(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - return strcmp(c1->target_name, c2->target_name); -} - -static int local_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_local_server_security_connector* c1 = - reinterpret_cast(sc1); - grpc_local_server_security_connector* c2 = - reinterpret_cast(sc2); - return grpc_server_security_connector_cmp(&c1->base, &c2->base); -} - -static grpc_security_status local_auth_context_create(grpc_auth_context** ctx) { - if (ctx == nullptr) { - gpr_log(GPR_ERROR, "Invalid arguments to local_auth_context_create()"); - return GRPC_SECURITY_ERROR; - } +grpc_core::RefCountedPtr local_auth_context_create() { /* Create auth context. */ - *ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); grpc_auth_context_add_cstring_property( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_LOCAL_TRANSPORT_SECURITY_TYPE); GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - *ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); - return GRPC_SECURITY_OK; + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME) == 1); + return ctx; } -static void local_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_security_status status; +void local_check_peer(grpc_security_connector* sc, tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) { /* Create an auth context which is necessary to pass the santiy check in - * {client, server}_auth_filter that verifies if the peer's auth context is + * {client, server}_auth_filter that verifies if the pepp's auth context is * obtained during handshakes. The auth context is only checked for its * existence and not actually used. */ - status = local_auth_context_create(auth_context); - grpc_error* error = status == GRPC_SECURITY_OK + *auth_context = local_auth_context_create(); + grpc_error* error = *auth_context != nullptr ? GRPC_ERROR_NONE : GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Could not create local auth context"); GRPC_CLOSURE_SCHED(on_peer_checked, error); } -static grpc_security_connector_vtable local_channel_vtable = { - local_channel_destroy, local_check_peer, local_channel_cmp}; - -static grpc_security_connector_vtable local_server_vtable = { - local_server_destroy, local_check_peer, local_server_cmp}; - -static bool local_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_local_channel_security_connector* local_sc = - reinterpret_cast(sc); - if (host == nullptr || local_sc == nullptr || - strcmp(host, local_sc->target_name) != 0) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "local call host does not match target name"); +class grpc_local_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_local_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const char* target_name) + : grpc_channel_security_connector(GRPC_UDS_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + target_name_(gpr_strdup(target_name)) {} + + ~grpc_local_channel_security_connector() override { gpr_free(target_name_); } + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(true /* is_client */, &handshaker) == + TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); } - return true; -} -static void local_cancel_check_call_host(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast( + other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + return strcmp(target_name_, other->target_name_); + } + + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + local_check_peer(this, peer, auth_context, on_peer_checked); + } + + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + if (host == nullptr || strcmp(host, target_name_) != 0) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "local call host does not match target name"); + } + return true; + } -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc) { - if (channel_creds == nullptr || sc == nullptr || target_name == nullptr) { + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } + + const char* target_name() const { return target_name_; } + + private: + char* target_name_; +}; + +class grpc_local_server_security_connector final + : public grpc_server_security_connector { + public: + grpc_local_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_UDS_URL_SCHEME, + std::move(server_creds)) {} + ~grpc_local_server_security_connector() override = default; + + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_manager) override { + tsi_handshaker* handshaker = nullptr; + GPR_ASSERT(local_tsi_handshaker_create(false /* is_client */, + &handshaker) == TSI_OK); + grpc_handshake_manager_add( + handshake_manager, grpc_security_handshaker_create(handshaker, this)); + } + + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + local_check_peer(this, peer, auth_context, on_peer_checked); + } + + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_channel_args* args, const char* target_name) { + if (channel_creds == nullptr || target_name == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } // Check if local_connect_type is UDS. Only UDS is supported for now. grpc_local_credentials* creds = - reinterpret_cast(channel_creds); - if (creds->connect_type != UDS) { + static_cast(channel_creds.get()); + if (creds->connect_type() != UDS) { gpr_log(GPR_ERROR, "Invalid local channel type to " "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } // Check if target_name is a valid UDS address. const grpc_arg* server_uri_arg = @@ -196,51 +189,30 @@ grpc_security_status grpc_local_channel_security_connector_create( gpr_log(GPR_ERROR, "Invalid target_name to " "grpc_local_channel_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast( - gpr_zalloc(sizeof(grpc_local_channel_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_channel_vtable; - c->base.add_handshakers = local_channel_add_handshakers; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = local_check_call_host; - c->base.cancel_check_call_host = local_cancel_check_call_host; - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->target_name = gpr_strdup(target_name); - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted( + channel_creds, request_metadata_creds, target_name); } -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, - grpc_server_security_connector** sc) { - if (server_creds == nullptr || sc == nullptr) { +grpc_core::RefCountedPtr +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr server_creds) { + if (server_creds == nullptr) { gpr_log( GPR_ERROR, "Invalid arguments to grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } // Check if local_connect_type is UDS. Only UDS is supported for now. - grpc_local_server_credentials* creds = - reinterpret_cast(server_creds); - if (creds->connect_type != UDS) { + const grpc_local_server_credentials* creds = + static_cast(server_creds.get()); + if (creds->connect_type() != UDS) { gpr_log(GPR_ERROR, "Invalid local server type to " "grpc_local_server_security_connector_create()"); - return GRPC_SECURITY_ERROR; + return nullptr; } - auto c = static_cast( - gpr_zalloc(sizeof(grpc_local_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &local_server_vtable; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - c->base.base.url_scheme = - creds->connect_type == UDS ? GRPC_UDS_URL_SCHEME : nullptr; - c->base.add_handshakers = local_server_add_handshakers; - *sc = &c->base; - return GRPC_SECURITY_OK; + return grpc_core::MakeRefCounted( + std::move(server_creds)); } diff --git a/src/core/lib/security/security_connector/local/local_security_connector.h b/src/core/lib/security/security_connector/local/local_security_connector.h index 5369a2127a..6eee0ca9a6 100644 --- a/src/core/lib/security/security_connector/local/local_security_connector.h +++ b/src/core/lib/security/security_connector/local/local_security_connector.h @@ -34,13 +34,13 @@ * - sc: address of local channel security connector instance to be returned * from the method. * - * It returns GRPC_SECURITY_OK on success, and an error stauts code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_channel_args* args, const char* target_name, - grpc_channel_security_connector** sc); +grpc_core::RefCountedPtr +grpc_local_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_channel_args* args, const char* target_name); /** * This method creates a local server security connector. @@ -49,10 +49,11 @@ grpc_security_status grpc_local_channel_security_connector_create( * - sc: address of local server security connector instance to be returned from * the method. * - * It returns GRPC_SECURITY_OK on success, and an error status code on failure. + * It returns nullptr on failure. */ -grpc_security_status grpc_local_server_security_connector_create( - grpc_server_credentials* server_creds, grpc_server_security_connector** sc); +grpc_core::RefCountedPtr +grpc_local_server_security_connector_create( + grpc_core::RefCountedPtr server_creds); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_LOCAL_LOCAL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/security_connector.cc b/src/core/lib/security/security_connector/security_connector.cc index 02cecb0eb1..96a1960546 100644 --- a/src/core/lib/security/security_connector/security_connector.cc +++ b/src/core/lib/security/security_connector/security_connector.cc @@ -35,150 +35,67 @@ #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/security_connector/load_system_roots.h" +#include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/security/transport/security_handshaker.h" grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount( false, "security_connector_refcount"); -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - if (connector != nullptr) { - connector->add_handshakers(connector, interested_parties, handshake_mgr); - } -} - -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - if (sc == nullptr) { - GRPC_CLOSURE_SCHED(on_peer_checked, - GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check peer -- no security connector")); - tsi_peer_destruct(&peer); - } else { - sc->vtable->check_peer(sc, peer, auth_context, on_peer_checked); - } -} - -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other) { +grpc_server_security_connector::grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr server_creds) + : grpc_security_connector(url_scheme), + server_creds_(std::move(server_creds)) {} + +grpc_channel_security_connector::grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds) + : grpc_security_connector(url_scheme), + channel_creds_(std::move(channel_creds)), + request_metadata_creds_(std::move(request_metadata_creds)) {} +grpc_channel_security_connector::~grpc_channel_security_connector() {} + +int grpc_security_connector_cmp(const grpc_security_connector* sc, + const grpc_security_connector* other) { if (sc == nullptr || other == nullptr) return GPR_ICMP(sc, other); - int c = GPR_ICMP(sc->vtable, other->vtable); - if (c != 0) return c; - return sc->vtable->cmp(sc, other); + return sc->cmp(other); } -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2) { - GPR_ASSERT(sc1->channel_creds != nullptr); - GPR_ASSERT(sc2->channel_creds != nullptr); - int c = GPR_ICMP(sc1->channel_creds, sc2->channel_creds); - if (c != 0) return c; - c = GPR_ICMP(sc1->request_metadata_creds, sc2->request_metadata_creds); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->check_call_host, (void*)sc2->check_call_host); - if (c != 0) return c; - c = GPR_ICMP((void*)sc1->cancel_check_call_host, - (void*)sc2->cancel_check_call_host); +int grpc_channel_security_connector::channel_security_connector_cmp( + const grpc_channel_security_connector* other) const { + const grpc_channel_security_connector* other_sc = + static_cast(other); + GPR_ASSERT(channel_creds() != nullptr); + GPR_ASSERT(other_sc->channel_creds() != nullptr); + int c = GPR_ICMP(channel_creds(), other_sc->channel_creds()); if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); + return GPR_ICMP(request_metadata_creds(), other_sc->request_metadata_creds()); } -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2) { - GPR_ASSERT(sc1->server_creds != nullptr); - GPR_ASSERT(sc2->server_creds != nullptr); - int c = GPR_ICMP(sc1->server_creds, sc2->server_creds); - if (c != 0) return c; - return GPR_ICMP((void*)sc1->add_handshakers, (void*)sc2->add_handshakers); -} - -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error) { - if (sc == nullptr || sc->check_call_host == nullptr) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "cannot check call host -- no security connector"); - return true; - } - return sc->check_call_host(sc, host, auth_context, on_call_host_checked, - error); -} - -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - if (sc == nullptr || sc->cancel_check_call_host == nullptr) { - GRPC_ERROR_UNREF(error); - return; - } - sc->cancel_check_call_host(sc, on_call_host_checked, error); -} - -#ifndef NDEBUG -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc, const char* file, int line, - const char* reason) { - if (sc == nullptr) return nullptr; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p ref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val + 1, reason); - } -#else -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* sc) { - if (sc == nullptr) return nullptr; -#endif - gpr_ref(&sc->refcount); - return sc; -} - -#ifndef NDEBUG -void grpc_security_connector_unref(grpc_security_connector* sc, - const char* file, int line, - const char* reason) { - if (sc == nullptr) return; - if (grpc_trace_security_connector_refcount.enabled()) { - gpr_atm val = gpr_atm_no_barrier_load(&sc->refcount.count); - gpr_log(file, line, GPR_LOG_SEVERITY_DEBUG, - "SECURITY_CONNECTOR:%p unref %" PRIdPTR " -> %" PRIdPTR " %s", sc, - val, val - 1, reason); - } -#else -void grpc_security_connector_unref(grpc_security_connector* sc) { - if (sc == nullptr) return; -#endif - if (gpr_unref(&sc->refcount)) sc->vtable->destroy(sc); +int grpc_server_security_connector::server_security_connector_cmp( + const grpc_server_security_connector* other) const { + const grpc_server_security_connector* other_sc = + static_cast(other); + GPR_ASSERT(server_creds() != nullptr); + GPR_ASSERT(other_sc->server_creds() != nullptr); + return GPR_ICMP(server_creds(), other_sc->server_creds()); } static void connector_arg_destroy(void* p) { - GRPC_SECURITY_CONNECTOR_UNREF((grpc_security_connector*)p, - "connector_arg_destroy"); + static_cast(p)->Unref(DEBUG_LOCATION, + "connector_arg_destroy"); } static void* connector_arg_copy(void* p) { - return GRPC_SECURITY_CONNECTOR_REF((grpc_security_connector*)p, - "connector_arg_copy"); + return static_cast(p) + ->Ref(DEBUG_LOCATION, "connector_arg_copy") + .release(); } static int connector_cmp(void* a, void* b) { - return grpc_security_connector_cmp(static_cast(a), - static_cast(b)); + return static_cast(a)->cmp( + static_cast(b)); } static const grpc_arg_pointer_vtable connector_arg_vtable = { diff --git a/src/core/lib/security/security_connector/security_connector.h b/src/core/lib/security/security_connector/security_connector.h index 4c921a8793..d90aa8c4da 100644 --- a/src/core/lib/security/security_connector/security_connector.h +++ b/src/core/lib/security/security_connector/security_connector.h @@ -26,6 +26,7 @@ #include #include "src/core/lib/channel/handshaker.h" +#include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/iomgr/endpoint.h" #include "src/core/lib/iomgr/pollset.h" #include "src/core/lib/iomgr/tcp_server.h" @@ -34,8 +35,6 @@ extern grpc_core::DebugOnlyTraceFlag grpc_trace_security_connector_refcount; -/* --- status enum. --- */ - typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; /* --- security_connector object. --- @@ -43,54 +42,33 @@ typedef enum { GRPC_SECURITY_OK = 0, GRPC_SECURITY_ERROR } grpc_security_status; A security connector object represents away to configure the underlying transport security mechanism and check the resulting trusted peer. */ -typedef struct grpc_security_connector grpc_security_connector; - #define GRPC_ARG_SECURITY_CONNECTOR "grpc.security_connector" -typedef struct { - void (*destroy)(grpc_security_connector* sc); - void (*check_peer)(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - int (*cmp)(grpc_security_connector* sc, grpc_security_connector* other); -} grpc_security_connector_vtable; - -struct grpc_security_connector { - const grpc_security_connector_vtable* vtable; - gpr_refcount refcount; - const char* url_scheme; -}; +class grpc_security_connector + : public grpc_core::RefCounted { + public: + explicit grpc_security_connector(const char* url_scheme) + : grpc_core::RefCounted( + &grpc_trace_security_connector_refcount), + url_scheme_(url_scheme) {} + virtual ~grpc_security_connector() = default; + + /* Check the peer. Callee takes ownership of the peer object. + When done, sets *auth_context and invokes on_peer_checked. */ + virtual void check_peer( + tsi_peer peer, grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) GRPC_ABSTRACT; + + /* Compares two security connectors. */ + virtual int cmp(const grpc_security_connector* other) const GRPC_ABSTRACT; + + const char* url_scheme() const { return url_scheme_; } -/* Refcounting. */ -#ifndef NDEBUG -#define GRPC_SECURITY_CONNECTOR_REF(p, r) \ - grpc_security_connector_ref((p), __FILE__, __LINE__, (r)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) \ - grpc_security_connector_unref((p), __FILE__, __LINE__, (r)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy, const char* file, int line, - const char* reason); -void grpc_security_connector_unref(grpc_security_connector* policy, - const char* file, int line, - const char* reason); -#else -#define GRPC_SECURITY_CONNECTOR_REF(p, r) grpc_security_connector_ref((p)) -#define GRPC_SECURITY_CONNECTOR_UNREF(p, r) grpc_security_connector_unref((p)) -grpc_security_connector* grpc_security_connector_ref( - grpc_security_connector* policy); -void grpc_security_connector_unref(grpc_security_connector* policy); -#endif - -/* Check the peer. Callee takes ownership of the peer object. - When done, sets *auth_context and invokes on_peer_checked. */ -void grpc_security_connector_check_peer(grpc_security_connector* sc, - tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked); - -/* Compares two security connectors. */ -int grpc_security_connector_cmp(grpc_security_connector* sc, - grpc_security_connector* other); + GRPC_ABSTRACT_BASE_CLASS + + private: + const char* url_scheme_; +}; /* Util to encapsulate the connector in a channel arg. */ grpc_arg grpc_security_connector_to_arg(grpc_security_connector* sc); @@ -107,71 +85,89 @@ grpc_security_connector* grpc_security_connector_find_in_args( A channel security connector object represents a way to configure the underlying transport security mechanism on the client side. */ -typedef struct grpc_channel_security_connector grpc_channel_security_connector; - -struct grpc_channel_security_connector { - grpc_security_connector base; - grpc_channel_credentials* channel_creds; - grpc_call_credentials* request_metadata_creds; - bool (*check_call_host)(grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error); - void (*cancel_check_call_host)(grpc_channel_security_connector* sc, - grpc_closure* on_call_host_checked, - grpc_error* error); - void (*add_handshakers)(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_channel_security_connector : public grpc_security_connector { + public: + grpc_channel_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds); + ~grpc_channel_security_connector() override; + + /// Checks that the host that will be set for a call is acceptable. + /// Returns true if completed synchronously, in which case \a error will + /// be set to indicate the result. Otherwise, \a on_call_host_checked + /// will be invoked when complete. + virtual bool check_call_host(const char* host, + grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) GRPC_ABSTRACT; + /// Cancels a pending asychronous call to + /// grpc_channel_security_connector_check_call_host() with + /// \a on_call_host_checked as its callback. + virtual void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) GRPC_ABSTRACT; + /// Registers handshakers with \a handshake_mgr. + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_channel_credentials* channel_creds() const { + return channel_creds_.get(); + } + grpc_channel_credentials* mutable_channel_creds() { + return channel_creds_.get(); + } + const grpc_call_credentials* request_metadata_creds() const { + return request_metadata_creds_.get(); + } + grpc_call_credentials* mutable_request_metadata_creds() { + return request_metadata_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int channel_security_connector_cmp( + const grpc_channel_security_connector* other) const; + + private: + grpc_core::RefCountedPtr channel_creds_; + grpc_core::RefCountedPtr request_metadata_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_channel_security_connector_cmp(grpc_channel_security_connector* sc1, - grpc_channel_security_connector* sc2); - -/// Checks that the host that will be set for a call is acceptable. -/// Returns true if completed synchronously, in which case \a error will -/// be set to indicate the result. Otherwise, \a on_call_host_checked -/// will be invoked when complete. -bool grpc_channel_security_connector_check_call_host( - grpc_channel_security_connector* sc, const char* host, - grpc_auth_context* auth_context, grpc_closure* on_call_host_checked, - grpc_error** error); - -/// Cancels a pending asychronous call to -/// grpc_channel_security_connector_check_call_host() with -/// \a on_call_host_checked as its callback. -void grpc_channel_security_connector_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error); - -/* Registers handshakers with \a handshake_mgr. */ -void grpc_channel_security_connector_add_handshakers( - grpc_channel_security_connector* connector, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - /* --- server_security_connector object. --- A server security connector object represents a way to configure the underlying transport security mechanism on the server side. */ -typedef struct grpc_server_security_connector grpc_server_security_connector; - -struct grpc_server_security_connector { - grpc_security_connector base; - grpc_server_credentials* server_creds; - void (*add_handshakers)(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); +class grpc_server_security_connector : public grpc_security_connector { + public: + grpc_server_security_connector( + const char* url_scheme, + grpc_core::RefCountedPtr server_creds); + ~grpc_server_security_connector() override = default; + + virtual void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) + GRPC_ABSTRACT; + + const grpc_server_credentials* server_creds() const { + return server_creds_.get(); + } + grpc_server_credentials* mutable_server_creds() { + return server_creds_.get(); + } + + GRPC_ABSTRACT_BASE_CLASS + + protected: + // Helper methods to be used in subclasses. + int server_security_connector_cmp( + const grpc_server_security_connector* other) const; + + private: + grpc_core::RefCountedPtr server_creds_; }; -/// A helper function for use in grpc_security_connector_cmp() implementations. -int grpc_server_security_connector_cmp(grpc_server_security_connector* sc1, - grpc_server_security_connector* sc2); - -void grpc_server_security_connector_add_handshakers( - grpc_server_security_connector* sc, grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr); - #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SECURITY_CONNECTOR_H */ diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc index 20a9533dd1..14b2c4030f 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/ssl/ssl_credentials.h" @@ -39,172 +40,10 @@ #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security.h" -typedef struct { - grpc_channel_security_connector base; - tsi_ssl_client_handshaker_factory* client_handshaker_factory; - char* target_name; - char* overridden_target_name; - const verify_peer_options* verify_options; -} grpc_ssl_channel_security_connector; - -typedef struct { - grpc_server_security_connector base; - tsi_ssl_server_handshaker_factory* server_handshaker_factory; -} grpc_ssl_server_security_connector; - -static bool server_connector_has_cert_config_fetcher( - grpc_ssl_server_security_connector* c) { - GPR_ASSERT(c != nullptr); - grpc_ssl_server_credentials* server_creds = - reinterpret_cast(c->base.server_creds); - GPR_ASSERT(server_creds != nullptr); - return server_creds->certificate_config_fetcher.cb != nullptr; -} - -static void ssl_channel_destroy(grpc_security_connector* sc) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast(sc); - grpc_channel_credentials_unref(c->base.channel_creds); - grpc_call_credentials_unref(c->base.request_metadata_creds); - tsi_ssl_client_handshaker_factory_unref(c->client_handshaker_factory); - c->client_handshaker_factory = nullptr; - if (c->target_name != nullptr) gpr_free(c->target_name); - if (c->overridden_target_name != nullptr) gpr_free(c->overridden_target_name); - gpr_free(sc); -} - -static void ssl_server_destroy(grpc_security_connector* sc) { - grpc_ssl_server_security_connector* c = - reinterpret_cast(sc); - grpc_server_credentials_unref(c->base.server_creds); - tsi_ssl_server_handshaker_factory_unref(c->server_handshaker_factory); - c->server_handshaker_factory = nullptr; - gpr_free(sc); -} - -static void ssl_channel_add_handshakers(grpc_channel_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast(sc); - // Instantiate TSI handshaker. - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( - c->client_handshaker_factory, - c->overridden_target_name != nullptr ? c->overridden_target_name - : c->target_name, - &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -/* Attempts to replace the server_handshaker_factory with a new factory using - * the provided grpc_ssl_server_certificate_config. Should new factory creation - * fail, the existing factory will not be replaced. Returns true on success (new - * factory created). */ -static bool try_replace_server_handshaker_factory( - grpc_ssl_server_security_connector* sc, - const grpc_ssl_server_certificate_config* config) { - if (config == nullptr) { - gpr_log(GPR_ERROR, - "Server certificate config callback returned invalid (NULL) " - "config."); - return false; - } - gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); - - size_t num_alpn_protocols = 0; - const char** alpn_protocol_strings = - grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( - config->pem_key_cert_pairs, config->num_key_cert_pairs); - tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; - grpc_ssl_server_credentials* server_creds = - reinterpret_cast(sc->base.server_creds); - tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( - cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, - grpc_get_tsi_client_certificate_request_type( - server_creds->config.client_certificate_request), - grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast(num_alpn_protocols), &new_handshaker_factory); - gpr_free(cert_pairs); - gpr_free((void*)alpn_protocol_strings); - - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - return false; - } - tsi_ssl_server_handshaker_factory_unref(sc->server_handshaker_factory); - sc->server_handshaker_factory = new_handshaker_factory; - return true; -} - -/* Attempts to fetch the server certificate config if a callback is available. - * Current certificate config will continue to be used if the callback returns - * an error. Returns true if new credentials were sucessfully loaded. */ -static bool try_fetch_ssl_server_credentials( - grpc_ssl_server_security_connector* sc) { - grpc_ssl_server_certificate_config* certificate_config = nullptr; - bool status; - - GPR_ASSERT(sc != nullptr); - if (!server_connector_has_cert_config_fetcher(sc)) return false; - - grpc_ssl_server_credentials* server_creds = - reinterpret_cast(sc->base.server_creds); - grpc_ssl_certificate_config_reload_status cb_result = - server_creds->certificate_config_fetcher.cb( - server_creds->certificate_config_fetcher.user_data, - &certificate_config); - if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { - gpr_log(GPR_DEBUG, "No change in SSL server credentials."); - status = false; - } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { - status = try_replace_server_handshaker_factory(sc, certificate_config); - } else { - // Log error, continue using previously-loaded credentials. - gpr_log(GPR_ERROR, - "Failed fetching new server credentials, continuing to " - "use previously-loaded credentials."); - status = false; - } - - if (certificate_config != nullptr) { - grpc_ssl_server_certificate_config_destroy(certificate_config); - } - return status; -} - -static void ssl_server_add_handshakers(grpc_server_security_connector* sc, - grpc_pollset_set* interested_parties, - grpc_handshake_manager* handshake_mgr) { - grpc_ssl_server_security_connector* c = - reinterpret_cast(sc); - // Instantiate TSI handshaker. - try_fetch_ssl_server_credentials(c); - tsi_handshaker* tsi_hs = nullptr; - tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( - c->server_handshaker_factory, &tsi_hs); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", - tsi_result_to_string(result)); - return; - } - // Create handshakers. - grpc_handshake_manager_add( - handshake_mgr, grpc_security_handshaker_create(tsi_hs, &sc->base)); -} - -static grpc_error* ssl_check_peer(grpc_security_connector* sc, - const char* peer_name, const tsi_peer* peer, - grpc_auth_context** auth_context) { +namespace { +grpc_error* ssl_check_peer( + const char* peer_name, const tsi_peer* peer, + grpc_core::RefCountedPtr* auth_context) { #if TSI_OPENSSL_ALPN_SUPPORT /* Check the ALPN if ALPN is supported. */ const tsi_peer_property* p = @@ -230,245 +69,384 @@ static grpc_error* ssl_check_peer(grpc_security_connector* sc, return GRPC_ERROR_NONE; } -static void ssl_channel_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast(sc); - const char* target_name = c->overridden_target_name != nullptr - ? c->overridden_target_name - : c->target_name; - grpc_error* error = ssl_check_peer(sc, target_name, &peer, auth_context); - if (error == GRPC_ERROR_NONE && - c->verify_options->verify_peer_callback != nullptr) { - const tsi_peer_property* p = - tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); - if (p == nullptr) { - error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "Cannot check peer: missing pem cert property."); - } else { - char* peer_pem = static_cast(gpr_malloc(p->value.length + 1)); - memcpy(peer_pem, p->value.data, p->value.length); - peer_pem[p->value.length] = '\0'; - int callback_status = c->verify_options->verify_peer_callback( - target_name, peer_pem, - c->verify_options->verify_peer_callback_userdata); - gpr_free(peer_pem); - if (callback_status) { - char* msg; - gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", - callback_status); - error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); - gpr_free(msg); - } - } +class grpc_ssl_channel_security_connector final + : public grpc_channel_security_connector { + public: + grpc_ssl_channel_security_connector( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name) + : grpc_channel_security_connector(GRPC_SSL_URL_SCHEME, + std::move(channel_creds), + std::move(request_metadata_creds)), + overridden_target_name_(overridden_target_name == nullptr + ? nullptr + : gpr_strdup(overridden_target_name)), + verify_options_(&config->verify_options) { + char* port; + gpr_split_host_port(target_name, &target_name_, &port); + gpr_free(port); } - GRPC_CLOSURE_SCHED(on_peer_checked, error); - tsi_peer_destruct(&peer); -} -static void ssl_server_check_peer(grpc_security_connector* sc, tsi_peer peer, - grpc_auth_context** auth_context, - grpc_closure* on_peer_checked) { - grpc_error* error = ssl_check_peer(sc, nullptr, &peer, auth_context); - tsi_peer_destruct(&peer); - GRPC_CLOSURE_SCHED(on_peer_checked, error); -} + ~grpc_ssl_channel_security_connector() override { + tsi_ssl_client_handshaker_factory_unref(client_handshaker_factory_); + if (target_name_ != nullptr) gpr_free(target_name_); + if (overridden_target_name_ != nullptr) gpr_free(overridden_target_name_); + } -static int ssl_channel_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - grpc_ssl_channel_security_connector* c1 = - reinterpret_cast(sc1); - grpc_ssl_channel_security_connector* c2 = - reinterpret_cast(sc2); - int c = grpc_channel_security_connector_cmp(&c1->base, &c2->base); - if (c != 0) return c; - c = strcmp(c1->target_name, c2->target_name); - if (c != 0) return c; - return (c1->overridden_target_name == nullptr || - c2->overridden_target_name == nullptr) - ? GPR_ICMP(c1->overridden_target_name, c2->overridden_target_name) - : strcmp(c1->overridden_target_name, c2->overridden_target_name); -} + grpc_security_status InitializeHandshakerFactory( + const grpc_ssl_config* config, const char* pem_root_certs, + const tsi_ssl_root_certs_store* root_store, + tsi_ssl_session_cache* ssl_session_cache) { + bool has_key_cert_pair = + config->pem_key_cert_pair != nullptr && + config->pem_key_cert_pair->private_key != nullptr && + config->pem_key_cert_pair->cert_chain != nullptr; + tsi_ssl_client_handshaker_options options; + memset(&options, 0, sizeof(options)); + GPR_DEBUG_ASSERT(pem_root_certs != nullptr); + options.pem_root_certs = pem_root_certs; + options.root_store = root_store; + options.alpn_protocols = + grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + if (has_key_cert_pair) { + options.pem_key_cert_pair = config->pem_key_cert_pair; + } + options.cipher_suites = grpc_get_ssl_cipher_suites(); + options.session_cache = ssl_session_cache; + const tsi_result result = + tsi_create_ssl_client_handshaker_factory_with_options( + &options, &client_handshaker_factory_); + gpr_free((void*)options.alpn_protocols); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } + return GRPC_SECURITY_OK; + } -static int ssl_server_cmp(grpc_security_connector* sc1, - grpc_security_connector* sc2) { - return grpc_server_security_connector_cmp( - reinterpret_cast(sc1), - reinterpret_cast(sc2)); -} + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_client_handshaker_factory_create_handshaker( + client_handshaker_factory_, + overridden_target_name_ != nullptr ? overridden_target_name_ + : target_name_, + &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); + } -static bool ssl_channel_check_call_host(grpc_channel_security_connector* sc, - const char* host, - grpc_auth_context* auth_context, - grpc_closure* on_call_host_checked, - grpc_error** error) { - grpc_ssl_channel_security_connector* c = - reinterpret_cast(sc); - grpc_security_status status = GRPC_SECURITY_ERROR; - tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); - if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; - /* If the target name was overridden, then the original target_name was - 'checked' transitively during the previous peer check at the end of the - handshake. */ - if (c->overridden_target_name != nullptr && - strcmp(host, c->target_name) == 0) { - status = GRPC_SECURITY_OK; + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + const char* target_name = overridden_target_name_ != nullptr + ? overridden_target_name_ + : target_name_; + grpc_error* error = ssl_check_peer(target_name, &peer, auth_context); + if (error == GRPC_ERROR_NONE && + verify_options_->verify_peer_callback != nullptr) { + const tsi_peer_property* p = + tsi_peer_get_property_by_name(&peer, TSI_X509_PEM_CERT_PROPERTY); + if (p == nullptr) { + error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "Cannot check peer: missing pem cert property."); + } else { + char* peer_pem = static_cast(gpr_malloc(p->value.length + 1)); + memcpy(peer_pem, p->value.data, p->value.length); + peer_pem[p->value.length] = '\0'; + int callback_status = verify_options_->verify_peer_callback( + target_name, peer_pem, + verify_options_->verify_peer_callback_userdata); + gpr_free(peer_pem); + if (callback_status) { + char* msg; + gpr_asprintf(&msg, "Verify peer callback returned a failure (%d)", + callback_status); + error = GRPC_ERROR_CREATE_FROM_COPIED_STRING(msg); + gpr_free(msg); + } + } + } + GRPC_CLOSURE_SCHED(on_peer_checked, error); + tsi_peer_destruct(&peer); } - if (status != GRPC_SECURITY_OK) { - *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( - "call host does not match SSL server name"); + + int cmp(const grpc_security_connector* other_sc) const override { + auto* other = + reinterpret_cast(other_sc); + int c = channel_security_connector_cmp(other); + if (c != 0) return c; + c = strcmp(target_name_, other->target_name_); + if (c != 0) return c; + return (overridden_target_name_ == nullptr || + other->overridden_target_name_ == nullptr) + ? GPR_ICMP(overridden_target_name_, + other->overridden_target_name_) + : strcmp(overridden_target_name_, + other->overridden_target_name_); } - grpc_shallow_peer_destruct(&peer); - return true; -} -static void ssl_channel_cancel_check_call_host( - grpc_channel_security_connector* sc, grpc_closure* on_call_host_checked, - grpc_error* error) { - GRPC_ERROR_UNREF(error); -} + bool check_call_host(const char* host, grpc_auth_context* auth_context, + grpc_closure* on_call_host_checked, + grpc_error** error) override { + grpc_security_status status = GRPC_SECURITY_ERROR; + tsi_peer peer = grpc_shallow_peer_from_ssl_auth_context(auth_context); + if (grpc_ssl_host_matches_name(&peer, host)) status = GRPC_SECURITY_OK; + /* If the target name was overridden, then the original target_name was + 'checked' transitively during the previous peer check at the end of the + handshake. */ + if (overridden_target_name_ != nullptr && strcmp(host, target_name_) == 0) { + status = GRPC_SECURITY_OK; + } + if (status != GRPC_SECURITY_OK) { + *error = GRPC_ERROR_CREATE_FROM_STATIC_STRING( + "call host does not match SSL server name"); + } + grpc_shallow_peer_destruct(&peer); + return true; + } -static grpc_security_connector_vtable ssl_channel_vtable = { - ssl_channel_destroy, ssl_channel_check_peer, ssl_channel_cmp}; + void cancel_check_call_host(grpc_closure* on_call_host_checked, + grpc_error* error) override { + GRPC_ERROR_UNREF(error); + } -static grpc_security_connector_vtable ssl_server_vtable = { - ssl_server_destroy, ssl_server_check_peer, ssl_server_cmp}; + private: + tsi_ssl_client_handshaker_factory* client_handshaker_factory_; + char* target_name_; + char* overridden_target_name_; + const verify_peer_options* verify_options_; +}; + +class grpc_ssl_server_security_connector + : public grpc_server_security_connector { + public: + grpc_ssl_server_security_connector( + grpc_core::RefCountedPtr server_creds) + : grpc_server_security_connector(GRPC_SSL_URL_SCHEME, + std::move(server_creds)) {} + + ~grpc_ssl_server_security_connector() override { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, - const grpc_ssl_config* config, const char* target_name, - const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_channel_security_connector* c; - char* port; - bool has_key_cert_pair; - tsi_ssl_client_handshaker_options options; - memset(&options, 0, sizeof(options)); - options.alpn_protocols = - grpc_fill_alpn_protocol_strings(&options.num_alpn_protocols); + bool has_cert_config_fetcher() const { + return static_cast(server_creds()) + ->has_cert_config_fetcher(); + } - if (config == nullptr || target_name == nullptr) { - gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); - goto error; + const tsi_ssl_server_handshaker_factory* server_handshaker_factory() const { + return server_handshaker_factory_; } - if (config->pem_root_certs == nullptr) { - // Use default root certificates. - options.pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); - options.root_store = grpc_core::DefaultSslRootStore::GetRootStore(); - if (options.pem_root_certs == nullptr) { - gpr_log(GPR_ERROR, "Could not get default pem root certs."); - goto error; + + grpc_security_status InitializeHandshakerFactory() { + if (has_cert_config_fetcher()) { + // Load initial credentials from certificate_config_fetcher: + if (!try_fetch_ssl_server_credentials()) { + gpr_log(GPR_ERROR, + "Failed loading SSL server credentials from fetcher."); + return GRPC_SECURITY_ERROR; + } + } else { + auto* server_credentials = + static_cast(server_creds()); + size_t num_alpn_protocols = 0; + const char** alpn_protocol_strings = + grpc_fill_alpn_protocol_strings(&num_alpn_protocols); + const tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + server_credentials->config().pem_key_cert_pairs, + server_credentials->config().num_key_cert_pairs, + server_credentials->config().pem_root_certs, + grpc_get_tsi_client_certificate_request_type( + server_credentials->config().client_certificate_request), + grpc_get_ssl_cipher_suites(), alpn_protocol_strings, + static_cast(num_alpn_protocols), + &server_handshaker_factory_); + gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", + tsi_result_to_string(result)); + return GRPC_SECURITY_ERROR; + } } - } else { - options.pem_root_certs = config->pem_root_certs; - } - c = static_cast( - gpr_zalloc(sizeof(grpc_ssl_channel_security_connector))); - - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.vtable = &ssl_channel_vtable; - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.channel_creds = grpc_channel_credentials_ref(channel_creds); - c->base.request_metadata_creds = - grpc_call_credentials_ref(request_metadata_creds); - c->base.check_call_host = ssl_channel_check_call_host; - c->base.cancel_check_call_host = ssl_channel_cancel_check_call_host; - c->base.add_handshakers = ssl_channel_add_handshakers; - gpr_split_host_port(target_name, &c->target_name, &port); - gpr_free(port); - if (overridden_target_name != nullptr) { - c->overridden_target_name = gpr_strdup(overridden_target_name); + return GRPC_SECURITY_OK; } - c->verify_options = &config->verify_options; - has_key_cert_pair = config->pem_key_cert_pair != nullptr && - config->pem_key_cert_pair->private_key != nullptr && - config->pem_key_cert_pair->cert_chain != nullptr; - if (has_key_cert_pair) { - options.pem_key_cert_pair = config->pem_key_cert_pair; + void add_handshakers(grpc_pollset_set* interested_parties, + grpc_handshake_manager* handshake_mgr) override { + // Instantiate TSI handshaker. + try_fetch_ssl_server_credentials(); + tsi_handshaker* tsi_hs = nullptr; + tsi_result result = tsi_ssl_server_handshaker_factory_create_handshaker( + server_handshaker_factory_, &tsi_hs); + if (result != TSI_OK) { + gpr_log(GPR_ERROR, "Handshaker creation failed with error %s.", + tsi_result_to_string(result)); + return; + } + // Create handshakers. + grpc_handshake_manager_add(handshake_mgr, + grpc_security_handshaker_create(tsi_hs, this)); } - options.cipher_suites = grpc_get_ssl_cipher_suites(); - options.session_cache = ssl_session_cache; - result = tsi_create_ssl_client_handshaker_factory_with_options( - &options, &c->client_handshaker_factory); - if (result != TSI_OK) { - gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", - tsi_result_to_string(result)); - ssl_channel_destroy(&c->base.base); - *sc = nullptr; - goto error; + + void check_peer(tsi_peer peer, + grpc_core::RefCountedPtr* auth_context, + grpc_closure* on_peer_checked) override { + grpc_error* error = ssl_check_peer(nullptr, &peer, auth_context); + tsi_peer_destruct(&peer); + GRPC_CLOSURE_SCHED(on_peer_checked, error); } - *sc = &c->base; - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_OK; -error: - gpr_free((void*)options.alpn_protocols); - return GRPC_SECURITY_ERROR; -} + int cmp(const grpc_security_connector* other) const override { + return server_security_connector_cmp( + static_cast(other)); + } -static grpc_ssl_server_security_connector* -grpc_ssl_server_security_connector_initialize( - grpc_server_credentials* server_creds) { - grpc_ssl_server_security_connector* c = - static_cast( - gpr_zalloc(sizeof(grpc_ssl_server_security_connector))); - gpr_ref_init(&c->base.base.refcount, 1); - c->base.base.url_scheme = GRPC_SSL_URL_SCHEME; - c->base.base.vtable = &ssl_server_vtable; - c->base.add_handshakers = ssl_server_add_handshakers; - c->base.server_creds = grpc_server_credentials_ref(server_creds); - return c; -} + private: + /* Attempts to fetch the server certificate config if a callback is available. + * Current certificate config will continue to be used if the callback returns + * an error. Returns true if new credentials were sucessfully loaded. */ + bool try_fetch_ssl_server_credentials() { + grpc_ssl_server_certificate_config* certificate_config = nullptr; + bool status; + + if (!has_cert_config_fetcher()) return false; + + grpc_ssl_server_credentials* server_creds = + static_cast(this->mutable_server_creds()); + grpc_ssl_certificate_config_reload_status cb_result = + server_creds->FetchCertConfig(&certificate_config); + if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_UNCHANGED) { + gpr_log(GPR_DEBUG, "No change in SSL server credentials."); + status = false; + } else if (cb_result == GRPC_SSL_CERTIFICATE_CONFIG_RELOAD_NEW) { + status = try_replace_server_handshaker_factory(certificate_config); + } else { + // Log error, continue using previously-loaded credentials. + gpr_log(GPR_ERROR, + "Failed fetching new server credentials, continuing to " + "use previously-loaded credentials."); + status = false; + } -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* gsc, grpc_server_security_connector** sc) { - tsi_result result = TSI_OK; - grpc_ssl_server_credentials* server_credentials = - reinterpret_cast(gsc); - grpc_security_status retval = GRPC_SECURITY_OK; + if (certificate_config != nullptr) { + grpc_ssl_server_certificate_config_destroy(certificate_config); + } + return status; + } - GPR_ASSERT(server_credentials != nullptr); - GPR_ASSERT(sc != nullptr); - - grpc_ssl_server_security_connector* c = - grpc_ssl_server_security_connector_initialize(gsc); - if (server_connector_has_cert_config_fetcher(c)) { - // Load initial credentials from certificate_config_fetcher: - if (!try_fetch_ssl_server_credentials(c)) { - gpr_log(GPR_ERROR, "Failed loading SSL server credentials from fetcher."); - retval = GRPC_SECURITY_ERROR; + /* Attempts to replace the server_handshaker_factory with a new factory using + * the provided grpc_ssl_server_certificate_config. Should new factory + * creation fail, the existing factory will not be replaced. Returns true on + * success (new factory created). */ + bool try_replace_server_handshaker_factory( + const grpc_ssl_server_certificate_config* config) { + if (config == nullptr) { + gpr_log(GPR_ERROR, + "Server certificate config callback returned invalid (NULL) " + "config."); + return false; } - } else { + gpr_log(GPR_DEBUG, "Using new server certificate config (%p).", config); + size_t num_alpn_protocols = 0; const char** alpn_protocol_strings = grpc_fill_alpn_protocol_strings(&num_alpn_protocols); - result = tsi_create_ssl_server_handshaker_factory_ex( - server_credentials->config.pem_key_cert_pairs, - server_credentials->config.num_key_cert_pairs, - server_credentials->config.pem_root_certs, + tsi_ssl_pem_key_cert_pair* cert_pairs = grpc_convert_grpc_to_tsi_cert_pairs( + config->pem_key_cert_pairs, config->num_key_cert_pairs); + tsi_ssl_server_handshaker_factory* new_handshaker_factory = nullptr; + const grpc_ssl_server_credentials* server_creds = + static_cast(this->server_creds()); + GPR_DEBUG_ASSERT(config->pem_root_certs != nullptr); + tsi_result result = tsi_create_ssl_server_handshaker_factory_ex( + cert_pairs, config->num_key_cert_pairs, config->pem_root_certs, grpc_get_tsi_client_certificate_request_type( - server_credentials->config.client_certificate_request), + server_creds->config().client_certificate_request), grpc_get_ssl_cipher_suites(), alpn_protocol_strings, - static_cast(num_alpn_protocols), - &c->server_handshaker_factory); + static_cast(num_alpn_protocols), &new_handshaker_factory); + gpr_free(cert_pairs); gpr_free((void*)alpn_protocol_strings); + if (result != TSI_OK) { gpr_log(GPR_ERROR, "Handshaker factory creation failed with %s.", tsi_result_to_string(result)); - retval = GRPC_SECURITY_ERROR; + return false; } + set_server_handshaker_factory(new_handshaker_factory); + return true; + } + + void set_server_handshaker_factory( + tsi_ssl_server_handshaker_factory* new_factory) { + if (server_handshaker_factory_) { + tsi_ssl_server_handshaker_factory_unref(server_handshaker_factory_); + } + server_handshaker_factory_ = new_factory; + } + + tsi_ssl_server_handshaker_factory* server_handshaker_factory_ = nullptr; +}; +} // namespace + +grpc_core::RefCountedPtr +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, + const grpc_ssl_config* config, const char* target_name, + const char* overridden_target_name, + tsi_ssl_session_cache* ssl_session_cache) { + if (config == nullptr || target_name == nullptr) { + gpr_log(GPR_ERROR, "An ssl channel needs a config and a target name."); + return nullptr; } - if (retval == GRPC_SECURITY_OK) { - *sc = &c->base; + const char* pem_root_certs; + const tsi_ssl_root_certs_store* root_store; + if (config->pem_root_certs == nullptr) { + // Use default root certificates. + pem_root_certs = grpc_core::DefaultSslRootStore::GetPemRootCerts(); + if (pem_root_certs == nullptr) { + gpr_log(GPR_ERROR, "Could not get default pem root certs."); + return nullptr; + } + root_store = grpc_core::DefaultSslRootStore::GetRootStore(); } else { - if (c != nullptr) ssl_server_destroy(&c->base.base); - if (sc != nullptr) *sc = nullptr; + pem_root_certs = config->pem_root_certs; + root_store = nullptr; + } + + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + std::move(channel_creds), std::move(request_metadata_creds), config, + target_name, overridden_target_name); + const grpc_security_status result = c->InitializeHandshakerFactory( + config, pem_root_certs, root_store, ssl_session_cache); + if (result != GRPC_SECURITY_OK) { + return nullptr; } - return retval; + return c; +} + +grpc_core::RefCountedPtr +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr server_credentials) { + GPR_ASSERT(server_credentials != nullptr); + grpc_core::RefCountedPtr c = + grpc_core::MakeRefCounted( + std::move(server_credentials)); + const grpc_security_status retval = c->InitializeHandshakerFactory(); + if (retval != GRPC_SECURITY_OK) { + return nullptr; + } + return c; } diff --git a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h index 9b80590606..70e26e338a 100644 --- a/src/core/lib/security/security_connector/ssl/ssl_security_connector.h +++ b/src/core/lib/security/security_connector/ssl/ssl_security_connector.h @@ -25,6 +25,7 @@ #include "src/core/lib/security/security_connector/security_connector.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,20 +48,21 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_channel_security_connector_create( - grpc_channel_credentials* channel_creds, - grpc_call_credentials* request_metadata_creds, +grpc_core::RefCountedPtr +grpc_ssl_channel_security_connector_create( + grpc_core::RefCountedPtr channel_creds, + grpc_core::RefCountedPtr request_metadata_creds, const grpc_ssl_config* config, const char* target_name, const char* overridden_target_name, - tsi_ssl_session_cache* ssl_session_cache, - grpc_channel_security_connector** sc); + tsi_ssl_session_cache* ssl_session_cache); /* Config for ssl servers. */ typedef struct { - tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs; - size_t num_key_cert_pairs; - char* pem_root_certs; - grpc_ssl_client_certificate_request_type client_certificate_request; + tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs = nullptr; + size_t num_key_cert_pairs = 0; + char* pem_root_certs = nullptr; + grpc_ssl_client_certificate_request_type client_certificate_request = + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE; } grpc_ssl_server_config; /* Creates an SSL server_security_connector. @@ -69,9 +71,9 @@ typedef struct { This function returns GRPC_SECURITY_OK in case of success or a specific error code otherwise. */ -grpc_security_status grpc_ssl_server_security_connector_create( - grpc_server_credentials* server_credentials, - grpc_server_security_connector** sc); +grpc_core::RefCountedPtr +grpc_ssl_server_security_connector_create( + grpc_core::RefCountedPtr server_credentials); #endif /* GRPC_CORE_LIB_SECURITY_SECURITY_CONNECTOR_SSL_SSL_SECURITY_CONNECTOR_H \ */ diff --git a/src/core/lib/security/security_connector/ssl_utils.cc b/src/core/lib/security/security_connector/ssl_utils.cc index fbf41cfbc7..29030f07ad 100644 --- a/src/core/lib/security/security_connector/ssl_utils.cc +++ b/src/core/lib/security/security_connector/ssl_utils.cc @@ -30,6 +30,7 @@ #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/host_port.h" #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/iomgr/load_file.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/security_connector/load_system_roots.h" @@ -141,16 +142,17 @@ int grpc_ssl_host_matches_name(const tsi_peer* peer, const char* peer_name) { return r; } -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { +grpc_core::RefCountedPtr grpc_ssl_peer_to_auth_context( + const tsi_peer* peer) { size_t i; - grpc_auth_context* ctx = nullptr; const char* peer_identity_property_name = nullptr; /* The caller has checked the certificate type property. */ GPR_ASSERT(peer->property_count >= 1); - ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); grpc_auth_context_add_cstring_property( - ctx, GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, + ctx.get(), GRPC_TRANSPORT_SECURITY_TYPE_PROPERTY_NAME, GRPC_SSL_TRANSPORT_SECURITY_TYPE); for (i = 0; i < peer->property_count; i++) { const tsi_peer_property* prop = &peer->properties[i]; @@ -160,24 +162,26 @@ grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer) { if (peer_identity_property_name == nullptr) { peer_identity_property_name = GRPC_X509_CN_PROPERTY_NAME; } - grpc_auth_context_add_property(ctx, GRPC_X509_CN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) { peer_identity_property_name = GRPC_X509_SAN_PROPERTY_NAME; - grpc_auth_context_add_property(ctx, GRPC_X509_SAN_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_X509_PEM_CERT_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_X509_PEM_CERT_PROPERTY_NAME, + grpc_auth_context_add_property(ctx.get(), + GRPC_X509_PEM_CERT_PROPERTY_NAME, prop->value.data, prop->value.length); } else if (strcmp(prop->name, TSI_SSL_SESSION_REUSED_PEER_PROPERTY) == 0) { - grpc_auth_context_add_property(ctx, GRPC_SSL_SESSION_REUSED_PROPERTY, + grpc_auth_context_add_property(ctx.get(), + GRPC_SSL_SESSION_REUSED_PROPERTY, prop->value.data, prop->value.length); } } if (peer_identity_property_name != nullptr) { GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name( - ctx, peer_identity_property_name) == 1); + ctx.get(), peer_identity_property_name) == 1); } return ctx; } diff --git a/src/core/lib/security/security_connector/ssl_utils.h b/src/core/lib/security/security_connector/ssl_utils.h index 6f6d473311..c9cd1a1d9c 100644 --- a/src/core/lib/security/security_connector/ssl_utils.h +++ b/src/core/lib/security/security_connector/ssl_utils.h @@ -26,6 +26,7 @@ #include #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/tsi/ssl_transport_security.h" #include "src/core/tsi/transport_security_interface.h" @@ -47,7 +48,8 @@ grpc_get_tsi_client_certificate_request_type( const char** grpc_fill_alpn_protocol_strings(size_t* num_alpn_protocols); /* Exposed for testing only. */ -grpc_auth_context* grpc_ssl_peer_to_auth_context(const tsi_peer* peer); +grpc_core::RefCountedPtr grpc_ssl_peer_to_auth_context( + const tsi_peer* peer); tsi_peer grpc_shallow_peer_from_ssl_auth_context( const grpc_auth_context* auth_context); void grpc_shallow_peer_destruct(tsi_peer* peer); diff --git a/src/core/lib/security/transport/client_auth_filter.cc b/src/core/lib/security/transport/client_auth_filter.cc index 6955e8698e..66f86b8bc5 100644 --- a/src/core/lib/security/transport/client_auth_filter.cc +++ b/src/core/lib/security/transport/client_auth_filter.cc @@ -55,7 +55,7 @@ struct call_data { // that the memory is not initialized. void destroy() { grpc_credentials_mdelem_array_destroy(&md_array); - grpc_call_credentials_unref(creds); + creds.reset(); grpc_slice_unref_internal(host); grpc_slice_unref_internal(method); grpc_auth_metadata_context_reset(&auth_md_context); @@ -64,7 +64,7 @@ struct call_data { gpr_arena* arena; grpc_call_stack* owning_call; grpc_call_combiner* call_combiner; - grpc_call_credentials* creds = nullptr; + grpc_core::RefCountedPtr creds; grpc_slice host = grpc_empty_slice(); grpc_slice method = grpc_empty_slice(); /* pollset{_set} bound to this call; if we need to make external @@ -83,8 +83,18 @@ struct call_data { /* We can have a per-channel credentials. */ struct channel_data { - grpc_channel_security_connector* security_connector; - grpc_auth_context* auth_context; + channel_data(grpc_channel_security_connector* security_connector, + grpc_auth_context* auth_context) + : security_connector( + security_connector->Ref(DEBUG_LOCATION, "client_auth_filter")), + auth_context(auth_context->Ref(DEBUG_LOCATION, "client_auth_filter")) {} + ~channel_data() { + security_connector.reset(DEBUG_LOCATION, "client_auth_filter"); + auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); + } + + grpc_core::RefCountedPtr security_connector; + grpc_core::RefCountedPtr auth_context; }; } // namespace @@ -98,10 +108,11 @@ void grpc_auth_metadata_context_reset( gpr_free(const_cast(auth_md_context->method_name)); auth_md_context->method_name = nullptr; } - GRPC_AUTH_CONTEXT_UNREF( - (grpc_auth_context*)auth_md_context->channel_auth_context, - "grpc_auth_metadata_context"); - auth_md_context->channel_auth_context = nullptr; + if (auth_md_context->channel_auth_context != nullptr) { + const_cast(auth_md_context->channel_auth_context) + ->Unref(DEBUG_LOCATION, "grpc_auth_metadata_context"); + auth_md_context->channel_auth_context = nullptr; + } } static void add_error(grpc_error** combined, grpc_error* error) { @@ -175,7 +186,10 @@ void grpc_auth_metadata_context_build( auth_md_context->service_url = service_url; auth_md_context->method_name = method_name; auth_md_context->channel_auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "grpc_auth_metadata_context"); + auth_context == nullptr + ? nullptr + : auth_context->Ref(DEBUG_LOCATION, "grpc_auth_metadata_context") + .release(); gpr_free(service); gpr_free(host_and_port); } @@ -184,8 +198,8 @@ static void cancel_get_request_metadata(void* arg, grpc_error* error) { grpc_call_element* elem = static_cast(arg); call_data* calld = static_cast(elem->call_data); if (error != GRPC_ERROR_NONE) { - grpc_call_credentials_cancel_get_request_metadata( - calld->creds, &calld->md_array, GRPC_ERROR_REF(error)); + calld->creds->cancel_get_request_metadata(&calld->md_array, + GRPC_ERROR_REF(error)); } } @@ -197,7 +211,7 @@ static void send_security_metadata(grpc_call_element* elem, static_cast( batch->payload->context[GRPC_CONTEXT_SECURITY].value); grpc_call_credentials* channel_call_creds = - chand->security_connector->request_metadata_creds; + chand->security_connector->mutable_request_metadata_creds(); int call_creds_has_md = (ctx != nullptr) && (ctx->creds != nullptr); if (channel_call_creds == nullptr && !call_creds_has_md) { @@ -207,8 +221,9 @@ static void send_security_metadata(grpc_call_element* elem, } if (channel_call_creds != nullptr && call_creds_has_md) { - calld->creds = grpc_composite_call_credentials_create(channel_call_creds, - ctx->creds, nullptr); + calld->creds = grpc_core::RefCountedPtr( + grpc_composite_call_credentials_create(channel_call_creds, + ctx->creds.get(), nullptr)); if (calld->creds == nullptr) { grpc_transport_stream_op_batch_finish_with_failure( batch, @@ -220,22 +235,22 @@ static void send_security_metadata(grpc_call_element* elem, return; } } else { - calld->creds = grpc_call_credentials_ref( - call_creds_has_md ? ctx->creds : channel_call_creds); + calld->creds = + call_creds_has_md ? ctx->creds->Ref() : channel_call_creds->Ref(); } grpc_auth_metadata_context_build( - chand->security_connector->base.url_scheme, calld->host, calld->method, - chand->auth_context, &calld->auth_md_context); + chand->security_connector->url_scheme(), calld->host, calld->method, + chand->auth_context.get(), &calld->auth_md_context); GPR_ASSERT(calld->pollent != nullptr); GRPC_CALL_STACK_REF(calld->owning_call, "get_request_metadata"); GRPC_CLOSURE_INIT(&calld->async_result_closure, on_credentials_metadata, batch, grpc_schedule_on_exec_ctx); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - calld->creds, calld->pollent, calld->auth_md_context, - &calld->md_array, &calld->async_result_closure, &error)) { + if (calld->creds->get_request_metadata( + calld->pollent, calld->auth_md_context, &calld->md_array, + &calld->async_result_closure, &error)) { // Synchronous return; invoke on_credentials_metadata() directly. on_credentials_metadata(batch, error); GRPC_ERROR_UNREF(error); @@ -279,9 +294,8 @@ static void cancel_check_call_host(void* arg, grpc_error* error) { call_data* calld = static_cast(elem->call_data); channel_data* chand = static_cast(elem->channel_data); if (error != GRPC_ERROR_NONE) { - grpc_channel_security_connector_cancel_check_call_host( - chand->security_connector, &calld->async_result_closure, - GRPC_ERROR_REF(error)); + chand->security_connector->cancel_check_call_host( + &calld->async_result_closure, GRPC_ERROR_REF(error)); } } @@ -299,16 +313,16 @@ static void auth_start_transport_stream_op_batch( GPR_ASSERT(batch->payload->context != nullptr); if (batch->payload->context[GRPC_CONTEXT_SECURITY].value == nullptr) { batch->payload->context[GRPC_CONTEXT_SECURITY].value = - grpc_client_security_context_create(calld->arena); + grpc_client_security_context_create(calld->arena, /*creds=*/nullptr); batch->payload->context[GRPC_CONTEXT_SECURITY].destroy = grpc_client_security_context_destroy; } grpc_client_security_context* sec_ctx = static_cast( batch->payload->context[GRPC_CONTEXT_SECURITY].value); - GRPC_AUTH_CONTEXT_UNREF(sec_ctx->auth_context, "client auth filter"); + sec_ctx->auth_context.reset(DEBUG_LOCATION, "client_auth_filter"); sec_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "client_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "client_auth_filter"); } if (batch->send_initial_metadata) { @@ -327,8 +341,8 @@ static void auth_start_transport_stream_op_batch( grpc_schedule_on_exec_ctx); char* call_host = grpc_slice_to_c_string(calld->host); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_channel_security_connector_check_call_host( - chand->security_connector, call_host, chand->auth_context, + if (chand->security_connector->check_call_host( + call_host, chand->auth_context.get(), &calld->async_result_closure, &error)) { // Synchronous return; invoke on_host_checked() directly. on_host_checked(batch, error); @@ -374,6 +388,10 @@ static void destroy_call_elem(grpc_call_element* elem, /* Constructor for channel_data */ static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { + /* The first and the last filters tend to be implemented differently to + handle the case that there's no 'next' filter to call on the up or down + path */ + GPR_ASSERT(!args->is_last); grpc_security_connector* sc = grpc_security_connector_find_in_args(args->channel_args); if (sc == nullptr) { @@ -386,33 +404,15 @@ static grpc_error* init_channel_elem(grpc_channel_element* elem, return GRPC_ERROR_CREATE_FROM_STATIC_STRING( "Auth context missing from client auth filter args"); } - - /* grab pointers to our data from the channel element */ - channel_data* chand = static_cast(elem->channel_data); - - /* The first and the last filters tend to be implemented differently to - handle the case that there's no 'next' filter to call on the up or down - path */ - GPR_ASSERT(!args->is_last); - - /* initialize members */ - chand->security_connector = - reinterpret_cast( - GRPC_SECURITY_CONNECTOR_REF(sc, "client_auth_filter")); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "client_auth_filter"); + new (elem->channel_data) channel_data( + static_cast(sc), auth_context); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { - /* grab pointers to our data from the channel element */ channel_data* chand = static_cast(elem->channel_data); - grpc_channel_security_connector* sc = chand->security_connector; - if (sc != nullptr) { - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "client_auth_filter"); - } - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "client_auth_filter"); + chand->~channel_data(); } const grpc_channel_filter grpc_client_auth_filter = { diff --git a/src/core/lib/security/transport/security_handshaker.cc b/src/core/lib/security/transport/security_handshaker.cc index 854a1c4af9..48d6901e88 100644 --- a/src/core/lib/security/transport/security_handshaker.cc +++ b/src/core/lib/security/transport/security_handshaker.cc @@ -30,6 +30,7 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/channel/handshaker.h" #include "src/core/lib/channel/handshaker_registry.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/transport/secure_endpoint.h" #include "src/core/lib/security/transport/tsi_error.h" @@ -38,34 +39,62 @@ #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256 -typedef struct { +namespace { +struct security_handshaker { + security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector); + ~security_handshaker() { + gpr_mu_destroy(&mu); + tsi_handshaker_destroy(handshaker); + tsi_handshaker_result_destroy(handshaker_result); + if (endpoint_to_destroy != nullptr) { + grpc_endpoint_destroy(endpoint_to_destroy); + } + if (read_buffer_to_destroy != nullptr) { + grpc_slice_buffer_destroy_internal(read_buffer_to_destroy); + gpr_free(read_buffer_to_destroy); + } + gpr_free(handshake_buffer); + grpc_slice_buffer_destroy_internal(&outgoing); + auth_context.reset(DEBUG_LOCATION, "handshake"); + connector.reset(DEBUG_LOCATION, "handshake"); + } + + void Ref() { refs.Ref(); } + void Unref() { + if (refs.Unref()) { + grpc_core::Delete(this); + } + } + grpc_handshaker base; // State set at creation time. tsi_handshaker* handshaker; - grpc_security_connector* connector; + grpc_core::RefCountedPtr connector; gpr_mu mu; - gpr_refcount refs; + grpc_core::RefCount refs; - bool shutdown; + bool shutdown = false; // Endpoint and read buffer to destroy after a shutdown. - grpc_endpoint* endpoint_to_destroy; - grpc_slice_buffer* read_buffer_to_destroy; + grpc_endpoint* endpoint_to_destroy = nullptr; + grpc_slice_buffer* read_buffer_to_destroy = nullptr; // State saved while performing the handshake. - grpc_handshaker_args* args; - grpc_closure* on_handshake_done; + grpc_handshaker_args* args = nullptr; + grpc_closure* on_handshake_done = nullptr; - unsigned char* handshake_buffer; size_t handshake_buffer_size; + unsigned char* handshake_buffer; grpc_slice_buffer outgoing; grpc_closure on_handshake_data_sent_to_peer; grpc_closure on_handshake_data_received_from_peer; grpc_closure on_peer_checked; - grpc_auth_context* auth_context; - tsi_handshaker_result* handshaker_result; -} security_handshaker; + grpc_core::RefCountedPtr auth_context; + tsi_handshaker_result* handshaker_result = nullptr; +}; +} // namespace static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { size_t bytes_in_read_buffer = h->args->read_buffer->length; @@ -85,26 +114,6 @@ static size_t move_read_buffer_into_handshake_buffer(security_handshaker* h) { return bytes_in_read_buffer; } -static void security_handshaker_unref(security_handshaker* h) { - if (gpr_unref(&h->refs)) { - gpr_mu_destroy(&h->mu); - tsi_handshaker_destroy(h->handshaker); - tsi_handshaker_result_destroy(h->handshaker_result); - if (h->endpoint_to_destroy != nullptr) { - grpc_endpoint_destroy(h->endpoint_to_destroy); - } - if (h->read_buffer_to_destroy != nullptr) { - grpc_slice_buffer_destroy_internal(h->read_buffer_to_destroy); - gpr_free(h->read_buffer_to_destroy); - } - gpr_free(h->handshake_buffer); - grpc_slice_buffer_destroy_internal(&h->outgoing); - GRPC_AUTH_CONTEXT_UNREF(h->auth_context, "handshake"); - GRPC_SECURITY_CONNECTOR_UNREF(h->connector, "handshake"); - gpr_free(h); - } -} - // Set args fields to NULL, saving the endpoint and read buffer for // later destruction. static void cleanup_args_for_failure_locked(security_handshaker* h) { @@ -194,7 +203,7 @@ static void on_peer_checked_inner(security_handshaker* h, grpc_error* error) { tsi_handshaker_result_destroy(h->handshaker_result); h->handshaker_result = nullptr; // Add auth context to channel args. - grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context); + grpc_arg auth_context_arg = grpc_auth_context_to_arg(h->auth_context.get()); grpc_channel_args* tmp_args = h->args->args; h->args->args = grpc_channel_args_copy_and_add(tmp_args, &auth_context_arg, 1); @@ -211,7 +220,7 @@ static void on_peer_checked(void* arg, grpc_error* error) { gpr_mu_lock(&h->mu); on_peer_checked_inner(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } static grpc_error* check_peer_locked(security_handshaker* h) { @@ -222,8 +231,7 @@ static grpc_error* check_peer_locked(security_handshaker* h) { return grpc_set_tsi_error_result( GRPC_ERROR_CREATE_FROM_STATIC_STRING("Peer extraction failed"), result); } - grpc_security_connector_check_peer(h->connector, peer, &h->auth_context, - &h->on_peer_checked); + h->connector->check_peer(peer, &h->auth_context, &h->on_peer_checked); return GRPC_ERROR_NONE; } @@ -281,7 +289,7 @@ static void on_handshake_next_done_grpc_wrapper( if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -317,7 +325,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake read failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // Copy all slices received. @@ -329,7 +337,7 @@ static void on_handshake_data_received_from_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); } else { gpr_mu_unlock(&h->mu); } @@ -343,7 +351,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { h, GRPC_ERROR_CREATE_REFERENCING_FROM_STATIC_STRING( "Handshake write failed", &error, 1)); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } // We may be done. @@ -355,7 +363,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } } @@ -368,7 +376,7 @@ static void on_handshake_data_sent_to_peer(void* arg, grpc_error* error) { static void security_handshaker_destroy(grpc_handshaker* handshaker) { security_handshaker* h = reinterpret_cast(handshaker); - security_handshaker_unref(h); + h->Unref(); } static void security_handshaker_shutdown(grpc_handshaker* handshaker, @@ -393,14 +401,14 @@ static void security_handshaker_do_handshake(grpc_handshaker* handshaker, gpr_mu_lock(&h->mu); h->args = args; h->on_handshake_done = on_handshake_done; - gpr_ref(&h->refs); + h->Ref(); size_t bytes_received_size = move_read_buffer_into_handshake_buffer(h); grpc_error* error = do_handshaker_next_locked(h, h->handshake_buffer, bytes_received_size); if (error != GRPC_ERROR_NONE) { security_handshake_failed_locked(h, error); gpr_mu_unlock(&h->mu); - security_handshaker_unref(h); + h->Unref(); return; } gpr_mu_unlock(&h->mu); @@ -410,27 +418,32 @@ static const grpc_handshaker_vtable security_handshaker_vtable = { security_handshaker_destroy, security_handshaker_shutdown, security_handshaker_do_handshake, "security"}; -static grpc_handshaker* security_handshaker_create( - tsi_handshaker* handshaker, grpc_security_connector* connector) { - security_handshaker* h = static_cast( - gpr_zalloc(sizeof(security_handshaker))); - grpc_handshaker_init(&security_handshaker_vtable, &h->base); - h->handshaker = handshaker; - h->connector = GRPC_SECURITY_CONNECTOR_REF(connector, "handshake"); - gpr_mu_init(&h->mu); - gpr_ref_init(&h->refs, 1); - h->handshake_buffer_size = GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE; - h->handshake_buffer = - static_cast(gpr_malloc(h->handshake_buffer_size)); - GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer, - on_handshake_data_sent_to_peer, h, +namespace { +security_handshaker::security_handshaker(tsi_handshaker* handshaker, + grpc_security_connector* connector) + : handshaker(handshaker), + connector(connector->Ref(DEBUG_LOCATION, "handshake")), + handshake_buffer_size(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE), + handshake_buffer( + static_cast(gpr_malloc(handshake_buffer_size))) { + grpc_handshaker_init(&security_handshaker_vtable, &base); + gpr_mu_init(&mu); + grpc_slice_buffer_init(&outgoing); + GRPC_CLOSURE_INIT(&on_handshake_data_sent_to_peer, + ::on_handshake_data_sent_to_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer, - on_handshake_data_received_from_peer, h, + GRPC_CLOSURE_INIT(&on_handshake_data_received_from_peer, + ::on_handshake_data_received_from_peer, this, grpc_schedule_on_exec_ctx); - GRPC_CLOSURE_INIT(&h->on_peer_checked, on_peer_checked, h, + GRPC_CLOSURE_INIT(&on_peer_checked, ::on_peer_checked, this, grpc_schedule_on_exec_ctx); - grpc_slice_buffer_init(&h->outgoing); +} +} // namespace + +static grpc_handshaker* security_handshaker_create( + tsi_handshaker* handshaker, grpc_security_connector* connector) { + security_handshaker* h = + grpc_core::New(handshaker, connector); return &h->base; } @@ -477,8 +490,9 @@ static void client_handshaker_factory_add_handshakers( grpc_channel_security_connector* security_connector = reinterpret_cast( grpc_security_connector_find_in_args(args)); - grpc_channel_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void server_handshaker_factory_add_handshakers( @@ -488,8 +502,9 @@ static void server_handshaker_factory_add_handshakers( grpc_server_security_connector* security_connector = reinterpret_cast( grpc_security_connector_find_in_args(args)); - grpc_server_security_connector_add_handshakers( - security_connector, interested_parties, handshake_mgr); + if (security_connector) { + security_connector->add_handshakers(interested_parties, handshake_mgr); + } } static void handshaker_factory_destroy( diff --git a/src/core/lib/security/transport/server_auth_filter.cc b/src/core/lib/security/transport/server_auth_filter.cc index 362f49a584..f93eb4275e 100644 --- a/src/core/lib/security/transport/server_auth_filter.cc +++ b/src/core/lib/security/transport/server_auth_filter.cc @@ -39,8 +39,12 @@ enum async_state { }; struct channel_data { - grpc_auth_context* auth_context; - grpc_server_credentials* creds; + channel_data(grpc_auth_context* auth_context, grpc_server_credentials* creds) + : auth_context(auth_context->Ref()), creds(creds->Ref()) {} + ~channel_data() { auth_context.reset(DEBUG_LOCATION, "server_auth_filter"); } + + grpc_core::RefCountedPtr auth_context; + grpc_core::RefCountedPtr creds; }; struct call_data { @@ -58,7 +62,7 @@ struct call_data { grpc_server_security_context_create(args.arena); channel_data* chand = static_cast(elem->channel_data); server_ctx->auth_context = - GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter"); + chand->auth_context->Ref(DEBUG_LOCATION, "server_auth_filter"); if (args.context[GRPC_CONTEXT_SECURITY].value != nullptr) { args.context[GRPC_CONTEXT_SECURITY].destroy( args.context[GRPC_CONTEXT_SECURITY].value); @@ -208,7 +212,8 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { call_data* calld = static_cast(elem->call_data); grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch; if (error == GRPC_ERROR_NONE) { - if (chand->creds != nullptr && chand->creds->processor.process != nullptr) { + if (chand->creds != nullptr && + chand->creds->auth_metadata_processor().process != nullptr) { // We're calling out to the application, so we need to make sure // to drop the call combiner early if we get cancelled. GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem, @@ -218,9 +223,10 @@ static void recv_initial_metadata_ready(void* arg, grpc_error* error) { GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata"); calld->md = metadata_batch_to_md_array( batch->payload->recv_initial_metadata.recv_initial_metadata); - chand->creds->processor.process( - chand->creds->processor.state, chand->auth_context, - calld->md.metadata, calld->md.count, on_md_processing_done, elem); + chand->creds->auth_metadata_processor().process( + chand->creds->auth_metadata_processor().state, + chand->auth_context.get(), calld->md.metadata, calld->md.count, + on_md_processing_done, elem); return; } } @@ -290,23 +296,19 @@ static void destroy_call_elem(grpc_call_element* elem, static grpc_error* init_channel_elem(grpc_channel_element* elem, grpc_channel_element_args* args) { GPR_ASSERT(!args->is_last); - channel_data* chand = static_cast(elem->channel_data); grpc_auth_context* auth_context = grpc_find_auth_context_in_args(args->channel_args); GPR_ASSERT(auth_context != nullptr); - chand->auth_context = - GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter"); grpc_server_credentials* creds = grpc_find_server_credentials_in_args(args->channel_args); - chand->creds = grpc_server_credentials_ref(creds); + new (elem->channel_data) channel_data(auth_context, creds); return GRPC_ERROR_NONE; } /* Destructor for channel data */ static void destroy_channel_elem(grpc_channel_element* elem) { channel_data* chand = static_cast(elem->channel_data); - GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter"); - grpc_server_credentials_unref(chand->creds); + chand->~channel_data(); } const grpc_channel_filter grpc_server_auth_filter = { diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index d0abe441a6..4d0ed355ab 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -261,10 +261,10 @@ void MetadataCredentialsPluginWrapper::InvokePlugin( grpc_status_code* status_code, const char** error_details) { std::multimap metadata; - // const_cast is safe since the SecureAuthContext does not take owndership and - // the object is passed as a const ref to plugin_->GetMetadata. + // const_cast is safe since the SecureAuthContext only inc/dec the refcount + // and the object is passed as a const ref to plugin_->GetMetadata. SecureAuthContext cpp_channel_auth_context( - const_cast(context.channel_auth_context), false); + const_cast(context.channel_auth_context)); Status status = plugin_->GetMetadata(context.service_url, context.method_name, cpp_channel_auth_context, &metadata); diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index 613f1d6dc2..4918bd5a4d 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -24,6 +24,7 @@ #include #include +#include "src/core/lib/security/credentials/credentials.h" #include "src/cpp/server/thread_pool_interface.h" namespace grpc { @@ -31,7 +32,9 @@ namespace grpc { class SecureChannelCredentials final : public ChannelCredentials { public: explicit SecureChannelCredentials(grpc_channel_credentials* c_creds); - ~SecureChannelCredentials() { grpc_channel_credentials_release(c_creds_); } + ~SecureChannelCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_channel_credentials* GetRawCreds() { return c_creds_; } std::shared_ptr CreateChannel( @@ -51,7 +54,9 @@ class SecureChannelCredentials final : public ChannelCredentials { class SecureCallCredentials final : public CallCredentials { public: explicit SecureCallCredentials(grpc_call_credentials* c_creds); - ~SecureCallCredentials() { grpc_call_credentials_release(c_creds_); } + ~SecureCallCredentials() { + if (c_creds_ != nullptr) c_creds_->Unref(); + } grpc_call_credentials* GetRawCreds() { return c_creds_; } bool ApplyToCall(grpc_call* call) override; diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc index 1d66dd3d1f..7a2b5afed6 100644 --- a/src/cpp/common/secure_auth_context.cc +++ b/src/cpp/common/secure_auth_context.cc @@ -22,19 +22,12 @@ namespace grpc { -SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx, - bool take_ownership) - : ctx_(ctx), take_ownership_(take_ownership) {} - -SecureAuthContext::~SecureAuthContext() { - if (take_ownership_) grpc_auth_context_release(ctx_); -} - std::vector SecureAuthContext::GetPeerIdentity() const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector(); } - grpc_auth_property_iterator iter = grpc_auth_context_peer_identity(ctx_); + grpc_auth_property_iterator iter = + grpc_auth_context_peer_identity(ctx_.get()); std::vector identity; const grpc_auth_property* property = nullptr; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -45,20 +38,20 @@ std::vector SecureAuthContext::GetPeerIdentity() const { } grpc::string SecureAuthContext::GetPeerIdentityPropertyName() const { - if (!ctx_) { + if (ctx_ == nullptr) { return ""; } - const char* name = grpc_auth_context_peer_identity_property_name(ctx_); + const char* name = grpc_auth_context_peer_identity_property_name(ctx_.get()); return name == nullptr ? "" : name; } std::vector SecureAuthContext::FindPropertyValues( const grpc::string& name) const { - if (!ctx_) { + if (ctx_ == nullptr) { return std::vector(); } grpc_auth_property_iterator iter = - grpc_auth_context_find_properties_by_name(ctx_, name.c_str()); + grpc_auth_context_find_properties_by_name(ctx_.get(), name.c_str()); const grpc_auth_property* property = nullptr; std::vector values; while ((property = grpc_auth_property_iterator_next(&iter))) { @@ -68,9 +61,9 @@ std::vector SecureAuthContext::FindPropertyValues( } AuthPropertyIterator SecureAuthContext::begin() const { - if (ctx_) { + if (ctx_ != nullptr) { grpc_auth_property_iterator iter = - grpc_auth_context_property_iterator(ctx_); + grpc_auth_context_property_iterator(ctx_.get()); const grpc_auth_property* property = grpc_auth_property_iterator_next(&iter); return AuthPropertyIterator(property, &iter); @@ -85,19 +78,20 @@ AuthPropertyIterator SecureAuthContext::end() const { void SecureAuthContext::AddProperty(const grpc::string& key, const grpc::string_ref& value) { - if (!ctx_) return; - grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size()); + if (ctx_ == nullptr) return; + grpc_auth_context_add_property(ctx_.get(), key.c_str(), value.data(), + value.size()); } bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) { - if (!ctx_) return false; - return grpc_auth_context_set_peer_identity_property_name(ctx_, + if (ctx_ == nullptr) return false; + return grpc_auth_context_set_peer_identity_property_name(ctx_.get(), name.c_str()) != 0; } bool SecureAuthContext::IsPeerAuthenticated() const { - if (!ctx_) return false; - return grpc_auth_context_peer_is_authenticated(ctx_) != 0; + if (ctx_ == nullptr) return false; + return grpc_auth_context_peer_is_authenticated(ctx_.get()) != 0; } } // namespace grpc diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h index 142617959c..2e8f793721 100644 --- a/src/cpp/common/secure_auth_context.h +++ b/src/cpp/common/secure_auth_context.h @@ -21,15 +21,17 @@ #include -struct grpc_auth_context; +#include "src/core/lib/gprpp/ref_counted_ptr.h" +#include "src/core/lib/security/context/security_context.h" namespace grpc { class SecureAuthContext final : public AuthContext { public: - SecureAuthContext(grpc_auth_context* ctx, bool take_ownership); + explicit SecureAuthContext(grpc_auth_context* ctx) + : ctx_(ctx != nullptr ? ctx->Ref() : nullptr) {} - ~SecureAuthContext() override; + ~SecureAuthContext() override = default; bool IsPeerAuthenticated() const override; @@ -50,8 +52,7 @@ class SecureAuthContext final : public AuthContext { virtual bool SetPeerIdentityPropertyName(const grpc::string& name) override; private: - grpc_auth_context* ctx_; - bool take_ownership_; + grpc_core::RefCountedPtr ctx_; }; } // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc index bc1387c8d7..908c46629e 100644 --- a/src/cpp/common/secure_create_auth_context.cc +++ b/src/cpp/common/secure_create_auth_context.cc @@ -20,6 +20,7 @@ #include #include #include +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/cpp/common/secure_auth_context.h" namespace grpc { @@ -28,8 +29,8 @@ std::shared_ptr CreateAuthContext(grpc_call* call) { if (call == nullptr) { return std::shared_ptr(); } - return std::shared_ptr( - new SecureAuthContext(grpc_call_auth_context(call), true)); + grpc_core::RefCountedPtr ctx(grpc_call_auth_context(call)); + return std::make_shared(ctx.get()); } } // namespace grpc diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index ebb17def32..453e76eb25 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -61,7 +61,7 @@ void AuthMetadataProcessorAyncWrapper::InvokeProcessor( metadata.insert(std::make_pair(StringRefFromSlice(&md[i].key), StringRefFromSlice(&md[i].value))); } - SecureAuthContext context(ctx, false); + SecureAuthContext context(ctx); AuthMetadataProcessor::OutputMetadata consumed_metadata; AuthMetadataProcessor::OutputMetadata response_metadata; diff --git a/test/core/security/alts_security_connector_test.cc b/test/core/security/alts_security_connector_test.cc index 9378236338..bcba340821 100644 --- a/test/core/security/alts_security_connector_test.cc +++ b/test/core/security/alts_security_connector_test.cc @@ -33,40 +33,34 @@ using grpc_core::internal::grpc_alts_auth_context_from_tsi_peer; /* This file contains unit tests of grpc_alts_auth_context_from_tsi_peer(). */ static void test_invalid_input_failure() { - tsi_peer peer; - grpc_auth_context* ctx; - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(nullptr, &ctx) == - GRPC_SECURITY_ERROR); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, nullptr) == - GRPC_SECURITY_ERROR); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(nullptr); + GPR_ASSERT(ctx == nullptr); } static void test_empty_certificate_type_failure() { tsi_peer peer; - grpc_auth_context* ctx = nullptr; GPR_ASSERT(tsi_construct_peer(0, &peer) == TSI_OK); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) == - GRPC_SECURITY_ERROR); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); GPR_ASSERT(ctx == nullptr); tsi_peer_destruct(&peer); } static void test_empty_peer_property_failure() { tsi_peer peer; - grpc_auth_context* ctx; GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, &peer.properties[0]) == TSI_OK); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) == - GRPC_SECURITY_ERROR); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); GPR_ASSERT(ctx == nullptr); tsi_peer_destruct(&peer); } static void test_missing_rpc_protocol_versions_property_failure() { tsi_peer peer; - grpc_auth_context* ctx; GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, @@ -74,23 +68,22 @@ static void test_missing_rpc_protocol_versions_property_failure() { GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice", &peer.properties[1]) == TSI_OK); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) == - GRPC_SECURITY_ERROR); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); GPR_ASSERT(ctx == nullptr); tsi_peer_destruct(&peer); } static void test_unknown_peer_property_failure() { tsi_peer peer; - grpc_auth_context* ctx; GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, &peer.properties[0]) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( "unknown", "alice", &peer.properties[1]) == TSI_OK); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) == - GRPC_SECURITY_ERROR); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); GPR_ASSERT(ctx == nullptr); tsi_peer_destruct(&peer); } @@ -119,7 +112,6 @@ static bool test_identity(const grpc_auth_context* ctx, static void test_alts_peer_to_auth_context_success() { tsi_peer peer; - grpc_auth_context* ctx; GPR_ASSERT(tsi_construct_peer(kTsiAltsNumOfPeerProperties, &peer) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE, @@ -144,11 +136,12 @@ static void test_alts_peer_to_auth_context_success() { GRPC_SLICE_START_PTR(serialized_peer_versions)), GRPC_SLICE_LENGTH(serialized_peer_versions), &peer.properties[2]) == TSI_OK); - GPR_ASSERT(grpc_alts_auth_context_from_tsi_peer(&peer, &ctx) == - GRPC_SECURITY_OK); - GPR_ASSERT( - test_identity(ctx, TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, "alice")); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + grpc_core::RefCountedPtr ctx = + grpc_alts_auth_context_from_tsi_peer(&peer); + GPR_ASSERT(ctx != nullptr); + GPR_ASSERT(test_identity(ctx.get(), TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, + "alice")); + ctx.reset(DEBUG_LOCATION, "test"); grpc_slice_unref(serialized_peer_versions); tsi_peer_destruct(&peer); } diff --git a/test/core/security/auth_context_test.cc b/test/core/security/auth_context_test.cc index 9a39afb800..e7e0cb2ed9 100644 --- a/test/core/security/auth_context_test.cc +++ b/test/core/security/auth_context_test.cc @@ -19,114 +19,122 @@ #include #include "src/core/lib/gpr/string.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "test/core/util/test_config.h" #include static void test_empty_context(void) { - grpc_auth_context* ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); grpc_auth_property_iterator it; gpr_log(GPR_INFO, "test_empty_context"); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx) == nullptr); - it = grpc_auth_context_peer_identity(ctx); + GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) == + nullptr); + it = grpc_auth_context_peer_identity(ctx.get()); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_property_iterator(ctx); + it = grpc_auth_context_property_iterator(ctx.get()); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_find_properties_by_name(ctx, "foo"); + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "bar") == - 0); - GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx) == nullptr); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + GPR_ASSERT( + grpc_auth_context_set_peer_identity_property_name(ctx.get(), "bar") == 0); + GPR_ASSERT(grpc_auth_context_peer_identity_property_name(ctx.get()) == + nullptr); + ctx.reset(DEBUG_LOCATION, "test"); } static void test_simple_context(void) { - grpc_auth_context* ctx = grpc_auth_context_create(nullptr); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); grpc_auth_property_iterator it; size_t i; gpr_log(GPR_INFO, "test_simple_context"); GPR_ASSERT(ctx != nullptr); - grpc_auth_context_add_cstring_property(ctx, "name", "chapi"); - grpc_auth_context_add_cstring_property(ctx, "name", "chapo"); - grpc_auth_context_add_cstring_property(ctx, "foo", "bar"); - GPR_ASSERT(ctx->properties.count == 3); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "name") == - 1); - - GPR_ASSERT( - strcmp(grpc_auth_context_peer_identity_property_name(ctx), "name") == 0); - it = grpc_auth_context_property_iterator(ctx); - for (i = 0; i < ctx->properties.count; i++) { + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapo"); + grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar"); + GPR_ASSERT(ctx->properties().count == 3); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(), + "name") == 1); + + GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()), + "name") == 0); + it = grpc_auth_context_property_iterator(ctx.get()); + for (i = 0; i < ctx->properties().count; i++) { const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); - GPR_ASSERT(p == &ctx->properties.array[i]); + GPR_ASSERT(p == &ctx->properties().array[i]); } GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_find_properties_by_name(ctx, "foo"); + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[2]); + &ctx->properties().array[2]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_peer_identity(ctx); + it = grpc_auth_context_peer_identity(ctx.get()); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[0]); + &ctx->properties().array[0]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[1]); + &ctx->properties().array[1]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static void test_chained_context(void) { - grpc_auth_context* chained = grpc_auth_context_create(nullptr); - grpc_auth_context* ctx = grpc_auth_context_create(chained); + grpc_core::RefCountedPtr chained = + grpc_core::MakeRefCounted(nullptr); + grpc_auth_context* chained_ptr = chained.get(); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(std::move(chained)); + grpc_auth_property_iterator it; size_t i; gpr_log(GPR_INFO, "test_chained_context"); - GRPC_AUTH_CONTEXT_UNREF(chained, "chained"); - grpc_auth_context_add_cstring_property(chained, "name", "padapo"); - grpc_auth_context_add_cstring_property(chained, "foo", "baz"); - grpc_auth_context_add_cstring_property(ctx, "name", "chapi"); - grpc_auth_context_add_cstring_property(ctx, "name", "chap0"); - grpc_auth_context_add_cstring_property(ctx, "foo", "bar"); - GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx, "name") == - 1); - - GPR_ASSERT( - strcmp(grpc_auth_context_peer_identity_property_name(ctx), "name") == 0); - it = grpc_auth_context_property_iterator(ctx); - for (i = 0; i < ctx->properties.count; i++) { + grpc_auth_context_add_cstring_property(chained_ptr, "name", "padapo"); + grpc_auth_context_add_cstring_property(chained_ptr, "foo", "baz"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx.get(), "name", "chap0"); + grpc_auth_context_add_cstring_property(ctx.get(), "foo", "bar"); + GPR_ASSERT(grpc_auth_context_set_peer_identity_property_name(ctx.get(), + "name") == 1); + + GPR_ASSERT(strcmp(grpc_auth_context_peer_identity_property_name(ctx.get()), + "name") == 0); + it = grpc_auth_context_property_iterator(ctx.get()); + for (i = 0; i < ctx->properties().count; i++) { const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); - GPR_ASSERT(p == &ctx->properties.array[i]); + GPR_ASSERT(p == &ctx->properties().array[i]); } - for (i = 0; i < chained->properties.count; i++) { + for (i = 0; i < chained_ptr->properties().count; i++) { const grpc_auth_property* p = grpc_auth_property_iterator_next(&it); - GPR_ASSERT(p == &chained->properties.array[i]); + GPR_ASSERT(p == &chained_ptr->properties().array[i]); } GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_find_properties_by_name(ctx, "foo"); + it = grpc_auth_context_find_properties_by_name(ctx.get(), "foo"); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[2]); + &ctx->properties().array[2]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &chained->properties.array[1]); + &chained_ptr->properties().array[1]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - it = grpc_auth_context_peer_identity(ctx); + it = grpc_auth_context_peer_identity(ctx.get()); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[0]); + &ctx->properties().array[0]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &ctx->properties.array[1]); + &ctx->properties().array[1]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == - &chained->properties.array[0]); + &chained_ptr->properties().array[0]); GPR_ASSERT(grpc_auth_property_iterator_next(&it) == nullptr); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } int main(int argc, char** argv) { diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc index a7a6050ec0..b3a8161786 100644 --- a/test/core/security/credentials_test.cc +++ b/test/core/security/credentials_test.cc @@ -46,19 +46,6 @@ using grpc_core::internal::grpc_flush_cached_google_default_credentials; using grpc_core::internal::set_gce_tenancy_checker_for_testing; -/* -- Mock channel credentials. -- */ - -static grpc_channel_credentials* grpc_mock_channel_credentials_create( - const grpc_channel_credentials_vtable* vtable) { - grpc_channel_credentials* c = - static_cast(gpr_malloc(sizeof(*c))); - memset(c, 0, sizeof(*c)); - c->type = "mock"; - c->vtable = vtable; - gpr_ref_init(&c->refcount, 1); - return c; -} - /* -- Constants. -- */ static const char test_google_iam_authorization_token[] = "blahblahblhahb"; @@ -377,9 +364,9 @@ static void run_request_metadata_test(grpc_call_credentials* creds, grpc_auth_metadata_context auth_md_ctx, request_metadata_state* state) { grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - creds, &state->pollent, auth_md_ctx, &state->md_array, - &state->on_request_metadata, &error)) { + if (creds->get_request_metadata(&state->pollent, auth_md_ctx, + &state->md_array, &state->on_request_metadata, + &error)) { // Synchronous result. Invoke the callback directly. check_request_metadata(state, error); GRPC_ERROR_UNREF(error); @@ -400,7 +387,7 @@ static void test_google_iam_creds(void) { grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, nullptr, nullptr}; run_request_metadata_test(creds, auth_md_ctx, state); - grpc_call_credentials_unref(creds); + creds->Unref(); } static void test_access_token_creds(void) { @@ -412,28 +399,36 @@ static void test_access_token_creds(void) { grpc_access_token_credentials_create("blah", nullptr); grpc_auth_metadata_context auth_md_ctx = {test_service_url, test_method, nullptr, nullptr}; - GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); + GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); run_request_metadata_test(creds, auth_md_ctx, state); - grpc_call_credentials_unref(creds); + creds->Unref(); } -static grpc_security_status check_channel_oauth2_create_security_connector( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, - const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - GPR_ASSERT(strcmp(c->type, "mock") == 0); - GPR_ASSERT(call_creds != nullptr); - GPR_ASSERT(strcmp(call_creds->type, GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); - return GRPC_SECURITY_OK; -} +namespace { +class check_channel_oauth2 final : public grpc_channel_credentials { + public: + check_channel_oauth2() : grpc_channel_credentials("mock") {} + ~check_channel_oauth2() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override { + GPR_ASSERT(strcmp(type(), "mock") == 0); + GPR_ASSERT(call_creds != nullptr); + GPR_ASSERT(strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == + 0); + return nullptr; + } +}; +} // namespace static void test_channel_oauth2_composite_creds(void) { grpc_core::ExecCtx exec_ctx; grpc_channel_args* new_args; - grpc_channel_credentials_vtable vtable = { - nullptr, check_channel_oauth2_create_security_connector, nullptr}; grpc_channel_credentials* channel_creds = - grpc_mock_channel_credentials_create(&vtable); + grpc_core::New(); grpc_call_credentials* oauth2_creds = grpc_access_token_credentials_create("blah", nullptr); grpc_channel_credentials* channel_oauth2_creds = @@ -441,9 +436,8 @@ static void test_channel_oauth2_composite_creds(void) { nullptr); grpc_channel_credentials_release(channel_creds); grpc_call_credentials_release(oauth2_creds); - GPR_ASSERT(grpc_channel_credentials_create_security_connector( - channel_oauth2_creds, nullptr, nullptr, nullptr, &new_args) == - GRPC_SECURITY_OK); + channel_oauth2_creds->create_security_connector(nullptr, nullptr, nullptr, + &new_args); grpc_channel_credentials_release(channel_oauth2_creds); } @@ -467,47 +461,54 @@ static void test_oauth2_google_iam_composite_creds(void) { grpc_call_credentials* composite_creds = grpc_composite_call_credentials_create(oauth2_creds, google_iam_creds, nullptr); - grpc_call_credentials_unref(oauth2_creds); - grpc_call_credentials_unref(google_iam_creds); - GPR_ASSERT( - strcmp(composite_creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); - const grpc_call_credentials_array* creds_array = - grpc_composite_call_credentials_get_credentials(composite_creds); - GPR_ASSERT(creds_array->num_creds == 2); - GPR_ASSERT(strcmp(creds_array->creds_array[0]->type, + oauth2_creds->Unref(); + google_iam_creds->Unref(); + GPR_ASSERT(strcmp(composite_creds->type(), + GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); + const grpc_call_credentials_array& creds_array = + static_cast(composite_creds) + ->inner(); + GPR_ASSERT(creds_array.size() == 2); + GPR_ASSERT(strcmp(creds_array.get(0)->type(), GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); - GPR_ASSERT(strcmp(creds_array->creds_array[1]->type, - GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0); + GPR_ASSERT( + strcmp(creds_array.get(1)->type(), GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0); run_request_metadata_test(composite_creds, auth_md_ctx, state); - grpc_call_credentials_unref(composite_creds); + composite_creds->Unref(); } -static grpc_security_status -check_channel_oauth2_google_iam_create_security_connector( - grpc_channel_credentials* c, grpc_call_credentials* call_creds, - const char* target, const grpc_channel_args* args, - grpc_channel_security_connector** sc, grpc_channel_args** new_args) { - const grpc_call_credentials_array* creds_array; - GPR_ASSERT(strcmp(c->type, "mock") == 0); - GPR_ASSERT(call_creds != nullptr); - GPR_ASSERT(strcmp(call_creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == - 0); - creds_array = grpc_composite_call_credentials_get_credentials(call_creds); - GPR_ASSERT(strcmp(creds_array->creds_array[0]->type, - GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); - GPR_ASSERT(strcmp(creds_array->creds_array[1]->type, - GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0); - return GRPC_SECURITY_OK; -} +namespace { +class check_channel_oauth2_google_iam final : public grpc_channel_credentials { + public: + check_channel_oauth2_google_iam() : grpc_channel_credentials("mock") {} + ~check_channel_oauth2_google_iam() override = default; + + grpc_core::RefCountedPtr + create_security_connector( + grpc_core::RefCountedPtr call_creds, + const char* target, const grpc_channel_args* args, + grpc_channel_args** new_args) override { + GPR_ASSERT(strcmp(type(), "mock") == 0); + GPR_ASSERT(call_creds != nullptr); + GPR_ASSERT( + strcmp(call_creds->type(), GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0); + const grpc_call_credentials_array& creds_array = + static_cast(call_creds.get()) + ->inner(); + GPR_ASSERT(strcmp(creds_array.get(0)->type(), + GRPC_CALL_CREDENTIALS_TYPE_OAUTH2) == 0); + GPR_ASSERT(strcmp(creds_array.get(1)->type(), + GRPC_CALL_CREDENTIALS_TYPE_IAM) == 0); + return nullptr; + } +}; +} // namespace static void test_channel_oauth2_google_iam_composite_creds(void) { grpc_core::ExecCtx exec_ctx; grpc_channel_args* new_args; - grpc_channel_credentials_vtable vtable = { - nullptr, check_channel_oauth2_google_iam_create_security_connector, - nullptr}; grpc_channel_credentials* channel_creds = - grpc_mock_channel_credentials_create(&vtable); + grpc_core::New(); grpc_call_credentials* oauth2_creds = grpc_access_token_credentials_create("blah", nullptr); grpc_channel_credentials* channel_oauth2_creds = @@ -524,9 +525,8 @@ static void test_channel_oauth2_google_iam_composite_creds(void) { grpc_channel_credentials_release(channel_oauth2_creds); grpc_call_credentials_release(google_iam_creds); - GPR_ASSERT(grpc_channel_credentials_create_security_connector( - channel_oauth2_iam_creds, nullptr, nullptr, nullptr, - &new_args) == GRPC_SECURITY_OK); + channel_oauth2_iam_creds->create_security_connector(nullptr, nullptr, nullptr, + &new_args); grpc_channel_credentials_release(channel_oauth2_iam_creds); } @@ -578,7 +578,7 @@ static int httpcli_get_should_not_be_called(const grpc_httpcli_request* request, return 1; } -static void test_compute_engine_creds_success(void) { +static void test_compute_engine_creds_success() { grpc_core::ExecCtx exec_ctx; expected_md emd[] = { {"authorization", "Bearer ya29.AHES6ZRN3-HlhAPya30GnW_bHSb_"}}; @@ -603,7 +603,7 @@ static void test_compute_engine_creds_success(void) { run_request_metadata_test(creds, auth_md_ctx, state); grpc_core::ExecCtx::Get()->Flush(); - grpc_call_credentials_unref(creds); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); } @@ -620,7 +620,7 @@ static void test_compute_engine_creds_failure(void) { grpc_httpcli_set_override(compute_engine_httpcli_get_failure_override, httpcli_post_should_not_be_called); run_request_metadata_test(creds, auth_md_ctx, state); - grpc_call_credentials_unref(creds); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); } @@ -692,7 +692,7 @@ static void test_refresh_token_creds_success(void) { run_request_metadata_test(creds, auth_md_ctx, state); grpc_core::ExecCtx::Get()->Flush(); - grpc_call_credentials_unref(creds); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); } @@ -709,7 +709,7 @@ static void test_refresh_token_creds_failure(void) { grpc_httpcli_set_override(httpcli_get_should_not_be_called, refresh_token_httpcli_post_failure); run_request_metadata_test(creds, auth_md_ctx, state); - grpc_call_credentials_unref(creds); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); } @@ -762,7 +762,7 @@ static char* encode_and_sign_jwt_should_not_be_called( static grpc_service_account_jwt_access_credentials* creds_as_jwt( grpc_call_credentials* creds) { GPR_ASSERT(creds != nullptr); - GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_JWT) == 0); + GPR_ASSERT(strcmp(creds->type(), GRPC_CALL_CREDENTIALS_TYPE_JWT) == 0); return reinterpret_cast(creds); } @@ -773,7 +773,7 @@ static void test_jwt_creds_lifetime(void) { grpc_call_credentials* jwt_creds = grpc_service_account_jwt_access_credentials_create( json_key_string, grpc_max_auth_token_lifetime(), nullptr); - GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime, + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), grpc_max_auth_token_lifetime()) == 0); grpc_call_credentials_release(jwt_creds); @@ -782,8 +782,8 @@ static void test_jwt_creds_lifetime(void) { GPR_ASSERT(gpr_time_cmp(grpc_max_auth_token_lifetime(), token_lifetime) > 0); jwt_creds = grpc_service_account_jwt_access_credentials_create( json_key_string, token_lifetime, nullptr); - GPR_ASSERT( - gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime, token_lifetime) == 0); + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), + token_lifetime) == 0); grpc_call_credentials_release(jwt_creds); // Cropped lifetime. @@ -791,7 +791,7 @@ static void test_jwt_creds_lifetime(void) { token_lifetime = gpr_time_add(grpc_max_auth_token_lifetime(), add_to_max); jwt_creds = grpc_service_account_jwt_access_credentials_create( json_key_string, token_lifetime, nullptr); - GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime, + GPR_ASSERT(gpr_time_cmp(creds_as_jwt(jwt_creds)->jwt_lifetime(), grpc_max_auth_token_lifetime()) == 0); grpc_call_credentials_release(jwt_creds); @@ -834,7 +834,7 @@ static void test_jwt_creds_success(void) { run_request_metadata_test(creds, auth_md_ctx, state); grpc_core::ExecCtx::Get()->Flush(); - grpc_call_credentials_unref(creds); + creds->Unref(); gpr_free(json_key_string); gpr_free(expected_md_value); grpc_jwt_encode_and_sign_set_override(nullptr); @@ -856,7 +856,7 @@ static void test_jwt_creds_signing_failure(void) { run_request_metadata_test(creds, auth_md_ctx, state); gpr_free(json_key_string); - grpc_call_credentials_unref(creds); + creds->Unref(); grpc_jwt_encode_and_sign_set_override(nullptr); } @@ -875,8 +875,6 @@ static void set_google_default_creds_env_var_with_file_contents( static void test_google_default_creds_auth_key(void) { grpc_core::ExecCtx exec_ctx; - grpc_service_account_jwt_access_credentials* jwt; - grpc_google_default_channel_credentials* default_creds; grpc_composite_channel_credentials* creds; char* json_key = test_json_key_str(); grpc_flush_cached_google_default_credentials(); @@ -885,37 +883,39 @@ static void test_google_default_creds_auth_key(void) { gpr_free(json_key); creds = reinterpret_cast( grpc_google_default_credentials_create()); - default_creds = reinterpret_cast( - creds->inner_creds); - GPR_ASSERT(default_creds->ssl_creds != nullptr); - jwt = reinterpret_cast( - creds->call_creds); + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* jwt = + reinterpret_cast( + creds->call_creds()); GPR_ASSERT( - strcmp(jwt->key.client_id, + strcmp(jwt->key().client_id, "777-abaslkan11hlb6nmim3bpspl31ud.apps.googleusercontent.com") == 0); - grpc_channel_credentials_unref(&creds->base); + creds->Unref(); gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ } static void test_google_default_creds_refresh_token(void) { grpc_core::ExecCtx exec_ctx; - grpc_google_refresh_token_credentials* refresh; - grpc_google_default_channel_credentials* default_creds; grpc_composite_channel_credentials* creds; grpc_flush_cached_google_default_credentials(); set_google_default_creds_env_var_with_file_contents( "refresh_token_google_default_creds", test_refresh_token_str); creds = reinterpret_cast( grpc_google_default_credentials_create()); - default_creds = reinterpret_cast( - creds->inner_creds); - GPR_ASSERT(default_creds->ssl_creds != nullptr); - refresh = reinterpret_cast( - creds->call_creds); - GPR_ASSERT(strcmp(refresh->refresh_token.client_id, + auto* default_creds = + reinterpret_cast( + creds->inner_creds()); + GPR_ASSERT(default_creds->ssl_creds() != nullptr); + auto* refresh = + reinterpret_cast( + creds->call_creds()); + GPR_ASSERT(strcmp(refresh->refresh_token().client_id, "32555999999.apps.googleusercontent.com") == 0); - grpc_channel_credentials_unref(&creds->base); + creds->Unref(); gpr_setenv(GRPC_GOOGLE_CREDENTIALS_ENV_VAR, ""); /* Reset. */ } @@ -965,16 +965,16 @@ static void test_google_default_creds_gce(void) { /* Verify that the default creds actually embeds a GCE creds. */ GPR_ASSERT(creds != nullptr); - GPR_ASSERT(creds->call_creds != nullptr); + GPR_ASSERT(creds->call_creds() != nullptr); grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, httpcli_post_should_not_be_called); - run_request_metadata_test(creds->call_creds, auth_md_ctx, state); + run_request_metadata_test(creds->mutable_call_creds(), auth_md_ctx, state); grpc_core::ExecCtx::Get()->Flush(); GPR_ASSERT(g_test_gce_tenancy_checker_called == true); /* Cleanup. */ - grpc_channel_credentials_unref(&creds->base); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); grpc_override_well_known_credentials_path_getter(nullptr); } @@ -1003,14 +1003,14 @@ static void test_google_default_creds_non_gce(void) { grpc_google_default_credentials_create()); /* Verify that the default creds actually embeds a GCE creds. */ GPR_ASSERT(creds != nullptr); - GPR_ASSERT(creds->call_creds != nullptr); + GPR_ASSERT(creds->call_creds() != nullptr); grpc_httpcli_set_override(compute_engine_httpcli_get_success_override, httpcli_post_should_not_be_called); - run_request_metadata_test(creds->call_creds, auth_md_ctx, state); + run_request_metadata_test(creds->mutable_call_creds(), auth_md_ctx, state); grpc_core::ExecCtx::Get()->Flush(); GPR_ASSERT(g_test_gce_tenancy_checker_called == true); /* Cleanup. */ - grpc_channel_credentials_unref(&creds->base); + creds->Unref(); grpc_httpcli_set_override(nullptr, nullptr); grpc_override_well_known_credentials_path_getter(nullptr); } @@ -1121,7 +1121,7 @@ static void test_metadata_plugin_success(void) { GPR_ASSERT(state == PLUGIN_INITIAL_STATE); run_request_metadata_test(creds, auth_md_ctx, md_state); GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE); - grpc_call_credentials_unref(creds); + creds->Unref(); GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE); } @@ -1149,7 +1149,7 @@ static void test_metadata_plugin_failure(void) { GPR_ASSERT(state == PLUGIN_INITIAL_STATE); run_request_metadata_test(creds, auth_md_ctx, md_state); GPR_ASSERT(state == PLUGIN_GET_METADATA_CALLED_STATE); - grpc_call_credentials_unref(creds); + creds->Unref(); GPR_ASSERT(state == PLUGIN_DESTROY_CALLED_STATE); } @@ -1176,25 +1176,23 @@ static void test_channel_creds_duplicate_without_call_creds(void) { grpc_channel_credentials* channel_creds = grpc_fake_transport_security_credentials_create(); - grpc_channel_credentials* dup = - grpc_channel_credentials_duplicate_without_call_credentials( - channel_creds); + grpc_core::RefCountedPtr dup = + channel_creds->duplicate_without_call_credentials(); GPR_ASSERT(dup == channel_creds); - grpc_channel_credentials_unref(dup); + dup.reset(); grpc_call_credentials* call_creds = grpc_access_token_credentials_create("blah", nullptr); grpc_channel_credentials* composite_creds = grpc_composite_channel_credentials_create(channel_creds, call_creds, nullptr); - grpc_call_credentials_unref(call_creds); - dup = grpc_channel_credentials_duplicate_without_call_credentials( - composite_creds); + call_creds->Unref(); + dup = composite_creds->duplicate_without_call_credentials(); GPR_ASSERT(dup == channel_creds); - grpc_channel_credentials_unref(dup); + dup.reset(); - grpc_channel_credentials_unref(channel_creds); - grpc_channel_credentials_unref(composite_creds); + channel_creds->Unref(); + composite_creds->Unref(); } typedef struct { diff --git a/test/core/security/oauth2_utils.cc b/test/core/security/oauth2_utils.cc index 469129a6d0..c9e205ab74 100644 --- a/test/core/security/oauth2_utils.cc +++ b/test/core/security/oauth2_utils.cc @@ -86,9 +86,8 @@ char* grpc_test_fetch_oauth2_token_with_credentials( grpc_schedule_on_exec_ctx); grpc_error* error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata(creds, &request.pops, null_ctx, - &request.md_array, - &request.closure, &error)) { + if (creds->get_request_metadata(&request.pops, null_ctx, &request.md_array, + &request.closure, &error)) { // Synchronous result; invoke callback directly. on_oauth2_response(&request, error); GRPC_ERROR_UNREF(error); diff --git a/test/core/security/print_google_default_creds_token.cc b/test/core/security/print_google_default_creds_token.cc index 4d251391ff..398c58c6e1 100644 --- a/test/core/security/print_google_default_creds_token.cc +++ b/test/core/security/print_google_default_creds_token.cc @@ -96,11 +96,10 @@ int main(int argc, char** argv) { grpc_schedule_on_exec_ctx); error = GRPC_ERROR_NONE; - if (grpc_call_credentials_get_request_metadata( - (reinterpret_cast(creds)) - ->call_creds, - &sync.pops, context, &sync.md_array, &sync.on_request_metadata, - &error)) { + if (reinterpret_cast(creds) + ->mutable_call_creds() + ->get_request_metadata(&sync.pops, context, &sync.md_array, + &sync.on_request_metadata, &error)) { // Synchronous response. Invoke callback directly. on_metadata_response(&sync, error); GRPC_ERROR_UNREF(error); diff --git a/test/core/security/security_connector_test.cc b/test/core/security/security_connector_test.cc index e82a8627d4..2a31763c73 100644 --- a/test/core/security/security_connector_test.cc +++ b/test/core/security/security_connector_test.cc @@ -27,6 +27,7 @@ #include "src/core/lib/gpr/env.h" #include "src/core/lib/gpr/string.h" #include "src/core/lib/gpr/tmpfile.h" +#include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/lib/security/context/security_context.h" #include "src/core/lib/security/security_connector/security_connector.h" #include "src/core/lib/security/security_connector/ssl_utils.h" @@ -83,22 +84,22 @@ static int check_ssl_peer_equivalence(const tsi_peer* original, static void test_unauthenticated_ssl_peer(void) { tsi_peer peer; tsi_peer rpeer; - grpc_auth_context* ctx; GPR_ASSERT(tsi_construct_peer(1, &peer) == TSI_OK); GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE, &peer.properties[0]) == TSI_OK); - ctx = grpc_ssl_peer_to_auth_context(&peer); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(!grpc_auth_context_peer_is_authenticated(ctx)); - GPR_ASSERT(check_transport_security_type(ctx)); + GPR_ASSERT(!grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_transport_security_type(ctx.get())); - rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx); + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); grpc_shallow_peer_destruct(&rpeer); tsi_peer_destruct(&peer); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static int check_identity(const grpc_auth_context* ctx, @@ -175,7 +176,6 @@ static int check_x509_pem_cert(const grpc_auth_context* ctx, static void test_cn_only_ssl_peer_to_auth_context(void) { tsi_peer peer; tsi_peer rpeer; - grpc_auth_context* ctx; const char* expected_cn = "cn1"; const char* expected_pem_cert = "pem_cert1"; GPR_ASSERT(tsi_construct_peer(3, &peer) == TSI_OK); @@ -188,26 +188,27 @@ static void test_cn_only_ssl_peer_to_auth_context(void) { GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, &peer.properties[2]) == TSI_OK); - ctx = grpc_ssl_peer_to_auth_context(&peer); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); - GPR_ASSERT(check_identity(ctx, GRPC_X509_CN_PROPERTY_NAME, &expected_cn, 1)); - GPR_ASSERT(check_transport_security_type(ctx)); - GPR_ASSERT(check_x509_cn(ctx, expected_cn)); - GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert)); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT( + check_identity(ctx.get(), GRPC_X509_CN_PROPERTY_NAME, &expected_cn, 1)); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); - rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx); + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); grpc_shallow_peer_destruct(&rpeer); tsi_peer_destruct(&peer); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static void test_cn_and_one_san_ssl_peer_to_auth_context(void) { tsi_peer peer; tsi_peer rpeer; - grpc_auth_context* ctx; const char* expected_cn = "cn1"; const char* expected_san = "san1"; const char* expected_pem_cert = "pem_cert1"; @@ -224,27 +225,28 @@ static void test_cn_and_one_san_ssl_peer_to_auth_context(void) { GPR_ASSERT(tsi_construct_string_peer_property_from_cstring( TSI_X509_PEM_CERT_PROPERTY, expected_pem_cert, &peer.properties[3]) == TSI_OK); - ctx = grpc_ssl_peer_to_auth_context(&peer); + + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); GPR_ASSERT( - check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, &expected_san, 1)); - GPR_ASSERT(check_transport_security_type(ctx)); - GPR_ASSERT(check_x509_cn(ctx, expected_cn)); - GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert)); + check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, &expected_san, 1)); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); - rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx); + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); grpc_shallow_peer_destruct(&rpeer); tsi_peer_destruct(&peer); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static void test_cn_and_multiple_sans_ssl_peer_to_auth_context(void) { tsi_peer peer; tsi_peer rpeer; - grpc_auth_context* ctx; const char* expected_cn = "cn1"; const char* expected_sans[] = {"san1", "san2", "san3"}; const char* expected_pem_cert = "pem_cert1"; @@ -265,28 +267,28 @@ static void test_cn_and_multiple_sans_ssl_peer_to_auth_context(void) { TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, expected_sans[i], &peer.properties[3 + i]) == TSI_OK); } - ctx = grpc_ssl_peer_to_auth_context(&peer); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); - GPR_ASSERT(check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, expected_sans, - GPR_ARRAY_SIZE(expected_sans))); - GPR_ASSERT(check_transport_security_type(ctx)); - GPR_ASSERT(check_x509_cn(ctx, expected_cn)); - GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert)); - - rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, + expected_sans, GPR_ARRAY_SIZE(expected_sans))); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); grpc_shallow_peer_destruct(&rpeer); tsi_peer_destruct(&peer); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static void test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context( void) { tsi_peer peer; tsi_peer rpeer; - grpc_auth_context* ctx; const char* expected_cn = "cn1"; const char* expected_pem_cert = "pem_cert1"; const char* expected_sans[] = {"san1", "san2", "san3"}; @@ -311,21 +313,22 @@ static void test_cn_and_multiple_sans_and_others_ssl_peer_to_auth_context( TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, expected_sans[i], &peer.properties[5 + i]) == TSI_OK); } - ctx = grpc_ssl_peer_to_auth_context(&peer); + grpc_core::RefCountedPtr ctx = + grpc_ssl_peer_to_auth_context(&peer); GPR_ASSERT(ctx != nullptr); - GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx)); - GPR_ASSERT(check_identity(ctx, GRPC_X509_SAN_PROPERTY_NAME, expected_sans, - GPR_ARRAY_SIZE(expected_sans))); - GPR_ASSERT(check_transport_security_type(ctx)); - GPR_ASSERT(check_x509_cn(ctx, expected_cn)); - GPR_ASSERT(check_x509_pem_cert(ctx, expected_pem_cert)); - - rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx); + GPR_ASSERT(grpc_auth_context_peer_is_authenticated(ctx.get())); + GPR_ASSERT(check_identity(ctx.get(), GRPC_X509_SAN_PROPERTY_NAME, + expected_sans, GPR_ARRAY_SIZE(expected_sans))); + GPR_ASSERT(check_transport_security_type(ctx.get())); + GPR_ASSERT(check_x509_cn(ctx.get(), expected_cn)); + GPR_ASSERT(check_x509_pem_cert(ctx.get(), expected_pem_cert)); + + rpeer = grpc_shallow_peer_from_ssl_auth_context(ctx.get()); GPR_ASSERT(check_ssl_peer_equivalence(&peer, &rpeer)); grpc_shallow_peer_destruct(&rpeer); tsi_peer_destruct(&peer); - GRPC_AUTH_CONTEXT_UNREF(ctx, "test"); + ctx.reset(DEBUG_LOCATION, "test"); } static const char* roots_for_override_api = "roots for override api"; diff --git a/test/core/security/ssl_server_fuzzer.cc b/test/core/security/ssl_server_fuzzer.cc index d2bbb7c1c2..c9380126dd 100644 --- a/test/core/security/ssl_server_fuzzer.cc +++ b/test/core/security/ssl_server_fuzzer.cc @@ -82,16 +82,15 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { ca_cert, &pem_key_cert_pair, 1, 0, nullptr); // Create security connector - grpc_server_security_connector* sc = nullptr; - grpc_security_status status = - grpc_server_credentials_create_security_connector(creds, &sc); - GPR_ASSERT(status == GRPC_SECURITY_OK); + grpc_core::RefCountedPtr sc = + creds->create_security_connector(); + GPR_ASSERT(sc != nullptr); grpc_millis deadline = GPR_MS_PER_SEC + grpc_core::ExecCtx::Get()->Now(); struct handshake_state state; state.done_callback_called = false; grpc_handshake_manager* handshake_mgr = grpc_handshake_manager_create(); - grpc_server_security_connector_add_handshakers(sc, nullptr, handshake_mgr); + sc->add_handshakers(nullptr, handshake_mgr); grpc_handshake_manager_do_handshake( handshake_mgr, mock_endpoint, nullptr /* channel_args */, deadline, nullptr /* acceptor */, on_handshake_done, &state); @@ -110,7 +109,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { GPR_ASSERT(state.done_callback_called); grpc_handshake_manager_destroy(handshake_mgr); - GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "test"); + sc.reset(DEBUG_LOCATION, "test"); grpc_server_credentials_release(creds); grpc_slice_unref(cert_slice); grpc_slice_unref(key_slice); diff --git a/test/core/surface/secure_channel_create_test.cc b/test/core/surface/secure_channel_create_test.cc index 5610d1ec4a..e9bb815f6e 100644 --- a/test/core/surface/secure_channel_create_test.cc +++ b/test/core/surface/secure_channel_create_test.cc @@ -39,7 +39,7 @@ void test_unknown_scheme_target(void) { GPR_ASSERT(0 == strcmp(elem->filter->name, "lame-client")); grpc_core::ExecCtx exec_ctx; GRPC_CHANNEL_INTERNAL_UNREF(chan, "test"); - grpc_channel_credentials_unref(creds); + creds->Unref(); } void test_security_connector_already_in_arg(void) { diff --git a/test/cpp/common/auth_property_iterator_test.cc b/test/cpp/common/auth_property_iterator_test.cc index 9634555e4b..5a2844d518 100644 --- a/test/cpp/common/auth_property_iterator_test.cc +++ b/test/cpp/common/auth_property_iterator_test.cc @@ -40,15 +40,14 @@ class TestAuthPropertyIterator : public AuthPropertyIterator { class AuthPropertyIteratorTest : public ::testing::Test { protected: void SetUp() override { - ctx_ = grpc_auth_context_create(nullptr); - grpc_auth_context_add_cstring_property(ctx_, "name", "chapi"); - grpc_auth_context_add_cstring_property(ctx_, "name", "chapo"); - grpc_auth_context_add_cstring_property(ctx_, "foo", "bar"); - EXPECT_EQ(1, - grpc_auth_context_set_peer_identity_property_name(ctx_, "name")); + ctx_ = grpc_core::MakeRefCounted(nullptr); + grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapi"); + grpc_auth_context_add_cstring_property(ctx_.get(), "name", "chapo"); + grpc_auth_context_add_cstring_property(ctx_.get(), "foo", "bar"); + EXPECT_EQ(1, grpc_auth_context_set_peer_identity_property_name(ctx_.get(), + "name")); } - void TearDown() override { grpc_auth_context_release(ctx_); } - grpc_auth_context* ctx_; + grpc_core::RefCountedPtr ctx_; }; TEST_F(AuthPropertyIteratorTest, DefaultCtor) { @@ -59,7 +58,7 @@ TEST_F(AuthPropertyIteratorTest, DefaultCtor) { TEST_F(AuthPropertyIteratorTest, GeneralTest) { grpc_auth_property_iterator c_iter = - grpc_auth_context_property_iterator(ctx_); + grpc_auth_context_property_iterator(ctx_.get()); const grpc_auth_property* property = grpc_auth_property_iterator_next(&c_iter); TestAuthPropertyIterator iter(property, &c_iter); diff --git a/test/cpp/common/secure_auth_context_test.cc b/test/cpp/common/secure_auth_context_test.cc index 6461f49743..03b3a9fdd8 100644 --- a/test/cpp/common/secure_auth_context_test.cc +++ b/test/cpp/common/secure_auth_context_test.cc @@ -33,7 +33,7 @@ class SecureAuthContextTest : public ::testing::Test {}; // Created with nullptr TEST_F(SecureAuthContextTest, EmptyContext) { - SecureAuthContext context(nullptr, true); + SecureAuthContext context(nullptr); EXPECT_TRUE(context.GetPeerIdentity().empty()); EXPECT_TRUE(context.GetPeerIdentityPropertyName().empty()); EXPECT_TRUE(context.FindPropertyValues("").empty()); @@ -42,8 +42,10 @@ TEST_F(SecureAuthContextTest, EmptyContext) { } TEST_F(SecureAuthContextTest, Properties) { - grpc_auth_context* ctx = grpc_auth_context_create(nullptr); - SecureAuthContext context(ctx, true); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + SecureAuthContext context(ctx.get()); + ctx.reset(); context.AddProperty("name", "chapi"); context.AddProperty("name", "chapo"); context.AddProperty("foo", "bar"); @@ -60,8 +62,10 @@ TEST_F(SecureAuthContextTest, Properties) { } TEST_F(SecureAuthContextTest, Iterators) { - grpc_auth_context* ctx = grpc_auth_context_create(nullptr); - SecureAuthContext context(ctx, true); + grpc_core::RefCountedPtr ctx = + grpc_core::MakeRefCounted(nullptr); + SecureAuthContext context(ctx.get()); + ctx.reset(); context.AddProperty("name", "chapi"); context.AddProperty("name", "chapo"); context.AddProperty("foo", "bar"); diff --git a/test/cpp/end2end/grpclb_end2end_test.cc b/test/cpp/end2end/grpclb_end2end_test.cc index bf990a07b5..2eaacd429d 100644 --- a/test/cpp/end2end/grpclb_end2end_test.cc +++ b/test/cpp/end2end/grpclb_end2end_test.cc @@ -414,8 +414,8 @@ class GrpclbEnd2endTest : public ::testing::Test { std::shared_ptr creds( new SecureChannelCredentials(grpc_composite_channel_credentials_create( channel_creds, call_creds, nullptr))); - grpc_call_credentials_unref(call_creds); - grpc_channel_credentials_unref(channel_creds); + call_creds->Unref(); + channel_creds->Unref(); channel_ = CreateCustomChannel(uri.str(), creds, args); stub_ = grpc::testing::EchoTestService::NewStub(channel_); } -- cgit v1.2.3 From 82797771679c93045876f48707c5cb46bfae4f53 Mon Sep 17 00:00:00 2001 From: easy Date: Wed, 19 Dec 2018 12:22:38 +1100 Subject: Destruct CensusContext to avoid leaking memory. Otherwise, the placement-new leaks context -> span_ -> impl_ which is a std::shared_ptr. --- src/cpp/ext/filters/census/context.cc | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'src/cpp') diff --git a/src/cpp/ext/filters/census/context.cc b/src/cpp/ext/filters/census/context.cc index 78fc69a805..160590353a 100644 --- a/src/cpp/ext/filters/census/context.cc +++ b/src/cpp/ext/filters/census/context.cc @@ -28,6 +28,9 @@ using ::opencensus::trace::SpanContext; void GenerateServerContext(absl::string_view tracing, absl::string_view stats, absl::string_view primary_role, absl::string_view method, CensusContext* context) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + context->~CensusContext(); GrpcTraceContext trace_ctxt; if (TraceContextEncoding::Decode(tracing, &trace_ctxt) != TraceContextEncoding::kEncodeDecodeFailure) { @@ -42,6 +45,9 @@ void GenerateServerContext(absl::string_view tracing, absl::string_view stats, void GenerateClientContext(absl::string_view method, CensusContext* ctxt, CensusContext* parent_ctxt) { + // Destruct the current CensusContext to free the Span memory before + // overwriting it below. + ctxt->~CensusContext(); if (parent_ctxt != nullptr) { SpanContext span_ctxt = parent_ctxt->Context(); Span span = parent_ctxt->Span(); -- cgit v1.2.3 From d6dd6f25f4e900d6099098d5d1f0bd52f0581750 Mon Sep 17 00:00:00 2001 From: yang-g Date: Thu, 20 Dec 2018 15:37:30 -0800 Subject: Correctly reference the internal string for socket mutator arg --- src/cpp/common/channel_arguments.cc | 2 ++ test/cpp/common/channel_arguments_test.cc | 3 +++ 2 files changed, 5 insertions(+) (limited to 'src/cpp') diff --git a/src/cpp/common/channel_arguments.cc b/src/cpp/common/channel_arguments.cc index 50ee9d871f..214d72f853 100644 --- a/src/cpp/common/channel_arguments.cc +++ b/src/cpp/common/channel_arguments.cc @@ -106,7 +106,9 @@ void ChannelArguments::SetSocketMutator(grpc_socket_mutator* mutator) { } if (!replaced) { + strings_.push_back(grpc::string(mutator_arg.key)); args_.push_back(mutator_arg); + args_.back().key = const_cast(strings_.back().c_str()); } } diff --git a/test/cpp/common/channel_arguments_test.cc b/test/cpp/common/channel_arguments_test.cc index 183d2afa78..12fd9784f4 100644 --- a/test/cpp/common/channel_arguments_test.cc +++ b/test/cpp/common/channel_arguments_test.cc @@ -209,6 +209,9 @@ TEST_F(ChannelArgumentsTest, SetSocketMutator) { channel_args_.SetSocketMutator(mutator0); EXPECT_TRUE(HasArg(arg0)); + // Exercise the copy constructor because we ran some sanity checks in it. + grpc::ChannelArguments new_args{channel_args_}; + channel_args_.SetSocketMutator(mutator1); EXPECT_TRUE(HasArg(arg1)); // arg0 is replaced by arg1 -- cgit v1.2.3