diff options
Diffstat (limited to 'test/cpp')
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 3db709e956..9ec89ce1ab 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -270,6 +270,60 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -550,6 +604,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); |