diff options
Diffstat (limited to 'test/cpp/end2end')
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 299 | ||||
-rw-r--r-- | test/cpp/end2end/interceptors_util.cc | 10 | ||||
-rw-r--r-- | test/cpp/end2end/interceptors_util.h | 3 | ||||
-rw-r--r-- | test/cpp/end2end/server_interceptors_end2end_test.cc | 81 |
4 files changed, 386 insertions, 7 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 8abf4eb3f4..177922f457 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -68,7 +68,7 @@ class HijackingInterceptor : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - auto* buffer = methods->GetSendMessage(); + auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) @@ -173,7 +173,7 @@ class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - auto* buffer = methods->GetSendMessage(); + auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) @@ -270,6 +270,235 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { + public: + BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message().find("Hello"), 0u); + msg = req.message(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", + "testvalue"); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message(msg); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage()) + ->message() + .find("Hello"), + 0u); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + grpc::string msg; +}; + +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + +class ServerStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedRecvMessage(); + } + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = methods->GetRecvMessage() == nullptr; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedMessage() { return got_failed_message_; } + + private: + experimental::ClientRpcInfo* info_; + static bool got_failed_message_; + int count_ = 0; +}; + +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + +class ServerStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ServerStreamingRpcHijackingInterceptor(info); + } +}; + +class BidiStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new BidiStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -287,12 +516,16 @@ class LoggingInterceptor : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - auto* buffer = methods->GetSendMessage(); + EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage()) + ->message() + .find("Hello"), + 0u); + auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) .ok()); - EXPECT_TRUE(req.message().find("Hello") == 0); + EXPECT_TRUE(req.message().find("Hello") == 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { @@ -308,7 +541,7 @@ class LoggingInterceptor : public experimental::Interceptor { experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast<EchoResponse*>(methods->GetRecvMessage()); - EXPECT_TRUE(resp->message().find("Hello") == 0); + EXPECT_TRUE(resp->message().find("Hello") == 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { @@ -546,6 +779,62 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>( + new BidiStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index e0ad7d1526..900f02b5f3 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, return false; } +bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first == key && pair.second == value) { + return true; + } + } + return false; +} + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> CreateDummyClientInterceptors() { std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index 659e613d2e..419845e5f6 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel); bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map, const string& key, const string& value); +bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map, + const string& key, const string& value); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> CreateDummyClientInterceptors(); diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc index 53d8c4dc96..82f142ba91 100644 --- a/test/cpp/end2end/server_interceptors_end2end_test.cc +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -73,7 +73,7 @@ class LoggingInterceptor : public experimental::Interceptor { type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)); } - virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + void Intercept(experimental::InterceptorBatchMethods* methods) override { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { auto* map = methods->GetSendInitialMetadata(); @@ -83,7 +83,7 @@ class LoggingInterceptor : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - auto* buffer = methods->GetSendMessage(); + auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) @@ -142,6 +142,71 @@ class LoggingInterceptorFactory } }; +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageTester : public experimental::Interceptor { + public: + SyncSendMessageTester(experimental::ServerRpcInfo* info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("Hello"), 0u); + new_msg_.set_message("World" + old_msg); + methods->ModifySendMessage(&new_msg_); + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageTesterFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageTester(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageVerifier : public experimental::Interceptor { + public: + SyncSendMessageVerifier(experimental::ServerRpcInfo* info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + // Make sure that the changes made in SyncSendMessageTester persisted + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("World"), 0u); + + // Remove the "World" part of the string that we added earlier + new_msg_.set_message(old_msg.erase(0, 5)); + methods->ModifySendMessage(&new_msg_); + + // LoggingInterceptor verifies that changes got reverted + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageVerifierFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageVerifier(info); + } +}; + void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) { auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; @@ -175,6 +240,12 @@ class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { creators; creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( new LoggingInterceptorFactory())); // Add 20 dummy interceptor factories and null interceptor factories for (auto i = 0; i < 20; i++) { @@ -215,6 +286,12 @@ class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { creators; creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageTesterFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( new LoggingInterceptorFactory())); for (auto i = 0; i < 20; i++) { creators.push_back(std::unique_ptr<DummyInterceptorFactory>( |