aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
authorGravatar Yash Tibrewal <yashkt@google.com>2019-01-07 10:09:06 -0800
committerGravatar Yash Tibrewal <yashkt@google.com>2019-01-07 10:09:06 -0800
commitdd067fd39034d42bf9a1799d09c1c2a7daedc61a (patch)
treeedb1db69a990b7ade9c25f2e4890ede8596c703f /test
parent7d1491d64c9b6279233c780d290c514a7040c1c7 (diff)
parent46bd2f7adb926053345665d5c487fa20acd2b5b0 (diff)
Merge branch 'master' into nocopyinterception
Diffstat (limited to 'test')
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc285
-rw-r--r--test/cpp/end2end/interceptors_util.cc10
-rw-r--r--test/cpp/end2end/interceptors_util.h3
-rw-r--r--test/cpp/qps/client.h118
-rw-r--r--test/cpp/qps/client_callback.cc201
5 files changed, 533 insertions, 84 deletions
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index ea72c1eda8..177922f457 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -270,6 +270,235 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
+class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
+ public:
+ BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message().find("Hello"), 0u);
+ msg = req.message();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
+ "testvalue");
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(msg);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
+ ->message()
+ .find("Hello"),
+ 0u);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ grpc::string msg;
+};
+
+class ClientStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedSendMessage();
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
+ EXPECT_FALSE(got_failed_send_);
+ got_failed_send_ = !methods->GetSendMessageStatus();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedSend() { return got_failed_send_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ int count_ = 0;
+ static bool got_failed_send_;
+};
+
+bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
+
+class ClientStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ClientStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class ServerStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedRecvMessage();
+ }
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ // Only the last message will be a failure
+ EXPECT_FALSE(got_failed_message_);
+ got_failed_message_ = methods->GetRecvMessage() == nullptr;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedMessage() { return got_failed_message_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ static bool got_failed_message_;
+ int count_ = 0;
+};
+
+bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
+
+class ServerStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ServerStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class BidiStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new BidiStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@@ -550,6 +779,62 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
+ new ClientStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ req.mutable_param()->set_echo_metadata(true);
+ req.set_message("Hello");
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(req));
+ expected_resp += "Hello";
+ }
+ // The interceptor will reject the 11th message
+ writer->Write(req);
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), false);
+ EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
+ new ServerStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeServerStreamingCall(channel);
+ EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
+ new BidiStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc
index e0ad7d1526..900f02b5f3 100644
--- a/test/cpp/end2end/interceptors_util.cc
+++ b/test/cpp/end2end/interceptors_util.cc
@@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first == key && pair.second == value) {
+ return true;
+ }
+ }
+ return false;
+}
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
index 659e613d2e..419845e5f6 100644
--- a/test/cpp/end2end/interceptors_util.h
+++ b/test/cpp/end2end/interceptors_util.h
@@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value);
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();
diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h
index 73f91eed2d..ceb5cdd710 100644
--- a/test/cpp/qps/client.h
+++ b/test/cpp/qps/client.h
@@ -236,58 +236,7 @@ class Client {
return 0;
}
- protected:
- bool closed_loop_;
- gpr_atm thread_pool_done_;
- double median_latency_collection_interval_seconds_; // In seconds
-
- void StartThreads(size_t num_threads) {
- gpr_atm_rel_store(&thread_pool_done_, static_cast<gpr_atm>(false));
- threads_remaining_ = num_threads;
- for (size_t i = 0; i < num_threads; i++) {
- threads_.emplace_back(new Thread(this, i));
- }
- }
-
- void EndThreads() {
- MaybeStartRequests();
- threads_.clear();
- }
-
- virtual void DestroyMultithreading() = 0;
-
- void SetupLoadTest(const ClientConfig& config, size_t num_threads) {
- // Set up the load distribution based on the number of threads
- const auto& load = config.load_params();
-
- std::unique_ptr<RandomDistInterface> random_dist;
- switch (load.load_case()) {
- case LoadParams::kClosedLoop:
- // Closed-loop doesn't use random dist at all
- break;
- case LoadParams::kPoisson:
- random_dist.reset(
- new ExpDist(load.poisson().offered_load() / num_threads));
- break;
- default:
- GPR_ASSERT(false);
- }
-
- // Set closed_loop_ based on whether or not random_dist is set
- if (!random_dist) {
- closed_loop_ = true;
- } else {
- closed_loop_ = false;
- // set up interarrival timer according to random dist
- interarrival_timer_.init(*random_dist, num_threads);
- const auto now = gpr_now(GPR_CLOCK_MONOTONIC);
- for (size_t i = 0; i < num_threads; i++) {
- next_time_.push_back(gpr_time_add(
- now,
- gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN)));
- }
- }
- }
+ bool IsClosedLoop() { return closed_loop_; }
gpr_timespec NextIssueTime(int thread_idx) {
const gpr_timespec result = next_time_[thread_idx];
@@ -297,9 +246,9 @@ class Client {
GPR_TIMESPAN));
return result;
}
- std::function<gpr_timespec()> NextIssuer(int thread_idx) {
- return closed_loop_ ? std::function<gpr_timespec()>()
- : std::bind(&Client::NextIssueTime, this, thread_idx);
+
+ bool ThreadCompleted() {
+ return static_cast<bool>(gpr_atm_acq_load(&thread_pool_done_));
}
class Thread {
@@ -380,8 +329,62 @@ class Client {
double interval_start_time_;
};
- bool ThreadCompleted() {
- return static_cast<bool>(gpr_atm_acq_load(&thread_pool_done_));
+ protected:
+ bool closed_loop_;
+ gpr_atm thread_pool_done_;
+ double median_latency_collection_interval_seconds_; // In seconds
+
+ void StartThreads(size_t num_threads) {
+ gpr_atm_rel_store(&thread_pool_done_, static_cast<gpr_atm>(false));
+ threads_remaining_ = num_threads;
+ for (size_t i = 0; i < num_threads; i++) {
+ threads_.emplace_back(new Thread(this, i));
+ }
+ }
+
+ void EndThreads() {
+ MaybeStartRequests();
+ threads_.clear();
+ }
+
+ virtual void DestroyMultithreading() = 0;
+
+ void SetupLoadTest(const ClientConfig& config, size_t num_threads) {
+ // Set up the load distribution based on the number of threads
+ const auto& load = config.load_params();
+
+ std::unique_ptr<RandomDistInterface> random_dist;
+ switch (load.load_case()) {
+ case LoadParams::kClosedLoop:
+ // Closed-loop doesn't use random dist at all
+ break;
+ case LoadParams::kPoisson:
+ random_dist.reset(
+ new ExpDist(load.poisson().offered_load() / num_threads));
+ break;
+ default:
+ GPR_ASSERT(false);
+ }
+
+ // Set closed_loop_ based on whether or not random_dist is set
+ if (!random_dist) {
+ closed_loop_ = true;
+ } else {
+ closed_loop_ = false;
+ // set up interarrival timer according to random dist
+ interarrival_timer_.init(*random_dist, num_threads);
+ const auto now = gpr_now(GPR_CLOCK_MONOTONIC);
+ for (size_t i = 0; i < num_threads; i++) {
+ next_time_.push_back(gpr_time_add(
+ now,
+ gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN)));
+ }
+ }
+ }
+
+ std::function<gpr_timespec()> NextIssuer(int thread_idx) {
+ return closed_loop_ ? std::function<gpr_timespec()>()
+ : std::bind(&Client::NextIssueTime, this, thread_idx);
}
virtual void ThreadFunc(size_t thread_idx, Client::Thread* t) = 0;
@@ -436,6 +439,7 @@ class ClientImpl : public Client {
config.payload_config());
}
virtual ~ClientImpl() {}
+ const RequestType* request() { return &request_; }
void WaitForChannelsToConnect() {
int connect_deadline_seconds = 10;
diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc
index 87889e36dc..4a06325f2b 100644
--- a/test/cpp/qps/client_callback.cc
+++ b/test/cpp/qps/client_callback.cc
@@ -66,13 +66,35 @@ class CallbackClient
config, BenchmarkStubCreator) {
num_threads_ = NumThreads(config);
rpcs_done_ = 0;
- SetupLoadTest(config, num_threads_);
+
+ // Don't divide the fixed load among threads as the user threads
+ // only bootstrap the RPCs
+ SetupLoadTest(config, 1);
total_outstanding_rpcs_ =
config.client_channels() * config.outstanding_rpcs_per_channel();
}
virtual ~CallbackClient() {}
+ /**
+ * The main thread of the benchmark will be waiting on DestroyMultithreading.
+ * Increment the rpcs_done_ variable to signify that the Callback RPC
+ * after thread completion is done. When the last outstanding rpc increments
+ * the counter it should also signal the main thread's conditional variable.
+ */
+ void NotifyMainThreadOfThreadCompletion() {
+ std::lock_guard<std::mutex> l(shutdown_mu_);
+ rpcs_done_++;
+ if (rpcs_done_ == total_outstanding_rpcs_) {
+ shutdown_cv_.notify_one();
+ }
+ }
+
+ gpr_timespec NextRPCIssueTime() {
+ std::lock_guard<std::mutex> l(next_issue_time_mu_);
+ return Client::NextIssueTime(0);
+ }
+
protected:
size_t num_threads_;
size_t total_outstanding_rpcs_;
@@ -93,24 +115,9 @@ class CallbackClient
ThreadFuncImpl(t, thread_idx);
}
- virtual void ScheduleRpc(Thread* t, size_t thread_idx,
- size_t ctx_vector_idx) = 0;
-
- /**
- * The main thread of the benchmark will be waiting on DestroyMultithreading.
- * Increment the rpcs_done_ variable to signify that the Callback RPC
- * after thread completion is done. When the last outstanding rpc increments
- * the counter it should also signal the main thread's conditional variable.
- */
- void NotifyMainThreadOfThreadCompletion() {
- std::lock_guard<std::mutex> l(shutdown_mu_);
- rpcs_done_++;
- if (rpcs_done_ == total_outstanding_rpcs_) {
- shutdown_cv_.notify_one();
- }
- }
-
private:
+ std::mutex next_issue_time_mu_; // Used by next issue time
+
int NumThreads(const ClientConfig& config) {
int num_threads = config.async_client_threads();
if (num_threads <= 0) { // Use dynamic sizing
@@ -149,7 +156,7 @@ class CallbackUnaryClient final : public CallbackClient {
bool ThreadFuncImpl(Thread* t, size_t thread_idx) override {
for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_;
vector_idx += num_threads_) {
- ScheduleRpc(t, thread_idx, vector_idx);
+ ScheduleRpc(t, vector_idx);
}
return true;
}
@@ -157,26 +164,26 @@ class CallbackUnaryClient final : public CallbackClient {
void InitThreadFuncImpl(size_t thread_idx) override { return; }
private:
- void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) override {
+ void ScheduleRpc(Thread* t, size_t vector_idx) {
if (!closed_loop_) {
- gpr_timespec next_issue_time = NextIssueTime(thread_idx);
+ gpr_timespec next_issue_time = NextRPCIssueTime();
// Start an alarm callback to run the internal callback after
// next_issue_time
ctx_[vector_idx]->alarm_.experimental().Set(
- next_issue_time, [this, t, thread_idx, vector_idx](bool ok) {
- IssueUnaryCallbackRpc(t, thread_idx, vector_idx);
+ next_issue_time, [this, t, vector_idx](bool ok) {
+ IssueUnaryCallbackRpc(t, vector_idx);
});
} else {
- IssueUnaryCallbackRpc(t, thread_idx, vector_idx);
+ IssueUnaryCallbackRpc(t, vector_idx);
}
}
- void IssueUnaryCallbackRpc(Thread* t, size_t thread_idx, size_t vector_idx) {
+ void IssueUnaryCallbackRpc(Thread* t, size_t vector_idx) {
GPR_TIMER_SCOPE("CallbackUnaryClient::ThreadFunc", 0);
double start = UsageTimer::Now();
ctx_[vector_idx]->stub_->experimental_async()->UnaryCall(
(&ctx_[vector_idx]->context_), &request_, &ctx_[vector_idx]->response_,
- [this, t, thread_idx, start, vector_idx](grpc::Status s) {
+ [this, t, start, vector_idx](grpc::Status s) {
// Update Histogram with data from the callback run
HistogramEntry entry;
if (s.ok()) {
@@ -193,17 +200,157 @@ class CallbackUnaryClient final : public CallbackClient {
ctx_[vector_idx].reset(
new CallbackClientRpcContext(ctx_[vector_idx]->stub_));
// Schedule a new RPC
- ScheduleRpc(t, thread_idx, vector_idx);
+ ScheduleRpc(t, vector_idx);
}
});
}
};
+class CallbackStreamingClient : public CallbackClient {
+ public:
+ CallbackStreamingClient(const ClientConfig& config)
+ : CallbackClient(config),
+ messages_per_stream_(config.messages_per_stream()) {
+ for (int ch = 0; ch < config.client_channels(); ch++) {
+ for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) {
+ ctx_.emplace_back(
+ new CallbackClientRpcContext(channels_[ch].get_stub()));
+ }
+ }
+ StartThreads(num_threads_);
+ }
+ ~CallbackStreamingClient() {}
+
+ void AddHistogramEntry(double start_, bool ok, Thread* thread_ptr) {
+ // Update Histogram with data from the callback run
+ HistogramEntry entry;
+ if (ok) {
+ entry.set_value((UsageTimer::Now() - start_) * 1e9);
+ }
+ thread_ptr->UpdateHistogram(&entry);
+ }
+
+ int messages_per_stream() { return messages_per_stream_; }
+
+ protected:
+ const int messages_per_stream_;
+};
+
+class CallbackStreamingPingPongClient : public CallbackStreamingClient {
+ public:
+ CallbackStreamingPingPongClient(const ClientConfig& config)
+ : CallbackStreamingClient(config) {}
+ ~CallbackStreamingPingPongClient() {}
+};
+
+class CallbackStreamingPingPongReactor final
+ : public grpc::experimental::ClientBidiReactor<SimpleRequest,
+ SimpleResponse> {
+ public:
+ CallbackStreamingPingPongReactor(
+ CallbackStreamingPingPongClient* client,
+ std::unique_ptr<CallbackClientRpcContext> ctx)
+ : client_(client), ctx_(std::move(ctx)), messages_issued_(0) {}
+
+ void StartNewRpc() {
+ if (client_->ThreadCompleted()) return;
+ start_ = UsageTimer::Now();
+ ctx_->stub_->experimental_async()->StreamingCall(&(ctx_->context_), this);
+ StartWrite(client_->request());
+ StartCall();
+ }
+
+ void OnWriteDone(bool ok) override {
+ if (!ok || client_->ThreadCompleted()) {
+ if (!ok) gpr_log(GPR_ERROR, "Error writing RPC");
+ StartWritesDone();
+ return;
+ }
+ StartRead(&ctx_->response_);
+ }
+
+ void OnReadDone(bool ok) override {
+ client_->AddHistogramEntry(start_, ok, thread_ptr_);
+
+ if (client_->ThreadCompleted() || !ok ||
+ (client_->messages_per_stream() != 0 &&
+ ++messages_issued_ >= client_->messages_per_stream())) {
+ if (!ok) {
+ gpr_log(GPR_ERROR, "Error reading RPC");
+ }
+ StartWritesDone();
+ return;
+ }
+ StartWrite(client_->request());
+ }
+
+ void OnDone(const Status& s) override {
+ if (client_->ThreadCompleted() || !s.ok()) {
+ client_->NotifyMainThreadOfThreadCompletion();
+ return;
+ }
+ ctx_.reset(new CallbackClientRpcContext(ctx_->stub_));
+ ScheduleRpc();
+ }
+
+ void ScheduleRpc() {
+ if (client_->ThreadCompleted()) return;
+
+ if (!client_->IsClosedLoop()) {
+ gpr_timespec next_issue_time = client_->NextRPCIssueTime();
+ // Start an alarm callback to run the internal callback after
+ // next_issue_time
+ ctx_->alarm_.experimental().Set(next_issue_time,
+ [this](bool ok) { StartNewRpc(); });
+ } else {
+ StartNewRpc();
+ }
+ }
+
+ void set_thread_ptr(Client::Thread* ptr) { thread_ptr_ = ptr; }
+
+ CallbackStreamingPingPongClient* client_;
+ std::unique_ptr<CallbackClientRpcContext> ctx_;
+ Client::Thread* thread_ptr_; // Needed to update histogram entries
+ double start_; // Track message start time
+ int messages_issued_; // Messages issued by this stream
+};
+
+class CallbackStreamingPingPongClientImpl final
+ : public CallbackStreamingPingPongClient {
+ public:
+ CallbackStreamingPingPongClientImpl(const ClientConfig& config)
+ : CallbackStreamingPingPongClient(config) {
+ for (size_t i = 0; i < total_outstanding_rpcs_; i++)
+ reactor_.emplace_back(
+ new CallbackStreamingPingPongReactor(this, std::move(ctx_[i])));
+ }
+ ~CallbackStreamingPingPongClientImpl() {}
+
+ bool ThreadFuncImpl(Client::Thread* t, size_t thread_idx) override {
+ for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_;
+ vector_idx += num_threads_) {
+ reactor_[vector_idx]->set_thread_ptr(t);
+ reactor_[vector_idx]->ScheduleRpc();
+ }
+ return true;
+ }
+
+ void InitThreadFuncImpl(size_t thread_idx) override {}
+
+ private:
+ std::vector<std::unique_ptr<CallbackStreamingPingPongReactor>> reactor_;
+};
+
+// TODO(mhaidry) : Implement Streaming from client, server and both ways
+
std::unique_ptr<Client> CreateCallbackClient(const ClientConfig& config) {
switch (config.rpc_type()) {
case UNARY:
return std::unique_ptr<Client>(new CallbackUnaryClient(config));
case STREAMING:
+ return std::unique_ptr<Client>(
+ new CallbackStreamingPingPongClientImpl(config));
case STREAMING_FROM_CLIENT:
case STREAMING_FROM_SERVER:
case STREAMING_BOTH_WAYS: