diff options
author | Yash Tibrewal <yashkt@google.com> | 2019-01-05 12:24:31 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-05 12:24:31 -0800 |
commit | 46bd2f7adb926053345665d5c487fa20acd2b5b0 (patch) | |
tree | f375e509e895f0d11888c930d44fdaee6f62b8cb | |
parent | 6de81f54bbba10caa79fc72a253c0ea53fa05273 (diff) | |
parent | 059459a9ee082538d79be65dda3a131ef634cef1 (diff) |
Merge pull request #17179 from yashykt/failhijackedrecv
Add interceptor methods to fail recv msg for hijacked rpcs and set recv message to nullptr on failure
-rw-r--r-- | include/grpcpp/impl/codegen/call_op_set.h | 13 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor.h | 9 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor_common.h | 18 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/server_interface.h | 2 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 4 | ||||
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 111 |
6 files changed, 147 insertions, 10 deletions
diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index b52bc98b89..1c75560c04 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -453,14 +453,16 @@ class CallOpRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - interceptor_methods->SetRecvMessage(message_); + if (message_ == nullptr) return; + interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (message_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { hijacked_ = true; @@ -548,20 +550,23 @@ class CallOpGenericRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - interceptor_methods->SetRecvMessage(message_); + if (!deserialize_) return; + interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (!deserialize_) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { hijacked_ = true; if (!deserialize_) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + got_message = true; } private: diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 519c65c717..a57a3fccbb 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -168,8 +168,13 @@ class InterceptorBatchMethods { /// list. virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; - // On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND - // MESSAGE op + /// On a hijacked RPC, an interceptor can decide to fail a PRE_RECV_MESSAGE + /// op. This would be a signal to the reader that there will be no more + /// messages, or the stream has failed or been cancelled. + virtual void FailHijackedRecvMessage() = 0; + + /// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND + /// MESSAGE op virtual void FailHijackedSendMessage() = 0; }; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 734860615f..345127c830 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -149,7 +149,10 @@ class InterceptorBatchMethodsImpl send_trailing_metadata_ = metadata; } - void SetRecvMessage(void* message) { recv_message_ = message; } + void SetRecvMessage(void* message, bool* got_message) { + recv_message_ = message; + got_message_ = got_message; + } void SetRecvInitialMetadata(MetadataMap* map) { recv_initial_metadata_ = map; @@ -172,6 +175,12 @@ class InterceptorBatchMethodsImpl info->channel(), current_interceptor_index_ + 1)); } + void FailHijackedRecvMessage() override { + GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); + *got_message_ = false; + } + // Clears all state void ClearState() { reverse_ = false; @@ -362,6 +371,7 @@ class InterceptorBatchMethodsImpl std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr; void* recv_message_ = nullptr; + bool* got_message_ = nullptr; MetadataMap* recv_initial_metadata_ = nullptr; @@ -485,6 +495,12 @@ class CancelInterceptorBatchMethods return std::unique_ptr<ChannelInterface>(nullptr); } + void FailHijackedRecvMessage() override { + GPR_CODEGEN_ASSERT(false && + "It is illegal to call FailHijackedRecvMessage on a " + "method which has a Cancel notification"); + } + void FailHijackedSendMessage() override { GPR_CODEGEN_ASSERT(false && "It is illegal to call FailHijackedSendMessage on a " diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index e0e2629827..890a5650d0 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -272,7 +272,7 @@ class ServerInterface : public internal::CallHook { /* Set interception point for recv message */ interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - interceptor_methods_.SetRecvMessage(request_); + interceptor_methods_.SetRecvMessage(request_, nullptr); return RegisteredAsyncRequest::FinalizeResult(tag, status); } diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 1e3c57446f..13741ce7aa 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -278,7 +278,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { request_payload_ = nullptr; interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - interceptor_methods_.SetRecvMessage(request_); + interceptor_methods_.SetRecvMessage(request_, nullptr); } if (interceptor_methods_.RunInterceptors( @@ -446,7 +446,7 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { req_->request_payload_ = nullptr; req_->interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - req_->interceptor_methods_.SetRecvMessage(req_->request_); + req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr); } if (req_->interceptor_methods_.RunInterceptors( diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index f8728fc595..9fbfd8c84a 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -393,6 +393,103 @@ class ClientStreamingRpcHijackingInterceptorFactory } }; +class ServerStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast<unsigned>(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedRecvMessage(); + } + EchoResponse* resp = + static_cast<EchoResponse*>(methods->GetRecvMessage()); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = methods->GetRecvMessage() == nullptr; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast<unsigned>(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedMessage() { return got_failed_message_; } + + private: + experimental::ClientRpcInfo* info_; + static bool got_failed_message_; + int count_ = 0; +}; + +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + +class ServerStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ServerStreamingRpcHijackingInterceptor(info); + } +}; + class BidiStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -711,6 +808,20 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); |