diff options
author | Yash Tibrewal <yashkt@google.com> | 2019-01-07 16:32:24 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-07 16:32:24 -0800 |
commit | 8dcda4dc36aa4e4d3a4c46023f6470b4c1ec7bca (patch) | |
tree | 1b4c33de61f67f7ed45a7e1f9204f34ff93be0ca | |
parent | bcd29821ec609996367bafe905b7f38b0e84d661 (diff) | |
parent | 34d77aae5ec26063e2eb5dc4b47ba4dce90c7136 (diff) |
Merge pull request #17630 from yashykt/nocopyinterception
Modifying semantics for GetSendMessage and GetSerializedSendMessage. Also adding ModifySendMessage
-rw-r--r-- | include/grpcpp/impl/codegen/call_op_set.h | 57 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor.h | 2 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/interceptor_common.h | 35 | ||||
-rw-r--r-- | test/cpp/end2end/client_interceptors_end2end_test.cc | 8 | ||||
-rw-r--r-- | test/cpp/end2end/server_interceptors_end2end_test.cc | 71 |
5 files changed, 132 insertions, 41 deletions
diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index 1c75560c04..880c62344b 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -317,7 +317,15 @@ class CallOpSendMessage { protected: void AddOp(grpc_op* ops, size_t* nops) { - if (!send_buf_.Valid() || hijacked_) return; + if (msg_ == nullptr && !send_buf_.Valid()) return; + if (hijacked_) { + serializer_ = nullptr; + return; + } + if (msg_ != nullptr) { + GPR_CODEGEN_ASSERT(serializer_(msg_).ok()); + } + serializer_ = nullptr; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_MESSAGE; op->flags = write_options_.flags(); @@ -327,9 +335,7 @@ class CallOpSendMessage { write_options_.Clear(); } void FinishOp(bool* status) { - if (!send_buf_.Valid()) { - return; - } + if (msg_ == nullptr && !send_buf_.Valid()) return; if (hijacked_ && failed_send_) { // Hijacking interceptor failed this Op *status = false; @@ -341,22 +347,25 @@ class CallOpSendMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!send_buf_.Valid()) return; + if (msg_ == nullptr && !send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); - interceptor_methods->SetSendMessage(&send_buf_, msg_, &failed_send_); + interceptor_methods->SetSendMessage(&send_buf_, &msg_, &failed_send_, + serializer_); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (send_buf_.Valid()) { + if (msg_ != nullptr || send_buf_.Valid()) { interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_SEND_MESSAGE); } send_buf_.Clear(); + msg_ = nullptr; // The contents of the SendMessage value that was previously set // has had its references stolen by core's operations - interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_); + interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_, + nullptr); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { @@ -369,22 +378,32 @@ class CallOpSendMessage { bool failed_send_ = false; ByteBuffer send_buf_; WriteOptions write_options_; + std::function<Status(const void*)> serializer_; }; template <class M> Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { write_options_ = options; - bool own_buf; - // TODO(vjpai): Remove the void below when possible - // The void in the template parameter below should not be needed - // (since it should be implicit) but is needed due to an observed - // difference in behavior between clang and gcc for certain internal users - Status result = SerializationTraits<M, void>::Serialize( - message, send_buf_.bbuf_ptr(), &own_buf); - if (!own_buf) { - send_buf_.Duplicate(); - } - return result; + serializer_ = [this](const void* message) { + bool own_buf; + send_buf_.Clear(); + // TODO(vjpai): Remove the void below when possible + // The void in the template parameter below should not be needed + // (since it should be implicit) but is needed due to an observed + // difference in behavior between clang and gcc for certain internal users + Status result = SerializationTraits<M, void>::Serialize( + *static_cast<const M*>(message), send_buf_.bbuf_ptr(), &own_buf); + if (!own_buf) { + send_buf_.Duplicate(); + } + return result; + }; + // Serialize immediately only if we do not have access to the message pointer + if (msg_ == nullptr) { + return serializer_(&message); + serializer_ = nullptr; + } + return Status(); } template <class M> diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index a57a3fccbb..d749d8578a 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -118,6 +118,8 @@ class InterceptorBatchMethods { /// only supported for sync and callback APIs at the present moment. virtual const void* GetSendMessage() = 0; + virtual void ModifySendMessage(const void* message) = 0; + /// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE /// interceptions. virtual bool GetSendMessageStatus() = 0; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 345127c830..09721343ff 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -79,9 +79,24 @@ class InterceptorBatchMethodsImpl hooks_[static_cast<size_t>(type)] = true; } - ByteBuffer* GetSerializedSendMessage() override { return send_message_; } + ByteBuffer* GetSerializedSendMessage() override { + GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); + if (*orig_send_message_ != nullptr) { + GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok()); + *orig_send_message_ = nullptr; + } + return send_message_; + } + + const void* GetSendMessage() override { + GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); + return *orig_send_message_; + } - const void* GetSendMessage() override { return orig_send_message_; } + void ModifySendMessage(const void* message) override { + GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); + *orig_send_message_ = message; + } bool GetSendMessageStatus() override { return !*fail_send_message_; } @@ -125,11 +140,13 @@ class InterceptorBatchMethodsImpl return recv_trailing_metadata_->map(); } - void SetSendMessage(ByteBuffer* buf, const void* msg, - bool* fail_send_message) { + void SetSendMessage(ByteBuffer* buf, const void** msg, + bool* fail_send_message, + std::function<Status(const void*)> serializer) { send_message_ = buf; orig_send_message_ = msg; fail_send_message_ = fail_send_message; + serializer_ = serializer; } void SetSendInitialMetadata( @@ -359,7 +376,8 @@ class InterceptorBatchMethodsImpl ByteBuffer* send_message_ = nullptr; bool* fail_send_message_ = nullptr; - const void* orig_send_message_ = nullptr; + const void** orig_send_message_ = nullptr; + std::function<Status(const void*)> serializer_; std::multimap<grpc::string, grpc::string>* send_initial_metadata_; @@ -429,6 +447,13 @@ class CancelInterceptorBatchMethods return nullptr; } + void ModifySendMessage(const void* message) override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call ModifySendMessage on a method which " + "has a Cancel notification"); + } + std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { GPR_CODEGEN_ASSERT(false && "It is illegal to call GetSendInitialMetadata on a " diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 9fbfd8c84a..177922f457 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -516,16 +516,16 @@ class LoggingInterceptor : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; + EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage()) + ->message() + .find("Hello"), + 0u); auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) .ok()); EXPECT_TRUE(req.message().find("Hello") == 0u); - EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage()) - ->message() - .find("Hello"), - 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc index 09e855b0d0..82f142ba91 100644 --- a/test/cpp/end2end/server_interceptors_end2end_test.cc +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -142,29 +142,68 @@ class LoggingInterceptorFactory } }; -// Test if GetSendMessage works as expected -class GetSendMessageTester : public experimental::Interceptor { +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageTester : public experimental::Interceptor { public: - GetSendMessageTester(experimental::ServerRpcInfo* info) {} + SyncSendMessageTester(experimental::ServerRpcInfo* info) {} void Intercept(experimental::InterceptorBatchMethods* methods) override { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { - EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage()) - ->message() - .find("Hello"), - 0u); + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("Hello"), 0u); + new_msg_.set_message("World" + old_msg); + methods->ModifySendMessage(&new_msg_); } methods->Proceed(); } + + private: + EchoRequest new_msg_; }; -class GetSendMessageTesterFactory +class SyncSendMessageTesterFactory : public experimental::ServerInterceptorFactoryInterface { public: virtual experimental::Interceptor* CreateServerInterceptor( experimental::ServerRpcInfo* info) override { - return new GetSendMessageTester(info); + return new SyncSendMessageTester(info); + } +}; + +// Test if SendMessage function family works as expected for sync/callback apis +class SyncSendMessageVerifier : public experimental::Interceptor { + public: + SyncSendMessageVerifier(experimental::ServerRpcInfo* info) {} + + void Intercept(experimental::InterceptorBatchMethods* methods) override { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + // Make sure that the changes made in SyncSendMessageTester persisted + string old_msg = + static_cast<const EchoRequest*>(methods->GetSendMessage())->message(); + EXPECT_EQ(old_msg.find("World"), 0u); + + // Remove the "World" part of the string that we added earlier + new_msg_.set_message(old_msg.erase(0, 5)); + methods->ModifySendMessage(&new_msg_); + + // LoggingInterceptor verifies that changes got reverted + } + methods->Proceed(); + } + + private: + EchoRequest new_msg_; +}; + +class SyncSendMessageVerifierFactory + : public experimental::ServerInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateServerInterceptor( + experimental::ServerRpcInfo* info) override { + return new SyncSendMessageVerifier(info); } }; @@ -201,10 +240,13 @@ class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test { creators; creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( - new LoggingInterceptorFactory())); + new SyncSendMessageTesterFactory())); creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( - new GetSendMessageTesterFactory())); + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); // Add 20 dummy interceptor factories and null interceptor factories for (auto i = 0; i < 20; i++) { creators.push_back(std::unique_ptr<DummyInterceptorFactory>( @@ -244,10 +286,13 @@ class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test { creators; creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( - new LoggingInterceptorFactory())); + new SyncSendMessageTesterFactory())); creators.push_back( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( - new GetSendMessageTesterFactory())); + new SyncSendMessageVerifierFactory())); + creators.push_back( + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( + new LoggingInterceptorFactory())); for (auto i = 0; i < 20; i++) { creators.push_back(std::unique_ptr<DummyInterceptorFactory>( new DummyInterceptorFactory())); |