diff options
-rw-r--r-- | include/grpcpp/impl/codegen/call_op_set.h | 6 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor.h | 3 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor_common.h | 10 | ||||
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 12 |
4 files changed, 29 insertions, 2 deletions
diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index aae8b9d3e3..3699ec94f2 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -406,12 +406,13 @@ class CallOpRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (message_ == nullptr) return; interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (message_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } @@ -501,12 +502,13 @@ class CallOpGenericRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (!deserialize_) return; interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (!deserialize_) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 943376a545..b977c35016 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -103,6 +103,9 @@ class InterceptorBatchMethods { // is already deserialized virtual void* GetRecvMessage() = 0; + // Checks whether the RECV MESSAGE op completed successfully + virtual bool GetRecvMessageStatus() = 0; + // Returns a modifiable multimap of the received initial metadata virtual std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() = 0; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index d23b71f8a7..b2e92dd6f3 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -103,6 +103,8 @@ class InterceptorBatchMethodsImpl void* GetRecvMessage() override { return recv_message_; } + bool GetRecvMessageStatus() override { return *got_message_; } + std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() override { return recv_initial_metadata_->map(); @@ -432,6 +434,14 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetRecvMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetRecvMessageStatus on a method which " + "has a Cancel notification"); + return false; + } + std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() override { GPR_CODEGEN_ASSERT(false && diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 459788ba51..4d4760a652 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -326,6 +326,12 @@ class ServerStreamingRpcHijackingInterceptor resp->set_message("Hello"); } if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = !methods->GetRecvMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); // insert the metadata that we want @@ -341,11 +347,16 @@ class ServerStreamingRpcHijackingInterceptor } } + static bool GotFailedMessage() { return got_failed_message_; } + private: experimental::ClientRpcInfo* info_; + static bool got_failed_message_; int count_ = 0; }; +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + class ServerStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -634,6 +645,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); } TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { |