aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/cpp/end2end/client_interceptors_end2end_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'test/cpp/end2end/client_interceptors_end2end_test.cc')
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc83
1 files changed, 83 insertions, 0 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 3f9820aba4..f8728fc595 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -339,6 +339,60 @@ class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
grpc::string msg;
};
+class ClientStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedSendMessage();
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
+ EXPECT_FALSE(got_failed_send_);
+ got_failed_send_ = !methods->GetSendMessageStatus();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedSend() { return got_failed_send_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ int count_ = 0;
+ static bool got_failed_send_;
+};
+
+bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
+
+class ClientStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ClientStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class BidiStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
@@ -628,6 +682,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
+ new ClientStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ req.mutable_param()->set_echo_metadata(true);
+ req.set_message("Hello");
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(req));
+ expected_resp += "Hello";
+ }
+ // The interceptor will reject the 11th message
+ writer->Write(req);
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), false);
+ EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();