diff options
author | 2018-10-28 23:36:59 -0700 | |
---|---|---|
committer | 2018-10-28 23:36:59 -0700 | |
commit | 01313976e1a44b5c9625d3a349fffa55471beff4 (patch) | |
tree | 805596796cce33154e0d875c4c1ade918ba958f2 /test | |
parent | ffac9d90b18cb076b1c952faa55ce4e049cbc9a6 (diff) | |
parent | 395edbfa24968b8406a0c157874d3cb473076df5 (diff) |
Merge pull request #16842 from yashykt/interceptors
Experimental API for Client and Server Interception
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/end2end/BUILD | 53 | ||||
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 606 | ||||
-rw-r--r-- | test/cpp/end2end/interceptors_util.h | 308 | ||||
-rw-r--r-- | test/cpp/end2end/server_interceptors_end2end_test.cc | 623 | ||||
-rw-r--r-- | test/cpp/interop/client_helper.h | 2 |
5 files changed, 1592 insertions, 0 deletions
diff --git a/test/cpp/end2end/BUILD b/test/cpp/end2end/BUILD index 0415efc1ef..235249e8bf 100644 --- a/test/cpp/end2end/BUILD +++ b/test/cpp/end2end/BUILD @@ -35,6 +35,19 @@ grpc_cc_library( ], ) +grpc_cc_library( + name = "interceptors_util", + testonly = True, + hdrs = ["interceptors_util.h"], + external_deps = [ + "gtest", + ], + deps = [ + "//src/proto/grpc/testing:echo_proto", + "//test/cpp/util:test_util", + ], +) + grpc_cc_test( name = "async_end2end_test", srcs = ["async_end2end_test.cc"], @@ -117,6 +130,26 @@ grpc_cc_test( ], ) +grpc_cc_test( + name = "client_interceptors_end2end_test", + srcs = ["client_interceptors_end2end_test.cc"], + external_deps = [ + "gtest", + ], + deps = [ + ":interceptors_util", + ":test_service_impl", + "//:gpr", + "//:grpc", + "//:grpc++", + "//src/proto/grpc/testing:echo_messages_proto", + "//src/proto/grpc/testing:echo_proto", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + "//test/cpp/util:test_util", + ], +) + grpc_cc_library( name = "end2end_test_lib", testonly = True, @@ -470,6 +503,26 @@ grpc_cc_binary( ) grpc_cc_test( + name = "server_interceptors_end2end_test", + srcs = ["server_interceptors_end2end_test.cc"], + external_deps = [ + "gtest", + ], + deps = [ + ":interceptors_util", + ":test_service_impl", + "//:gpr", + "//:grpc", + "//:grpc++", + "//src/proto/grpc/testing:echo_messages_proto", + "//src/proto/grpc/testing:echo_proto", + "//test/core/util:gpr_test_util", + "//test/core/util:grpc_test_util", + "//test/cpp/util:test_util", + ], +) + +grpc_cc_test( name = "server_load_reporting_end2end_test", srcs = ["server_load_reporting_end2end_test.cc"], external_deps = [ diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc new file mode 100644 index 0000000000..5720e87478 --- /dev/null +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -0,0 +1,606 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/client_interceptor.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsStreamingEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr<Server> server_; +}; + +class ClientInterceptorsEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor(experimental::ClientRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA)) { + num_times_run_reverse_++; + } + methods->Proceed(); + } + + static void Reset() { + num_times_run_.store(0); + num_times_run_reverse_.store(0); + } + + static int GetNumTimesRun() { + EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); + return num_times_run_.load(); + } + + private: + static std::atomic<int> num_times_run_; + static std::atomic<int> num_times_run_reverse_; +}; + +std::atomic<int> DummyInterceptor::num_times_run_; +std::atomic<int> DummyInterceptor::num_times_run_reverse_; + +class DummyInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +/* Hijacks Echo RPC and fills in the expected values */ +class HijackingInterceptor : public experimental::Interceptor { + public: + HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + req_ = req; + stub_ = grpc::testing::EchoTestService::NewStub( + methods->GetInterceptedChannel()); + ctx_.AddMetadata(metadata_map_.begin()->first, + metadata_map_.begin()->second); + stub_->experimental_async()->Echo(&ctx_, &req_, &resp_, + [this, methods](Status s) { + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp_.message(), "Hello"); + methods->Hijack(); + }); + // There isn't going to be any other interesting operation in this batch, + // so it is fine to return + return; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message(resp_.message()); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap<grpc::string, grpc::string> metadata_map_; + ClientContext ctx_; + EchoRequest req_; + EchoResponse resp_; + std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_; +}; + +class HijackingInterceptorMakesAnotherCallFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptorMakesAnotherCall(info); + } +}; + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + methods->Proceed(); + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + // Add 20 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators->push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + // Add 20 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + MakeCall(channel); + // Make sure only 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { + ChannelArguments args; + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + creators->push_back(std::unique_ptr<HijackingInterceptorFactory>( + new HijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + MakeCall(channel); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorHijackingMakesAnotherCallTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + // Add 5 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 5; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + creators->push_back( + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>( + new HijackingInterceptorMakesAnotherCallFactory())); + // Add 7 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 7; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + + MakeCall(channel); + // Make sure all interceptors were run once, since the hijacking interceptor + // makes an RPC on the intercepted channel + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12); +} + +TEST_F(ClientInterceptorsEnd2endTest, + ClientInterceptorLoggingTestWithCallback) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = server_->experimental().InProcessChannelWithInterceptors( + args, std::move(creators)); + MakeCallbackCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeClientStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr<std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>( + new std::vector< + std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>()); + creators->push_back(std::unique_ptr<LoggingInterceptorFactory>( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc_test_init(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h new file mode 100644 index 0000000000..5f0aa37dc0 --- /dev/null +++ b/test/cpp/end2end/interceptors_util.h @@ -0,0 +1,308 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/cpp/util/string_ref_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +class EchoTestServiceStreamingImpl : public EchoTestService::Service { + public: + ~EchoTestServiceStreamingImpl() override {} + + Status BidiStream( + ServerContext* context, + grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest req; + EchoResponse resp; + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + while (stream->Read(&req)) { + resp.set_message(req.message()); + EXPECT_TRUE(stream->Write(resp, grpc::WriteOptions())); + } + return Status::OK; + } + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* resp) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoRequest req; + string response_str = ""; + while (reader->Read(&req)) { + response_str += req.message(); + } + resp->set_message(response_str); + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* req, + ServerWriter<EchoResponse>* writer) override { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + + EchoResponse resp; + resp.set_message(req->message()); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(resp)); + } + return Status::OK; + } +}; + +void MakeCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + +void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + writer->Write(req); + expected_resp += "Hello"; + } + writer->WritesDone(); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), expected_resp); +} + +void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + string expected_resp = ""; + auto reader = stub->ResponseStream(&ctx, req); + int count = 0; + while (reader->Read(&resp)) { + EXPECT_EQ(resp.message(), "Hello"); + count++; + } + ASSERT_EQ(count, 10); + Status s = reader->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +void MakeCallbackCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + std::mutex mu; + std::condition_variable cv; + bool done = false; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + stub->experimental_async()->Echo(&ctx, &req, &resp, + [&resp, &mu, &done, &cv](Status s) { + // gpr_log(GPR_ERROR, "got the callback"); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + std::lock_guard<std::mutex> l(mu); + done = true; + cv.notify_one(); + }); + std::unique_lock<std::mutex> l(mu); + while (!done) { + cv.wait(l); + } +} + +bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first.starts_with(key) && pair.second.starts_with(value)) { + return true; + } + } + return false; +} + +void* tag(int i) { return (void*)static_cast<intptr_t>(i); } +int detag(void* p) { return static_cast<int>(reinterpret_cast<intptr_t>(p)); } + +class Verifier { + public: + Verifier() : lambda_run_(false) {} + // Expect sets the expected ok value for a specific tag + Verifier& Expect(int i, bool expect_ok) { + return ExpectUnless(i, expect_ok, false); + } + // ExpectUnless sets the expected ok value for a specific tag + // unless the tag was already marked seen (as a result of ExpectMaybe) + Verifier& ExpectUnless(int i, bool expect_ok, bool seen) { + if (!seen) { + expectations_[tag(i)] = expect_ok; + } + return *this; + } + // ExpectMaybe sets the expected ok value for a specific tag, but does not + // require it to appear + // If it does, sets *seen to true + Verifier& ExpectMaybe(int i, bool expect_ok, bool* seen) { + if (!*seen) { + maybe_expectations_[tag(i)] = MaybeExpect{expect_ok, seen}; + } + return *this; + } + + // Next waits for 1 async tag to complete, checks its + // expectations, and returns the tag + int Next(CompletionQueue* cq, bool ignore_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + GotTag(got_tag, ok, ignore_ok); + return detag(got_tag); + } + + template <typename T> + CompletionQueue::NextStatus DoOnceThenAsyncNext( + CompletionQueue* cq, void** got_tag, bool* ok, T deadline, + std::function<void(void)> lambda) { + if (lambda_run_) { + return cq->AsyncNext(got_tag, ok, deadline); + } else { + lambda_run_ = true; + return cq->DoThenAsyncNext(lambda, got_tag, ok, deadline); + } + } + + // Verify keeps calling Next until all currently set + // expected tags are complete + void Verify(CompletionQueue* cq) { Verify(cq, false); } + + // This version of Verify allows optionally ignoring the + // outcome of the expectation + void Verify(CompletionQueue* cq, bool ignore_ok) { + GPR_ASSERT(!expectations_.empty() || !maybe_expectations_.empty()); + while (!expectations_.empty()) { + Next(cq, ignore_ok); + } + } + + // This version of Verify stops after a certain deadline, and uses the + // DoThenAsyncNext API + // to call the lambda + void Verify(CompletionQueue* cq, + std::chrono::system_clock::time_point deadline, + const std::function<void(void)>& lambda) { + if (expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::TIMEOUT); + } else { + while (!expectations_.empty()) { + bool ok; + void* got_tag; + EXPECT_EQ(DoOnceThenAsyncNext(cq, &got_tag, &ok, deadline, lambda), + CompletionQueue::GOT_EVENT); + GotTag(got_tag, ok, false); + } + } + } + + private: + void GotTag(void* got_tag, bool ok, bool ignore_ok) { + auto it = expectations_.find(got_tag); + if (it != expectations_.end()) { + if (!ignore_ok) { + EXPECT_EQ(it->second, ok); + } + expectations_.erase(it); + } else { + auto it2 = maybe_expectations_.find(got_tag); + if (it2 != maybe_expectations_.end()) { + if (it2->second.seen != nullptr) { + EXPECT_FALSE(*it2->second.seen); + *it2->second.seen = true; + } + if (!ignore_ok) { + EXPECT_EQ(it2->second.ok, ok); + } + } else { + gpr_log(GPR_ERROR, "Unexpected tag: %p", got_tag); + abort(); + } + } + } + + struct MaybeExpect { + bool ok; + bool* seen; + }; + + std::map<void*, bool> expectations_; + std::map<void*, MaybeExpect> maybe_expectations_; + bool lambda_run_; +}; + +} // namespace testing +} // namespace grpc diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc new file mode 100644 index 0000000000..44ba2a6009 --- /dev/null +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -0,0 +1,623 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <memory> +#include <vector> + +#include <grpcpp/channel.h> +#include <grpcpp/client_context.h> +#include <grpcpp/create_channel.h> +#include <grpcpp/generic/generic_stub.h> +#include <grpcpp/impl/codegen/proto_utils.h> +#include <grpcpp/impl/codegen/server_interceptor.h> +#include <grpcpp/server.h> +#include <grpcpp/server_builder.h> +#include <grpcpp/server_context.h> + +#include "src/proto/grpc/testing/echo.grpc.pb.h" +#include "test/core/util/port.h" +#include "test/core/util/test_config.h" +#include "test/cpp/end2end/interceptors_util.h" +#include "test/cpp/end2end/test_service_impl.h" +#include "test/cpp/util/byte_buffer_proto_helper.h" + +#include <gtest/gtest.h> + +namespace grpc { +namespace testing { +namespace { + +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ +class DummyInterceptor : public experimental::Interceptor { + public: + DummyInterceptor(experimental::ServerRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } else if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints:: + POST_RECV_INITIAL_METADATA)) { + num_times_run_reverse_++; + } + methods->Proceed(); + } + + static void Reset() { + num_times_run_.store(0); + num_times_run_reverse_.store(0); + } + + static int GetNumTimesRun() { + EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); + return num_times_run_.load(); + } + + private: + static std::atomic<int> num_times_run_; + static std::atomic<int> num_times_run_reverse_; +}; + +std::atomic<int> DummyInterceptor::num_times_run_; +std::atomic<int> DummyInterceptor::num_times_run_reverse_; + +class DummyInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +class LoggingInterceptor : public experimental::Interceptor { + public: + LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req); + EXPECT_TRUE(req.message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS)) { + auto* map = methods->GetSendTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + auto status = methods->GetSendStatus(); + EXPECT_EQ(status.ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.find("testkey") == 0 && + pair.second.find("testvalue") == 0; + if (found) break; + } + EXPECT_EQ(found, true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + EXPECT_TRUE(resp->message().find("Hello") == 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE)) { + // Got nothing interesting to do here + } + methods->Proceed(); + } + + private: + experimental::ServerRpcInfo* info_; +}; + +class LoggingInterceptorFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new LoggingInterceptor(info); + } +}; + +void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + ctx.AddMetadata("testkey", "testvalue"); + auto stream = stub->BidiStream(&ctx); + for (auto i = 0; i < 10; i++) { + req.set_message("Hello" + std::to_string(i)); + stream->Write(req); + stream->Read(&resp); + EXPECT_EQ(req.message(), resp.message()); + } + ASSERT_TRUE(stream->WritesDone()); + Status s = stream->Finish(); + EXPECT_EQ(s.ok(), true); +} + +class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncUnaryTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncUnaryTest, UnaryTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { + protected: + ServerInterceptorsEnd2endSyncStreamingTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + server_ = builder.BuildAndStart(); + } + std::string server_address_; + EchoTestServiceStreamingImpl service_; + std::unique_ptr<Server> server_; +}; + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ClientStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeClientStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, ServerStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeServerStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ServerInterceptorsEnd2endSyncStreamingTest, BidiStreamingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto channel = CreateChannel(server_address_, InsecureChannelCredentials()); + MakeBidiStreamingCall(channel); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +class ServerInterceptorsAsyncEnd2endTest : public ::testing::Test {}; + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnaryTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, cq.get())); + + service.RequestEcho(&srv_ctx, &recv_request, &response_writer, cq.get(), + cq.get(), tag(2)); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + + Verifier().Expect(2, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + send_response.set_message(recv_request.message()); + response_writer.Finish(send_response, Status::OK, tag(3)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, BidiStreamingTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + EchoTestService::AsyncService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + ServerContext srv_ctx; + grpc::ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr<ClientAsyncReaderWriter<EchoRequest, EchoResponse>> + cli_stream(stub->AsyncBidiStream(&cli_ctx, cq.get(), tag(1))); + + service.RequestBidiStream(&srv_ctx, &srv_stream, cq.get(), cq.get(), tag(2)); + + Verifier().Expect(1, true).Expect(2, true).Verify(cq.get()); + + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + cli_stream->Write(send_request, tag(3)); + srv_stream.Read(&recv_request, tag(4)); + Verifier().Expect(3, true).Expect(4, true).Verify(cq.get()); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + cli_stream->Read(&recv_response, tag(6)); + Verifier().Expect(5, true).Expect(6, true).Verify(cq.get()); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + srv_stream.Read(&recv_request, tag(8)); + Verifier().Expect(7, true).Expect(8, false).Verify(cq.get()); + + srv_stream.Finish(Status::OK, tag(9)); + cli_stream->Finish(&recv_status, tag(10)); + Verifier().Expect(9, true).Expect(10, true).Verify(cq.get()); + + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, GenericRPCTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + AsyncGenericService service; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + builder.RegisterAsyncGenericService(&service); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + auto channel = CreateChannel(server_address, InsecureChannelCredentials()); + GenericStub generic_stub(channel); + + const grpc::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo"); + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + GenericServerContext srv_ctx; + GenericServerAsyncReaderWriter stream(&srv_ctx); + + // The string needs to be long enough to test heap-based slice. + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + + std::unique_ptr<GenericClientAsyncReaderWriter> call = + generic_stub.PrepareCall(&cli_ctx, kMethodName, cq.get()); + call->StartCall(tag(1)); + Verifier().Expect(1, true).Verify(cq.get()); + std::unique_ptr<ByteBuffer> send_buffer = + SerializeToByteBuffer(&send_request); + call->Write(*send_buffer, tag(2)); + // Send ByteBuffer can be destroyed after calling Write. + send_buffer.reset(); + Verifier().Expect(2, true).Verify(cq.get()); + call->WritesDone(tag(3)); + Verifier().Expect(3, true).Verify(cq.get()); + + service.RequestCall(&srv_ctx, &stream, cq.get(), cq.get(), tag(4)); + + Verifier().Expect(4, true).Verify(cq.get()); + EXPECT_EQ(kMethodName, srv_ctx.method()); + EXPECT_TRUE(CheckMetadata(srv_ctx.client_metadata(), "testkey", "testvalue")); + srv_ctx.AddTrailingMetadata("testkey", "testvalue"); + + ByteBuffer recv_buffer; + stream.Read(&recv_buffer, tag(5)); + Verifier().Expect(5, true).Verify(cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request)); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + send_buffer = SerializeToByteBuffer(&send_response); + stream.Write(*send_buffer, tag(6)); + send_buffer.reset(); + Verifier().Expect(6, true).Verify(cq.get()); + + stream.Finish(Status::OK, tag(7)); + Verifier().Expect(7, true).Verify(cq.get()); + + recv_buffer.Clear(); + call->Read(&recv_buffer, tag(8)); + Verifier().Expect(8, true).Verify(cq.get()); + EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response)); + + call->Finish(&recv_status, tag(9)); + Verifier().Expect(9, true).Verify(cq.get()); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); + EXPECT_TRUE(CheckMetadata(cli_ctx.GetServerTrailingMetadata(), "testkey", + "testvalue")); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +TEST_F(ServerInterceptorsAsyncEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto cq = builder.AddCompletionQueue(); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr<Channel> channel = + CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader( + stub->AsyncUnimplemented(&cli_ctx, send_request, cq.get())); + + response_reader->Finish(&recv_response, &recv_status, tag(4)); + Verifier().Expect(4, true).Verify(cq.get()); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + cq->Shutdown(); + void* ignored_tag; + bool ignored_ok; + while (cq->Next(&ignored_tag, &ignored_ok)) + ; + grpc_recycle_unused_port(port); +} + +class ServerInterceptorsSyncUnimplementedEnd2endTest : public ::testing::Test { +}; + +TEST_F(ServerInterceptorsSyncUnimplementedEnd2endTest, UnimplementedRpcTest) { + DummyInterceptor::Reset(); + int port = grpc_pick_unused_port_or_die(); + string server_address = "localhost:" + std::to_string(port); + ServerBuilder builder; + TestServiceImpl service; + builder.RegisterService(&service); + builder.AddListeningPort(server_address, InsecureServerCredentials()); + std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + creators; + for (auto i = 0; i < 20; i++) { + creators.push_back(std::unique_ptr<DummyInterceptorFactory>( + new DummyInterceptorFactory())); + } + builder.experimental().SetInterceptorCreators(std::move(creators)); + auto server = builder.BuildAndStart(); + + ChannelArguments args; + std::shared_ptr<Channel> channel = + CreateChannel(server_address, InsecureChannelCredentials()); + std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub; + stub = grpc::testing::UnimplementedEchoService::NewStub(channel); + EchoRequest send_request; + EchoResponse recv_response; + + ClientContext cli_ctx; + send_request.set_message("Hello"); + Status recv_status = + stub->Unimplemented(&cli_ctx, send_request, &recv_response); + + EXPECT_EQ(StatusCode::UNIMPLEMENTED, recv_status.error_code()); + EXPECT_EQ("", recv_status.error_message()); + + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); + + server->Shutdown(); + grpc_recycle_unused_port(port); +} + +} // namespace +} // namespace testing +} // namespace grpc + +int main(int argc, char** argv) { + grpc_test_init(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/cpp/interop/client_helper.h b/test/cpp/interop/client_helper.h index eada2f671f..7dee85cc98 100644 --- a/test/cpp/interop/client_helper.h +++ b/test/cpp/interop/client_helper.h @@ -19,10 +19,12 @@ #ifndef GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H #define GRPC_TEST_CPP_INTEROP_CLIENT_HELPER_H +#include <functional> #include <memory> #include <unordered_map> #include <grpcpp/channel.h> +#include <grpcpp/client_context.h> #include "src/core/lib/surface/call_test_only.h" |