From 699c10386d6f0cbed7364e5a94144f0794f469d5 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Fri, 9 Nov 2018 19:43:00 -0800 Subject: Add method to fail recv msg for hijacked rpcs --- .../end2end/client_interceptors_end2end_test.cc | 101 +++++++++++++++++++++ 1 file changed, 101 insertions(+) (limited to 'test') diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 0b34ec93ae..b94a3d752e 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -269,6 +269,92 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class ServerStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), static_cast(1)); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + if (++count > 10) { + methods->FailHijackedRecvMessage(); + } + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + int count = 0; +}; + +class ServerStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ServerStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -535,6 +621,21 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back( + std::unique_ptr( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeServerStreamingCall(channel); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); -- cgit v1.2.3 From 5d7d6c0fbdcac75ea482e1fde3e128cd0c1646c1 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Wed, 14 Nov 2018 17:35:26 -0800 Subject: Add method to fail hijacked send messages --- include/grpcpp/impl/codegen/call_op_set.h | 10 ++- include/grpcpp/impl/codegen/interceptor.h | 4 ++ include/grpcpp/impl/codegen/interceptor_common.h | 18 +++++- .../end2end/client_interceptors_end2end_test.cc | 73 ++++++++++++++++++++++ 4 files changed, 102 insertions(+), 3 deletions(-) (limited to 'test') diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index b4c34a01c9..1f2b88e9e1 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -314,14 +314,19 @@ class CallOpSendMessage { // Flags are per-message: clear them after use. write_options_.Clear(); } - void FinishOp(bool* status) { send_buf_.Clear(); } + void FinishOp(bool* status) { + send_buf_.Clear(); + if (hijacked_ && failed_send_) { + *status = false; + } + } void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); - interceptor_methods->SetSendMessage(&send_buf_); + interceptor_methods->SetSendMessage(&send_buf_, &failed_send_); } void SetFinishInterceptionHookPoint( @@ -333,6 +338,7 @@ class CallOpSendMessage { private: bool hijacked_ = false; + bool failed_send_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index e449e44a23..47239332c8 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -118,6 +118,10 @@ class InterceptorBatchMethods { // only interceptors after the current interceptor are created from the // factory objects registered with the channel. virtual std::unique_ptr GetInterceptedChannel() = 0; + + // On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND + // MESSAGE op + virtual void FailHijackedSendMessage() = 0; }; class Interceptor { diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index d0aa23cb0a..601a929afe 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -110,12 +110,21 @@ class InterceptorBatchMethodsImpl Status* GetRecvStatus() override { return recv_status_; } + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(hooks_[static_cast( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + *fail_send_message_ = true; + } + std::multimap* GetRecvTrailingMetadata() override { return recv_trailing_metadata_->map(); } - void SetSendMessage(ByteBuffer* buf) { send_message_ = buf; } + void SetSendMessage(ByteBuffer* buf, bool* fail_send_message) { + send_message_ = buf; + fail_send_message_ = fail_send_message; + } void SetSendInitialMetadata( std::multimap* metadata) { @@ -334,6 +343,7 @@ class InterceptorBatchMethodsImpl std::function callback_; ByteBuffer* send_message_ = nullptr; + bool* fail_send_message_ = nullptr; std::multimap* send_initial_metadata_; @@ -451,6 +461,12 @@ class CancelInterceptorBatchMethods "method which has a Cancel notification"); return std::unique_ptr(nullptr); } + + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(false && + "It is illegal to call FailHijackedSendMessage on a " + "method which has a Cancel notification"); + } }; } // namespace internal } // namespace grpc diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 0b34ec93ae..81efd15452 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -269,6 +269,49 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; +}; +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -535,6 +578,36 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back( + std::unique_ptr( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // Expect that the interceptor will reject the 11th message + EXPECT_FALSE(writer->Write(req)); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); -- cgit v1.2.3 From ceeb28a360d48389edb8c1c91a1efac46e24a9b7 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Wed, 14 Nov 2018 18:02:54 -0800 Subject: s/count/count_/ --- test/cpp/end2end/client_interceptors_end2end_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index b94a3d752e..459788ba51 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -318,7 +318,7 @@ class ServerStreamingRpcHijackingInterceptor } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { - if (++count > 10) { + if (++count_ > 10) { methods->FailHijackedRecvMessage(); } EchoResponse* resp = @@ -343,7 +343,7 @@ class ServerStreamingRpcHijackingInterceptor private: experimental::ClientRpcInfo* info_; - int count = 0; + int count_ = 0; }; class ServerStreamingRpcHijackingInterceptorFactory -- cgit v1.2.3 From 0911e489e3fe22e2ca5d7c927dac83358f2f05b7 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 15 Nov 2018 13:55:56 -0800 Subject: Add a method to check whether the message was received successfully --- include/grpcpp/impl/codegen/call_op_set.h | 6 ++++-- include/grpcpp/impl/codegen/interceptor.h | 3 +++ include/grpcpp/impl/codegen/interceptor_common.h | 10 ++++++++++ test/cpp/end2end/client_interceptors_end2end_test.cc | 12 ++++++++++++ 4 files changed, 29 insertions(+), 2 deletions(-) (limited to 'test') diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index aae8b9d3e3..3699ec94f2 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -406,12 +406,13 @@ class CallOpRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (message_ == nullptr) return; interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (message_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } @@ -501,12 +502,13 @@ class CallOpGenericRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (!deserialize_) return; interceptor_methods->SetRecvMessage(message_, &got_message); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { - if (!got_message) return; + if (!deserialize_) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 943376a545..b977c35016 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -103,6 +103,9 @@ class InterceptorBatchMethods { // is already deserialized virtual void* GetRecvMessage() = 0; + // Checks whether the RECV MESSAGE op completed successfully + virtual bool GetRecvMessageStatus() = 0; + // Returns a modifiable multimap of the received initial metadata virtual std::multimap* GetRecvInitialMetadata() = 0; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index d23b71f8a7..b2e92dd6f3 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -103,6 +103,8 @@ class InterceptorBatchMethodsImpl void* GetRecvMessage() override { return recv_message_; } + bool GetRecvMessageStatus() override { return *got_message_; } + std::multimap* GetRecvInitialMetadata() override { return recv_initial_metadata_->map(); @@ -432,6 +434,14 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetRecvMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetRecvMessageStatus on a method which " + "has a Cancel notification"); + return false; + } + std::multimap* GetRecvInitialMetadata() override { GPR_CODEGEN_ASSERT(false && diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 459788ba51..4d4760a652 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -325,6 +325,12 @@ class ServerStreamingRpcHijackingInterceptor static_cast(methods->GetRecvMessage()); resp->set_message("Hello"); } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + // Only the last message will be a failure + EXPECT_FALSE(got_failed_message_); + got_failed_message_ = !methods->GetRecvMessageStatus(); + } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* map = methods->GetRecvTrailingMetadata(); @@ -341,11 +347,16 @@ class ServerStreamingRpcHijackingInterceptor } } + static bool GotFailedMessage() { return got_failed_message_; } + private: experimental::ClientRpcInfo* info_; + static bool got_failed_message_; int count_ = 0; }; +bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; + class ServerStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -634,6 +645,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); } TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { -- cgit v1.2.3 From d4ebd30eb2eb94f77ac9b52c44880e3d70c6aef0 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 15 Nov 2018 15:57:43 -0800 Subject: Add method to get status of send message op on POST_SEND_MESSAGE --- include/grpcpp/impl/codegen/call_op_set.h | 18 ++++++++++++++++-- include/grpcpp/impl/codegen/interceptor.h | 6 +++++- include/grpcpp/impl/codegen/interceptor_common.h | 10 ++++++++++ test/cpp/end2end/client_interceptors_end2end_test.cc | 18 ++++++++++++++++-- 4 files changed, 47 insertions(+), 5 deletions(-) (limited to 'test') diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index 1f2b88e9e1..f330679ffc 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -315,9 +315,16 @@ class CallOpSendMessage { write_options_.Clear(); } void FinishOp(bool* status) { - send_buf_.Clear(); + if (!send_buf_.Valid()) { + return; + } if (hijacked_ && failed_send_) { + // Hijacking interceptor failed this Op *status = false; + } else if (!*status) { + // This Op was passed down to core and the Op failed + gpr_log(GPR_ERROR, "failure status"); + failed_send_ = true; } } @@ -330,7 +337,14 @@ class CallOpSendMessage { } void SetFinishInterceptionHookPoint( - InterceptorBatchMethodsImpl* interceptor_methods) {} + InterceptorBatchMethodsImpl* interceptor_methods) { + if (send_buf_.Valid()) { + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE); + // We had already registered failed_send_ earlier. No need to do it again. + } + send_buf_.Clear(); + } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { hijacked_ = true; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 47239332c8..154172dd81 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -41,9 +41,10 @@ class InterceptedMessage { }; enum class InterceptionHookPoints { - /* The first two in this list are for clients and servers */ + /* The first three in this list are for clients and servers */ PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, + POST_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be @@ -85,6 +86,9 @@ class InterceptorBatchMethods { // sent virtual ByteBuffer* GetSendMessage() = 0; + // Checks whether the SEND MESSAGE op succeeded + virtual bool GetSendMessageStatus() = 0; + // Returns a modifiable multimap of the initial metadata to be sent virtual std::multimap* GetSendInitialMetadata() = 0; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 601a929afe..21326df73b 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -81,6 +81,8 @@ class InterceptorBatchMethodsImpl ByteBuffer* GetSendMessage() override { return send_message_; } + bool GetSendMessageStatus() override { return !*fail_send_message_; } + std::multimap* GetSendInitialMetadata() override { return send_initial_metadata_; } @@ -113,6 +115,7 @@ class InterceptorBatchMethodsImpl void FailHijackedSendMessage() override { GPR_CODEGEN_ASSERT(hooks_[static_cast( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + gpr_log(GPR_ERROR, "failing"); *fail_send_message_ = true; } @@ -396,6 +399,13 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetSendMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetSendMessageStatus on a method which " + "has a Cancel notification"); + } + std::multimap* GetSendInitialMetadata() override { GPR_CODEGEN_ASSERT(false && "It is illegal to call GetSendInitialMetadata on a " diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 81efd15452..97947e7393 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -287,6 +287,13 @@ class ClientStreamingRpcHijackingInterceptor methods->FailHijackedSendMessage(); } } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + gpr_log(GPR_ERROR, "%d", got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + gpr_log(GPR_ERROR, "%d", got_failed_send_); + } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* status = methods->GetRecvStatus(); @@ -299,10 +306,16 @@ class ClientStreamingRpcHijackingInterceptor } } + static bool GotFailedSend() { return got_failed_send_; } + private: experimental::ClientRpcInfo* info_; int count_ = 0; + static bool got_failed_send_; }; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + class ClientStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -602,10 +615,11 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { EXPECT_TRUE(writer->Write(req)); expected_resp += "Hello"; } - // Expect that the interceptor will reject the 11th message - EXPECT_FALSE(writer->Write(req)); + // The interceptor will reject the 11th message + writer->Write(req); Status s = writer->Finish(); EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); } TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { -- cgit v1.2.3 From 00c9c40004d011f01c72d253a530edb3364992bf Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 15 Nov 2018 16:00:33 -0800 Subject: Remove extraneous logging statements --- include/grpcpp/impl/codegen/call_op_set.h | 1 - include/grpcpp/impl/codegen/interceptor_common.h | 1 - test/cpp/end2end/client_interceptors_end2end_test.cc | 2 -- 3 files changed, 4 deletions(-) (limited to 'test') diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index f330679ffc..ac3ba17bd9 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -323,7 +323,6 @@ class CallOpSendMessage { *status = false; } else if (!*status) { // This Op was passed down to core and the Op failed - gpr_log(GPR_ERROR, "failure status"); failed_send_ = true; } } diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 21326df73b..321691236b 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -115,7 +115,6 @@ class InterceptorBatchMethodsImpl void FailHijackedSendMessage() override { GPR_CODEGEN_ASSERT(hooks_[static_cast( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); - gpr_log(GPR_ERROR, "failing"); *fail_send_message_ = true; } diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 97947e7393..3708c11235 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -290,9 +290,7 @@ class ClientStreamingRpcHijackingInterceptor if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { EXPECT_FALSE(got_failed_send_); - gpr_log(GPR_ERROR, "%d", got_failed_send_); got_failed_send_ = !methods->GetSendMessageStatus(); - gpr_log(GPR_ERROR, "%d", got_failed_send_); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { -- cgit v1.2.3 From e6e10814993b5e6af93675bd468755786de1e20d Mon Sep 17 00:00:00 2001 From: Moiz Haidry Date: Tue, 11 Dec 2018 10:29:08 -0800 Subject: Add support for Callback Client Streaming benchmarks --- test/cpp/qps/client.h | 28 ++++--- test/cpp/qps/client_callback.cc | 176 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 174 insertions(+), 30 deletions(-) (limited to 'test') diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 668d941916..0b9837660b 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -236,6 +236,21 @@ class Client { return 0; } + bool IsClosedLoop() { return closed_loop_; } + + gpr_timespec NextIssueTime(int thread_idx) { + const gpr_timespec result = next_time_[thread_idx]; + next_time_[thread_idx] = + gpr_time_add(next_time_[thread_idx], + gpr_time_from_nanos(interarrival_timer_.next(thread_idx), + GPR_TIMESPAN)); + return result; + } + + bool ThreadCompleted() { + return static_cast(gpr_atm_acq_load(&thread_pool_done_)); + } + protected: bool closed_loop_; gpr_atm thread_pool_done_; @@ -289,14 +304,6 @@ class Client { } } - gpr_timespec NextIssueTime(int thread_idx) { - const gpr_timespec result = next_time_[thread_idx]; - next_time_[thread_idx] = - gpr_time_add(next_time_[thread_idx], - gpr_time_from_nanos(interarrival_timer_.next(thread_idx), - GPR_TIMESPAN)); - return result; - } std::function NextIssuer(int thread_idx) { return closed_loop_ ? std::function() : std::bind(&Client::NextIssueTime, this, thread_idx); @@ -380,10 +387,6 @@ class Client { double interval_start_time_; }; - bool ThreadCompleted() { - return static_cast(gpr_atm_acq_load(&thread_pool_done_)); - } - virtual void ThreadFunc(size_t thread_idx, Client::Thread* t) = 0; std::vector> threads_; @@ -442,6 +445,7 @@ class ClientImpl : public Client { config.payload_config()); } virtual ~ClientImpl() {} + const RequestType* request() { return &request_; } protected: const int cores_; diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc index 87889e36dc..00d5853a8e 100644 --- a/test/cpp/qps/client_callback.cc +++ b/test/cpp/qps/client_callback.cc @@ -73,6 +73,20 @@ class CallbackClient virtual ~CallbackClient() {} + /** + * The main thread of the benchmark will be waiting on DestroyMultithreading. + * Increment the rpcs_done_ variable to signify that the Callback RPC + * after thread completion is done. When the last outstanding rpc increments + * the counter it should also signal the main thread's conditional variable. + */ + void NotifyMainThreadOfThreadCompletion() { + std::lock_guard l(shutdown_mu_); + rpcs_done_++; + if (rpcs_done_ == total_outstanding_rpcs_) { + shutdown_cv_.notify_one(); + } + } + protected: size_t num_threads_; size_t total_outstanding_rpcs_; @@ -93,23 +107,6 @@ class CallbackClient ThreadFuncImpl(t, thread_idx); } - virtual void ScheduleRpc(Thread* t, size_t thread_idx, - size_t ctx_vector_idx) = 0; - - /** - * The main thread of the benchmark will be waiting on DestroyMultithreading. - * Increment the rpcs_done_ variable to signify that the Callback RPC - * after thread completion is done. When the last outstanding rpc increments - * the counter it should also signal the main thread's conditional variable. - */ - void NotifyMainThreadOfThreadCompletion() { - std::lock_guard l(shutdown_mu_); - rpcs_done_++; - if (rpcs_done_ == total_outstanding_rpcs_) { - shutdown_cv_.notify_one(); - } - } - private: int NumThreads(const ClientConfig& config) { int num_threads = config.async_client_threads(); @@ -157,7 +154,7 @@ class CallbackUnaryClient final : public CallbackClient { void InitThreadFuncImpl(size_t thread_idx) override { return; } private: - void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) override { + void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) { if (!closed_loop_) { gpr_timespec next_issue_time = NextIssueTime(thread_idx); // Start an alarm callback to run the internal callback after @@ -199,11 +196,154 @@ class CallbackUnaryClient final : public CallbackClient { } }; +class CallbackStreamingClient : public CallbackClient { + public: + CallbackStreamingClient(const ClientConfig& config) + : CallbackClient(config), + messages_per_stream_(config.messages_per_stream()) { + for (int ch = 0; ch < config.client_channels(); ch++) { + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + ctx_.emplace_back( + new CallbackClientRpcContext(channels_[ch].get_stub())); + } + } + StartThreads(num_threads_); + } + ~CallbackStreamingClient() {} + + void AddHistogramEntry(double start_, bool ok, void* thread_ptr) { + // Update Histogram with data from the callback run + HistogramEntry entry; + if (ok) { + entry.set_value((UsageTimer::Now() - start_) * 1e9); + } + ((Client::Thread*)thread_ptr)->UpdateHistogram(&entry); + } + + int messages_per_stream() { return messages_per_stream_; } + + protected: + const int messages_per_stream_; +}; + +class CallbackStreamingPingPongClient : public CallbackStreamingClient { + public: + CallbackStreamingPingPongClient(const ClientConfig& config) + : CallbackStreamingClient(config) {} + ~CallbackStreamingPingPongClient() {} +}; + +class CallbackStreamingPingPongReactor final + : public grpc::experimental::ClientBidiReactor { + public: + CallbackStreamingPingPongReactor( + CallbackStreamingPingPongClient* client, + std::unique_ptr ctx) + : client_(client), ctx_(std::move(ctx)), messages_issued_(0) {} + + void StartNewRpc() { + if (client_->ThreadCompleted()) return; + start_ = UsageTimer::Now(); + ctx_->stub_->experimental_async()->StreamingCall(&(ctx_->context_), this); + StartWrite(client_->request()); + StartCall(); + } + + void OnWriteDone(bool ok) override { + if (!ok || client_->ThreadCompleted()) { + if (!ok) gpr_log(GPR_ERROR, "Error writing RPC"); + StartWritesDone(); + return; + } + StartRead(&ctx_->response_); + } + + void OnReadDone(bool ok) override { + client_->AddHistogramEntry(start_, ok, thread_ptr_); + + if (client_->ThreadCompleted() || !ok || + (client_->messages_per_stream() != 0 && + ++messages_issued_ >= client_->messages_per_stream())) { + if (!ok) { + gpr_log(GPR_ERROR, "Error reading RPC"); + } + StartWritesDone(); + return; + } + StartWrite(client_->request()); + } + + void OnDone(const Status& s) override { + if (client_->ThreadCompleted() || !s.ok()) { + client_->NotifyMainThreadOfThreadCompletion(); + return; + } + ctx_.reset(new CallbackClientRpcContext(ctx_->stub_)); + ScheduleRpc(); + } + + void ScheduleRpc() { + if (client_->ThreadCompleted()) return; + + if (!client_->IsClosedLoop()) { + gpr_timespec next_issue_time = client_->NextIssueTime(thread_idx_); + // Start an alarm callback to run the internal callback after + // next_issue_time + ctx_->alarm_.experimental().Set(next_issue_time, + [this](bool ok) { StartNewRpc(); }); + } else { + StartNewRpc(); + } + } + + void set_thread_ptr(void* ptr) { thread_ptr_ = ptr; } + void set_thread_idx(int thread_idx) { thread_idx_ = thread_idx; } + + CallbackStreamingPingPongClient* client_; + std::unique_ptr ctx_; + int thread_idx_; // Needed to update histogram entries + void* thread_ptr_; // Needed to update histogram entries + double start_; // Track message start time + int messages_issued_; // Messages issued by this stream +}; + +class CallbackStreamingPingPongClientImpl final + : public CallbackStreamingPingPongClient { + public: + CallbackStreamingPingPongClientImpl(const ClientConfig& config) + : CallbackStreamingPingPongClient(config) { + for (size_t i = 0; i < total_outstanding_rpcs_; i++) + reactor_.emplace_back( + new CallbackStreamingPingPongReactor(this, std::move(ctx_[i]))); + } + ~CallbackStreamingPingPongClientImpl() {} + + bool ThreadFuncImpl(Client::Thread* t, size_t thread_idx) override { + for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; + vector_idx += num_threads_) { + reactor_[vector_idx]->set_thread_ptr(t); + reactor_[vector_idx]->set_thread_idx(thread_idx); + reactor_[vector_idx]->ScheduleRpc(); + } + return true; + } + + void InitThreadFuncImpl(size_t thread_idx) override {} + + private: + std::vector> reactor_; +}; + +// TODO(mhaidry) : Implement Streaming from client, server and both ways + std::unique_ptr CreateCallbackClient(const ClientConfig& config) { switch (config.rpc_type()) { case UNARY: return std::unique_ptr(new CallbackUnaryClient(config)); case STREAMING: + return std::unique_ptr( + new CallbackStreamingPingPongClientImpl(config)); case STREAMING_FROM_CLIENT: case STREAMING_FROM_SERVER: case STREAMING_BOTH_WAYS: -- cgit v1.2.3 From 5ec78a286d7be61aec929b133c031a7a1af262df Mon Sep 17 00:00:00 2001 From: Moiz Haidry Date: Fri, 14 Dec 2018 10:36:51 -0800 Subject: Added support for fixed load benchmarks, all the rpcs access one requestor to the get the next issue time for the RPC --- test/cpp/qps/client_callback.cc | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) (limited to 'test') diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc index 00d5853a8e..1880f46d43 100644 --- a/test/cpp/qps/client_callback.cc +++ b/test/cpp/qps/client_callback.cc @@ -66,7 +66,10 @@ class CallbackClient config, BenchmarkStubCreator) { num_threads_ = NumThreads(config); rpcs_done_ = 0; - SetupLoadTest(config, num_threads_); + + // Don't divide the fixed load among threads as the user threads + // only bootstrap the RPCs + SetupLoadTest(config, 1); total_outstanding_rpcs_ = config.client_channels() * config.outstanding_rpcs_per_channel(); } @@ -87,6 +90,11 @@ class CallbackClient } } + gpr_timespec NextIssueTime() { + std::lock_guard l(next_issue_time_mu_); + return Client::NextIssueTime(0); + } + protected: size_t num_threads_; size_t total_outstanding_rpcs_; @@ -108,6 +116,8 @@ class CallbackClient } private: + std::mutex next_issue_time_mu_; // Used by next issue time + int NumThreads(const ClientConfig& config) { int num_threads = config.async_client_threads(); if (num_threads <= 0) { // Use dynamic sizing @@ -146,7 +156,7 @@ class CallbackUnaryClient final : public CallbackClient { bool ThreadFuncImpl(Thread* t, size_t thread_idx) override { for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; vector_idx += num_threads_) { - ScheduleRpc(t, thread_idx, vector_idx); + ScheduleRpc(t, vector_idx); } return true; } @@ -154,26 +164,26 @@ class CallbackUnaryClient final : public CallbackClient { void InitThreadFuncImpl(size_t thread_idx) override { return; } private: - void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) { + void ScheduleRpc(Thread* t, size_t vector_idx) { if (!closed_loop_) { - gpr_timespec next_issue_time = NextIssueTime(thread_idx); + gpr_timespec next_issue_time = NextIssueTime(); // Start an alarm callback to run the internal callback after // next_issue_time ctx_[vector_idx]->alarm_.experimental().Set( - next_issue_time, [this, t, thread_idx, vector_idx](bool ok) { - IssueUnaryCallbackRpc(t, thread_idx, vector_idx); + next_issue_time, [this, t, vector_idx](bool ok) { + IssueUnaryCallbackRpc(t, vector_idx); }); } else { - IssueUnaryCallbackRpc(t, thread_idx, vector_idx); + IssueUnaryCallbackRpc(t, vector_idx); } } - void IssueUnaryCallbackRpc(Thread* t, size_t thread_idx, size_t vector_idx) { + void IssueUnaryCallbackRpc(Thread* t, size_t vector_idx) { GPR_TIMER_SCOPE("CallbackUnaryClient::ThreadFunc", 0); double start = UsageTimer::Now(); ctx_[vector_idx]->stub_->experimental_async()->UnaryCall( (&ctx_[vector_idx]->context_), &request_, &ctx_[vector_idx]->response_, - [this, t, thread_idx, start, vector_idx](grpc::Status s) { + [this, t, start, vector_idx](grpc::Status s) { // Update Histogram with data from the callback run HistogramEntry entry; if (s.ok()) { @@ -190,7 +200,7 @@ class CallbackUnaryClient final : public CallbackClient { ctx_[vector_idx].reset( new CallbackClientRpcContext(ctx_[vector_idx]->stub_)); // Schedule a new RPC - ScheduleRpc(t, thread_idx, vector_idx); + ScheduleRpc(t, vector_idx); } }); } @@ -287,7 +297,7 @@ class CallbackStreamingPingPongReactor final if (client_->ThreadCompleted()) return; if (!client_->IsClosedLoop()) { - gpr_timespec next_issue_time = client_->NextIssueTime(thread_idx_); + gpr_timespec next_issue_time = client_->NextIssueTime(); // Start an alarm callback to run the internal callback after // next_issue_time ctx_->alarm_.experimental().Set(next_issue_time, @@ -298,11 +308,9 @@ class CallbackStreamingPingPongReactor final } void set_thread_ptr(void* ptr) { thread_ptr_ = ptr; } - void set_thread_idx(int thread_idx) { thread_idx_ = thread_idx; } CallbackStreamingPingPongClient* client_; std::unique_ptr ctx_; - int thread_idx_; // Needed to update histogram entries void* thread_ptr_; // Needed to update histogram entries double start_; // Track message start time int messages_issued_; // Messages issued by this stream @@ -323,7 +331,6 @@ class CallbackStreamingPingPongClientImpl final for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_; vector_idx += num_threads_) { reactor_[vector_idx]->set_thread_ptr(t); - reactor_[vector_idx]->set_thread_idx(thread_idx); reactor_[vector_idx]->ScheduleRpc(); } return true; -- cgit v1.2.3 From 31a775b425eac37bb43c301cfb25e1f6a4bde106 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Tue, 18 Dec 2018 12:52:14 -0800 Subject: Add missing argument --- include/grpcpp/impl/codegen/call_op_set.h | 3 +-- include/grpcpp/impl/codegen/interceptor_common.h | 1 + test/cpp/end2end/client_interceptors_end2end_test.cc | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) (limited to 'test') diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index 3db9f48bff..1c0ccbab52 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -340,12 +340,11 @@ class CallOpSendMessage { if (send_buf_.Valid()) { interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_SEND_MESSAGE); - // We had already registered failed_send_ earlier. No need to do it again. } send_buf_.Clear(); // The contents of the SendMessage value that was previously set // has had its references stolen by core's operations - interceptor_methods->SetSendMessage(nullptr); + interceptor_methods->SetSendMessage(nullptr, &failed_send_); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 321691236b..b01706af8d 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -403,6 +403,7 @@ class CancelInterceptorBatchMethods false && "It is illegal to call GetSendMessageStatus on a method which " "has a Cancel notification"); + return false; } std::multimap* GetSendInitialMetadata() override { diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 33773e3b3b..c55eaab4d6 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -580,11 +580,9 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { ChannelArguments args; - auto creators = std::unique_ptr>>( - new std::vector< - std::unique_ptr>()); - creators->push_back( + std::vector> + creators; + creators.push_back( std::unique_ptr( new ClientStreamingRpcHijackingInterceptorFactory())); auto channel = experimental::CreateCustomChannelWithInterceptors( -- cgit v1.2.3 From 7bb853ebdd0b6e057de447147ad60ebf42e0903d Mon Sep 17 00:00:00 2001 From: Moiz Haidry Date: Fri, 21 Dec 2018 22:40:38 -0800 Subject: Addressed PR comments. Made Client::Thread public and removed use of void ptr to refer it. Avoided overloading of NextIssue TIme by renaming it NextRPCIssueTime --- test/cpp/qps/client.h | 116 ++++++++++++++++++++-------------------- test/cpp/qps/client_callback.cc | 18 +++---- 2 files changed, 67 insertions(+), 67 deletions(-) (limited to 'test') diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 0b9837660b..4b8ac9bd94 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -251,64 +251,6 @@ class Client { return static_cast(gpr_atm_acq_load(&thread_pool_done_)); } - protected: - bool closed_loop_; - gpr_atm thread_pool_done_; - double median_latency_collection_interval_seconds_; // In seconds - - void StartThreads(size_t num_threads) { - gpr_atm_rel_store(&thread_pool_done_, static_cast(false)); - threads_remaining_ = num_threads; - for (size_t i = 0; i < num_threads; i++) { - threads_.emplace_back(new Thread(this, i)); - } - } - - void EndThreads() { - MaybeStartRequests(); - threads_.clear(); - } - - virtual void DestroyMultithreading() = 0; - - void SetupLoadTest(const ClientConfig& config, size_t num_threads) { - // Set up the load distribution based on the number of threads - const auto& load = config.load_params(); - - std::unique_ptr random_dist; - switch (load.load_case()) { - case LoadParams::kClosedLoop: - // Closed-loop doesn't use random dist at all - break; - case LoadParams::kPoisson: - random_dist.reset( - new ExpDist(load.poisson().offered_load() / num_threads)); - break; - default: - GPR_ASSERT(false); - } - - // Set closed_loop_ based on whether or not random_dist is set - if (!random_dist) { - closed_loop_ = true; - } else { - closed_loop_ = false; - // set up interarrival timer according to random dist - interarrival_timer_.init(*random_dist, num_threads); - const auto now = gpr_now(GPR_CLOCK_MONOTONIC); - for (size_t i = 0; i < num_threads; i++) { - next_time_.push_back(gpr_time_add( - now, - gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN))); - } - } - } - - std::function NextIssuer(int thread_idx) { - return closed_loop_ ? std::function() - : std::bind(&Client::NextIssueTime, this, thread_idx); - } - class Thread { public: Thread(Client* client, size_t idx) @@ -387,6 +329,64 @@ class Client { double interval_start_time_; }; + protected: + bool closed_loop_; + gpr_atm thread_pool_done_; + double median_latency_collection_interval_seconds_; // In seconds + + void StartThreads(size_t num_threads) { + gpr_atm_rel_store(&thread_pool_done_, static_cast(false)); + threads_remaining_ = num_threads; + for (size_t i = 0; i < num_threads; i++) { + threads_.emplace_back(new Thread(this, i)); + } + } + + void EndThreads() { + MaybeStartRequests(); + threads_.clear(); + } + + virtual void DestroyMultithreading() = 0; + + void SetupLoadTest(const ClientConfig& config, size_t num_threads) { + // Set up the load distribution based on the number of threads + const auto& load = config.load_params(); + + std::unique_ptr random_dist; + switch (load.load_case()) { + case LoadParams::kClosedLoop: + // Closed-loop doesn't use random dist at all + break; + case LoadParams::kPoisson: + random_dist.reset( + new ExpDist(load.poisson().offered_load() / num_threads)); + break; + default: + GPR_ASSERT(false); + } + + // Set closed_loop_ based on whether or not random_dist is set + if (!random_dist) { + closed_loop_ = true; + } else { + closed_loop_ = false; + // set up interarrival timer according to random dist + interarrival_timer_.init(*random_dist, num_threads); + const auto now = gpr_now(GPR_CLOCK_MONOTONIC); + for (size_t i = 0; i < num_threads; i++) { + next_time_.push_back(gpr_time_add( + now, + gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN))); + } + } + } + + std::function NextIssuer(int thread_idx) { + return closed_loop_ ? std::function() + : std::bind(&Client::NextIssueTime, this, thread_idx); + } + virtual void ThreadFunc(size_t thread_idx, Client::Thread* t) = 0; std::vector> threads_; diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc index 1880f46d43..4a06325f2b 100644 --- a/test/cpp/qps/client_callback.cc +++ b/test/cpp/qps/client_callback.cc @@ -90,7 +90,7 @@ class CallbackClient } } - gpr_timespec NextIssueTime() { + gpr_timespec NextRPCIssueTime() { std::lock_guard l(next_issue_time_mu_); return Client::NextIssueTime(0); } @@ -166,7 +166,7 @@ class CallbackUnaryClient final : public CallbackClient { private: void ScheduleRpc(Thread* t, size_t vector_idx) { if (!closed_loop_) { - gpr_timespec next_issue_time = NextIssueTime(); + gpr_timespec next_issue_time = NextRPCIssueTime(); // Start an alarm callback to run the internal callback after // next_issue_time ctx_[vector_idx]->alarm_.experimental().Set( @@ -221,13 +221,13 @@ class CallbackStreamingClient : public CallbackClient { } ~CallbackStreamingClient() {} - void AddHistogramEntry(double start_, bool ok, void* thread_ptr) { + void AddHistogramEntry(double start_, bool ok, Thread* thread_ptr) { // Update Histogram with data from the callback run HistogramEntry entry; if (ok) { entry.set_value((UsageTimer::Now() - start_) * 1e9); } - ((Client::Thread*)thread_ptr)->UpdateHistogram(&entry); + thread_ptr->UpdateHistogram(&entry); } int messages_per_stream() { return messages_per_stream_; } @@ -297,7 +297,7 @@ class CallbackStreamingPingPongReactor final if (client_->ThreadCompleted()) return; if (!client_->IsClosedLoop()) { - gpr_timespec next_issue_time = client_->NextIssueTime(); + gpr_timespec next_issue_time = client_->NextRPCIssueTime(); // Start an alarm callback to run the internal callback after // next_issue_time ctx_->alarm_.experimental().Set(next_issue_time, @@ -307,13 +307,13 @@ class CallbackStreamingPingPongReactor final } } - void set_thread_ptr(void* ptr) { thread_ptr_ = ptr; } + void set_thread_ptr(Client::Thread* ptr) { thread_ptr_ = ptr; } CallbackStreamingPingPongClient* client_; std::unique_ptr ctx_; - void* thread_ptr_; // Needed to update histogram entries - double start_; // Track message start time - int messages_issued_; // Messages issued by this stream + Client::Thread* thread_ptr_; // Needed to update histogram entries + double start_; // Track message start time + int messages_issued_; // Messages issued by this stream }; class CallbackStreamingPingPongClientImpl final -- cgit v1.2.3 From aecc5f7285faedec634c99aff0b48eea86d3861a Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Fri, 28 Dec 2018 16:03:20 -0800 Subject: Add client interceptor test for bidi streaming hijacking interceptor --- .../end2end/client_interceptors_end2end_test.cc | 91 ++++++++++++++++++++++ test/cpp/end2end/interceptors_util.cc | 10 +++ test/cpp/end2end/interceptors_util.h | 3 + 3 files changed, 104 insertions(+) (limited to 'test') diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 8abf4eb3f4..ab387aa914 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -270,6 +270,84 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { + public: + BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message().find("Hello"), 0); + msg = req.message(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", + "testvalue"); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message(msg); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EXPECT_EQ(static_cast(methods->GetRecvMessage()) + ->message() + .find("Hello"), + 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + grpc::string msg; +}; + +class BidiStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new BidiStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -546,6 +624,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new BidiStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index e0ad7d1526..900f02b5f3 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap& map, return false; } +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first == key && pair.second == value) { + return true; + } + } + return false; +} + std::vector> CreateDummyClientInterceptors() { std::vector> diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index 659e613d2e..419845e5f6 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr& channel); bool CheckMetadata(const std::multimap& map, const string& key, const string& value); +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value); + std::vector> CreateDummyClientInterceptors(); -- cgit v1.2.3 From a5ed3d245e448c1e0e0e28b93e5821aaa7a3e439 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Fri, 4 Jan 2019 11:17:35 -0800 Subject: Avoid unsigned signed comparison issues --- test/cpp/end2end/client_interceptors_end2end_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'test') diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index ab387aa914..fc75fdb290 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -291,7 +291,7 @@ class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { EXPECT_TRUE( SerializationTraits::Deserialize(&copied_buffer, &req) .ok()); - EXPECT_EQ(req.message().find("Hello"), 0); + EXPECT_EQ(req.message().find("Hello"), 0u); msg = req.message(); } if (methods->QueryInterceptionHookPoint( @@ -316,7 +316,7 @@ class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { EXPECT_EQ(static_cast(methods->GetRecvMessage()) ->message() .find("Hello"), - 0); + 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { @@ -370,7 +370,7 @@ class LoggingInterceptor : public experimental::Interceptor { EXPECT_TRUE( SerializationTraits::Deserialize(&copied_buffer, &req) .ok()); - EXPECT_TRUE(req.message().find("Hello") == 0); + EXPECT_TRUE(req.message().find("Hello") == 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { @@ -386,7 +386,7 @@ class LoggingInterceptor : public experimental::Interceptor { experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast(methods->GetRecvMessage()); - EXPECT_TRUE(resp->message().find("Hello") == 0); + EXPECT_TRUE(resp->message().find("Hello") == 0u); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { -- cgit v1.2.3