/* * * Copyright 2018 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ #include #include #include #include #include #include #include #include #include #include #include #include "src/proto/grpc/testing/echo.grpc.pb.h" #include "test/core/util/port.h" #include "test/core/util/test_config.h" #include "test/cpp/end2end/test_service_impl.h" #include "test/cpp/util/byte_buffer_proto_helper.h" #include "test/cpp/util/string_ref_helper.h" #include namespace grpc { namespace testing { namespace { class EchoTestServiceStreamingImpl : public EchoTestService::Service { public: ~EchoTestServiceStreamingImpl() override {} Status BidiStream( ServerContext* context, grpc::ServerReaderWriter* stream) override { EchoRequest req; EchoResponse resp; auto client_metadata = context->client_metadata(); for (const auto& pair : client_metadata) { context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); } while (stream->Read(&req)) { resp.set_message(req.message()); stream->Write(resp, grpc::WriteOptions()); } return Status::OK; } }; class ClientInterceptorsStreamingEnd2EndTest : public ::testing::Test { protected: ClientInterceptorsStreamingEnd2EndTest() { int port = grpc_pick_unused_port_or_die(); ServerBuilder builder; server_address_ = "localhost:" + std::to_string(port); builder.AddListeningPort(server_address_, InsecureServerCredentials()); builder.RegisterService(&service_); server_ = builder.BuildAndStart(); } std::string server_address_; EchoTestServiceStreamingImpl service_; std::unique_ptr server_; }; class ClientInterceptorsEnd2endTest : public ::testing::Test { protected: ClientInterceptorsEnd2endTest() { int port = grpc_pick_unused_port_or_die(); ServerBuilder builder; server_address_ = "localhost:" + std::to_string(port); builder.AddListeningPort(server_address_, InsecureServerCredentials()); builder.RegisterService(&service_); server_ = builder.BuildAndStart(); } ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } std::string server_address_; TestServiceImpl service_; std::unique_ptr server_; }; /* This interceptor does nothing. Just keeps a global count on the number of * times it was invoked. */ class DummyInterceptor : public experimental::Interceptor { public: DummyInterceptor(experimental::ClientRpcInfo* info) {} virtual void Intercept(experimental::InterceptorBatchMethods* methods) { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { num_times_run_++; } else if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints:: POST_RECV_INITIAL_METADATA)) { num_times_run_reverse_++; } methods->Proceed(); } static void Reset() { num_times_run_.store(0); num_times_run_reverse_.store(0); } static int GetNumTimesRun() { EXPECT_EQ(num_times_run_.load(), num_times_run_reverse_.load()); return num_times_run_.load(); } private: static std::atomic num_times_run_; static std::atomic num_times_run_reverse_; }; std::atomic DummyInterceptor::num_times_run_; std::atomic DummyInterceptor::num_times_run_reverse_; class DummyInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: virtual experimental::Interceptor* CreateClientInterceptor( experimental::ClientRpcInfo* info) override { return new DummyInterceptor(info); } }; /* Hijacks Echo RPC and fills in the expected values */ class HijackingInterceptor : public experimental::Interceptor { public: HijackingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; // Make sure it is the right method EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { // gpr_log(GPR_ERROR, "ran this"); bool hijack = false; if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { auto* map = methods->GetSendInitialMetadata(); // Check that we can see the test metadata ASSERT_EQ(map->size(), static_cast(1)); auto iterator = map->begin(); EXPECT_EQ("testkey", iterator->first); EXPECT_EQ("testvalue", iterator->second); hijack = true; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; auto* buffer = methods->GetSendMessage(); auto copied_buffer = *buffer; SerializationTraits::Deserialize(&copied_buffer, &req); EXPECT_EQ(req.message(), "Hello"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { // Got nothing to do here for now } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here for now EXPECT_EQ(map->size(), static_cast(0)); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast(methods->GetRecvMessage()); // Check that we got the hijacked message, and re-insert the expected // message EXPECT_EQ(resp->message(), "Hello1"); resp->set_message("Hello"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); bool found = false; // Check that we received the metadata as an echo for (const auto& pair : *map) { found = pair.first.starts_with("testkey") && pair.second.starts_with("testvalue"); if (found) break; } EXPECT_EQ(found, true); auto* status = methods->GetRecvStatus(); EXPECT_EQ(status->ok(), true); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here at the moment EXPECT_EQ(map->size(), static_cast(0)); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { // Insert a different message than expected EchoResponse* resp = static_cast(methods->GetRecvMessage()); resp->set_message("Hello1"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); // insert the metadata that we want EXPECT_EQ(map->size(), static_cast(0)); map->insert(std::make_pair("testkey", "testvalue")); auto* status = methods->GetRecvStatus(); *status = Status(StatusCode::OK, ""); } if (hijack) { methods->Hijack(); } else { methods->Proceed(); } } private: experimental::ClientRpcInfo* info_; }; class HijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: virtual experimental::Interceptor* CreateClientInterceptor( experimental::ClientRpcInfo* info) override { return new HijackingInterceptor(info); } }; class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { public: HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { info_ = info; // Make sure it is the right method EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { // gpr_log(GPR_ERROR, "ran this"); if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { auto* map = methods->GetSendInitialMetadata(); // Check that we can see the test metadata ASSERT_EQ(map->size(), static_cast(1)); auto iterator = map->begin(); EXPECT_EQ("testkey", iterator->first); EXPECT_EQ("testvalue", iterator->second); // Make a copy of the map metadata_map_ = *map; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; auto* buffer = methods->GetSendMessage(); auto copied_buffer = *buffer; SerializationTraits::Deserialize(&copied_buffer, &req); EXPECT_EQ(req.message(), "Hello"); req_ = req; stub_ = grpc::testing::EchoTestService::NewStub( methods->GetInterceptedChannel()); ctx_.AddMetadata(metadata_map_.begin()->first, metadata_map_.begin()->second); stub_->experimental_async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) { EXPECT_EQ(s.ok(), true); EXPECT_EQ(resp_.message(), "Hello"); methods->Hijack(); }); // There isn't going to be any other interesting operation in this batch, // so it is fine to return return; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { // Got nothing to do here for now } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here for now EXPECT_EQ(map->size(), static_cast(0)); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast(methods->GetRecvMessage()); // Check that we got the hijacked message, and re-insert the expected // message EXPECT_EQ(resp->message(), "Hello"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); bool found = false; // Check that we received the metadata as an echo for (const auto& pair : *map) { found = pair.first.starts_with("testkey") && pair.second.starts_with("testvalue"); if (found) break; } EXPECT_EQ(found, true); auto* status = methods->GetRecvStatus(); EXPECT_EQ(status->ok(), true); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here at the moment EXPECT_EQ(map->size(), static_cast(0)); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { // Insert a different message than expected EchoResponse* resp = static_cast(methods->GetRecvMessage()); resp->set_message(resp_.message()); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); // insert the metadata that we want EXPECT_EQ(map->size(), static_cast(0)); *map = ctx_.GetServerTrailingMetadata(); auto* status = methods->GetRecvStatus(); *status = Status(StatusCode::OK, ""); } methods->Proceed(); } private: experimental::ClientRpcInfo* info_; std::multimap metadata_map_; ClientContext ctx_; EchoRequest req_; EchoResponse resp_; std::unique_ptr stub_; }; class HijackingInterceptorMakesAnotherCallFactory : public experimental::ClientInterceptorFactoryInterface { public: virtual experimental::Interceptor* CreateClientInterceptor( experimental::ClientRpcInfo* info) override { return new HijackingInterceptorMakesAnotherCall(info); } }; class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { // gpr_log(GPR_ERROR, "ran this"); if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { auto* map = methods->GetSendInitialMetadata(); // Check that we can see the test metadata ASSERT_EQ(map->size(), static_cast(1)); auto iterator = map->begin(); EXPECT_EQ("testkey", iterator->first); EXPECT_EQ("testvalue", iterator->second); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; auto* buffer = methods->GetSendMessage(); auto copied_buffer = *buffer; SerializationTraits::Deserialize(&copied_buffer, &req); EXPECT_TRUE(req.message().find("Hello") == 0); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { // Got nothing to do here for now } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here for now EXPECT_EQ(map->size(), static_cast(0)); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast(methods->GetRecvMessage()); EXPECT_TRUE(resp->message().find("Hello") == 0); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); bool found = false; // Check that we received the metadata as an echo for (const auto& pair : *map) { found = pair.first.starts_with("testkey") && pair.second.starts_with("testvalue"); if (found) break; } EXPECT_EQ(found, true); auto* status = methods->GetRecvStatus(); EXPECT_EQ(status->ok(), true); } methods->Proceed(); } private: experimental::ClientRpcInfo* info_; }; class LoggingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: virtual experimental::Interceptor* CreateClientInterceptor( experimental::ClientRpcInfo* info) override { return new LoggingInterceptor(info); } }; void MakeCall(const std::shared_ptr& channel) { auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; EchoRequest req; req.mutable_param()->set_echo_metadata(true); ctx.AddMetadata("testkey", "testvalue"); req.set_message("Hello"); EchoResponse resp; Status s = stub->Echo(&ctx, req, &resp); EXPECT_EQ(s.ok(), true); EXPECT_EQ(resp.message(), "Hello"); } void MakeCallbackCall(const std::shared_ptr& channel) { auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; EchoRequest req; std::mutex mu; std::condition_variable cv; bool done = false; req.mutable_param()->set_echo_metadata(true); ctx.AddMetadata("testkey", "testvalue"); req.set_message("Hello"); EchoResponse resp; stub->experimental_async()->Echo(&ctx, &req, &resp, [&resp, &mu, &done, &cv](Status s) { // gpr_log(GPR_ERROR, "got the callback"); EXPECT_EQ(s.ok(), true); EXPECT_EQ(resp.message(), "Hello"); std::lock_guard l(mu); done = true; cv.notify_one(); }); std::unique_lock l(mu); while (!done) { cv.wait(l); } } void MakeStreamingCall(const std::shared_ptr& channel) { auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; EchoRequest req; EchoResponse resp; ctx.AddMetadata("testkey", "testvalue"); auto stream = stub->BidiStream(&ctx); for (auto i = 0; i < 10; i++) { req.set_message("Hello" + std::to_string(i)); stream->Write(req); stream->Read(&resp); EXPECT_EQ(req.message(), resp.message()); } ASSERT_TRUE(stream->WritesDone()); Status s = stream->Finish(); EXPECT_EQ(s.ok(), true); } TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { ChannelArguments args; DummyInterceptor::Reset(); auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); creators->push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeCall(channel); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); // Add 20 dummy interceptors before hijacking interceptor for (auto i = 0; i < 20; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } creators->push_back(std::unique_ptr( new HijackingInterceptorFactory())); // Add 20 dummy interceptors after hijacking interceptor for (auto i = 0; i < 20; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeCall(channel); // Make sure only 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { ChannelArguments args; auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); creators->push_back(std::unique_ptr( new LoggingInterceptorFactory())); creators->push_back(std::unique_ptr( new HijackingInterceptorFactory())); auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeCall(channel); } TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingMakesAnotherCallTest) { ChannelArguments args; DummyInterceptor::Reset(); auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); // Add 5 dummy interceptors before hijacking interceptor for (auto i = 0; i < 5; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } creators->push_back( std::unique_ptr( new HijackingInterceptorMakesAnotherCallFactory())); // Add 7 dummy interceptors after hijacking interceptor for (auto i = 0; i < 7; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } // auto channel = experimental::CreateCustomChannelWithInterceptors( // server_address_, InsecureChannelCredentials(), args, // std::move(creators)); auto channel = server_->experimental().InProcessChannelWithInterceptors( args, std::move(creators)); MakeCall(channel); // Make sure all interceptors were run once, since the hijacking interceptor // makes an RPC on the intercepted channel EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12); } TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTestWithCallback) { ChannelArguments args; DummyInterceptor::Reset(); auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); creators->push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = server_->experimental().InProcessChannelWithInterceptors( args, std::move(creators)); MakeCallbackCall(channel); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } TEST_F(ClientInterceptorsStreamingEnd2EndTest, ClientInterceptorLoggingTest) { ChannelArguments args; DummyInterceptor::Reset(); auto creators = std::unique_ptr>>( new std::vector< std::unique_ptr>()); creators->push_back(std::unique_ptr( new LoggingInterceptorFactory())); // Add 20 dummy interceptors for (auto i = 0; i < 20; i++) { creators->push_back(std::unique_ptr( new DummyInterceptorFactory())); } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeStreamingCall(channel); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } } // namespace } // namespace testing } // namespace grpc int main(int argc, char** argv) { grpc_test_init(argc, argv); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }