diff options
author | Yash Tibrewal <yashkt@google.com> | 2019-01-04 17:20:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-04 17:20:52 -0800 |
commit | 6de81f54bbba10caa79fc72a253c0ea53fa05273 (patch) | |
tree | a25bfcde7f1525565dbf8e0a2dcd894318b9e099 | |
parent | b542cc8917e18e6d921d42cf2308705dedeb3539 (diff) | |
parent | 8ba5922e8767e4c10ca38cb46b19d5878c310598 (diff) |
Merge pull request #17220 from yashykt/failhijackedsend
Add interceptor method to fail hijacked send messages and get status on POST_SEND_MESSAGE
-rw-r--r-- | include/grpcpp/impl/codegen/call_op_set.h | 23 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor.h | 11 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor_common.h | 27 | ||||
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 83 |
4 files changed, 139 insertions, 5 deletions
diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index 310bea93ca..b52bc98b89 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -326,21 +326,37 @@ class CallOpSendMessage { // Flags are per-message: clear them after use. write_options_.Clear(); } - void FinishOp(bool* status) { send_buf_.Clear(); } + void FinishOp(bool* status) { + if (!send_buf_.Valid()) { + return; + } + if (hijacked_ && failed_send_) { + // Hijacking interceptor failed this Op + *status = false; + } else if (!*status) { + // This Op was passed down to core and the Op failed + failed_send_ = true; + } + } void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); - interceptor_methods->SetSendMessage(&send_buf_, msg_); + interceptor_methods->SetSendMessage(&send_buf_, msg_, &failed_send_); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (send_buf_.Valid()) { + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE); + } + send_buf_.Clear(); // The contents of the SendMessage value that was previously set // has had its references stolen by core's operations - interceptor_methods->SetSendMessage(nullptr, nullptr); + interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { @@ -350,6 +366,7 @@ class CallOpSendMessage { private: const void* msg_ = nullptr; // The original non-serialized message bool hijacked_ = false; + bool failed_send_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 5a9a3a44e6..519c65c717 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -46,9 +46,10 @@ namespace experimental { /// operation has been requested and it is available. POST_RECV means that a /// result is available but has not yet been passed back to the application. enum class InterceptionHookPoints { - /// The first two in this list are for clients and servers + /// The first three in this list are for clients and servers PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, + POST_SEND_MESSAGE, PRE_SEND_STATUS, // server only PRE_SEND_CLOSE, // client only: WritesDone for stream; after write in unary /// The following three are for hijacked clients only and can only be @@ -117,6 +118,10 @@ class InterceptorBatchMethods { /// only supported for sync and callback APIs at the present moment. virtual const void* GetSendMessage() = 0; + /// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE + /// interceptions. + virtual bool GetSendMessageStatus() = 0; + /// Returns a modifiable multimap of the initial metadata to be sent. Valid /// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates /// that this field is not valid. @@ -162,6 +167,10 @@ class InterceptorBatchMethods { /// started from interceptors without infinite regress through the interceptor /// list. virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; + + // On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND + // MESSAGE op + virtual void FailHijackedSendMessage() = 0; }; /// Interface for an interceptor. Interceptor authors must create a class diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 4b7eaefee1..734860615f 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -83,6 +83,8 @@ class InterceptorBatchMethodsImpl const void* GetSendMessage() override { return orig_send_message_; } + bool GetSendMessageStatus() override { return !*fail_send_message_; } + std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { return send_initial_metadata_; } @@ -112,14 +114,22 @@ class InterceptorBatchMethodsImpl Status* GetRecvStatus() override { return recv_status_; } + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + *fail_send_message_ = true; + } + std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() override { return recv_trailing_metadata_->map(); } - void SetSendMessage(ByteBuffer* buf, const void* msg) { + void SetSendMessage(ByteBuffer* buf, const void* msg, + bool* fail_send_message) { send_message_ = buf; orig_send_message_ = msg; + fail_send_message_ = fail_send_message; } void SetSendInitialMetadata( @@ -339,6 +349,7 @@ class InterceptorBatchMethodsImpl std::function<void(void)> callback_; ByteBuffer* send_message_ = nullptr; + bool* fail_send_message_ = nullptr; const void* orig_send_message_ = nullptr; std::multimap<grpc::string, grpc::string>* send_initial_metadata_; @@ -392,6 +403,14 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetSendMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetSendMessageStatus on a method which " + "has a Cancel notification"); + return false; + } + const void* GetSendMessage() override { GPR_CODEGEN_ASSERT( false && @@ -465,6 +484,12 @@ class CancelInterceptorBatchMethods "method which has a Cancel notification"); return std::unique_ptr<ChannelInterface>(nullptr); } + + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(false && + "It is illegal to call FailHijackedSendMessage on a " + "method which has a Cancel notification"); + } }; } // namespace internal } // namespace grpc diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 3f9820aba4..f8728fc595 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -339,6 +339,60 @@ class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { grpc::string msg; }; +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + class BidiStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -628,6 +682,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> + creators; + creators.push_back( + std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); |