diff options
author | Vijay Pai <vpai@google.com> | 2017-11-30 10:16:21 -0800 |
---|---|---|
committer | Vijay Pai <vpai@google.com> | 2017-12-04 00:57:58 -0800 |
commit | 6389457ed2ef7af141a3723c3eb82a1cdc81293b (patch) | |
tree | f3996840252916596572f1758e776a4d4954075c | |
parent | d1945788c5ef92c3ed772da945e4f74a1681381a (diff) |
Adjust stream cancellation point and fix races in sync client
-rw-r--r-- | test/cpp/qps/client_sync.cc | 193 |
1 files changed, 121 insertions, 72 deletions
diff --git a/test/cpp/qps/client_sync.cc b/test/cpp/qps/client_sync.cc index 9f20b148eb..cb0945b05b 100644 --- a/test/cpp/qps/client_sync.cc +++ b/test/cpp/qps/client_sync.cc @@ -62,11 +62,13 @@ class SynchronousClient virtual ~SynchronousClient(){}; - virtual void InitThreadFuncImpl(size_t thread_idx) = 0; + virtual bool InitThreadFuncImpl(size_t thread_idx) = 0; virtual bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) = 0; void ThreadFunc(size_t thread_idx, Thread* t) override { - InitThreadFuncImpl(thread_idx); + if (!InitThreadFuncImpl(thread_idx)) { + return; + } for (;;) { // run the loop body HistogramEntry entry; @@ -109,9 +111,6 @@ class SynchronousClient size_t num_threads_; std::vector<SimpleResponse> responses_; - - private: - void DestroyMultithreading() override final { EndThreads(); } }; class SynchronousUnaryClient final : public SynchronousClient { @@ -122,7 +121,7 @@ class SynchronousUnaryClient final : public SynchronousClient { } ~SynchronousUnaryClient() {} - void InitThreadFuncImpl(size_t thread_idx) override {} + bool InitThreadFuncImpl(size_t thread_idx) override { return true; } bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { if (!WaitToIssue(thread_idx)) { @@ -140,6 +139,9 @@ class SynchronousUnaryClient final : public SynchronousClient { entry->set_status(s.error_code()); return true; } + + private: + void DestroyMultithreading() override final { EndThreads(); } }; template <class StreamType> @@ -149,31 +151,30 @@ class SynchronousStreamingClient : public SynchronousClient { : SynchronousClient(config), context_(num_threads_), stream_(num_threads_), + stream_mu_(num_threads_), + shutdown_(num_threads_), messages_per_stream_(config.messages_per_stream()), messages_issued_(num_threads_) { StartThreads(num_threads_); } virtual ~SynchronousStreamingClient() { - std::vector<std::thread> cleanup_threads; - for (size_t i = 0; i < num_threads_; i++) { - cleanup_threads.emplace_back([this, i]() { - auto stream = &stream_[i]; - if (*stream) { - // forcibly cancel the streams, then finish - context_[i].TryCancel(); - (*stream)->Finish().IgnoreError(); - // don't log any error message on !ok since this was canceled - } - }); - } - for (auto& th : cleanup_threads) { - th.join(); - } + OnAllStreams([](ClientContext* ctx, StreamType* s) -> bool { + // don't log any kind of error since we might have canceled it + s->Finish().IgnoreError(); + return true; + }); } protected: std::vector<grpc::ClientContext> context_; std::vector<std::unique_ptr<StreamType>> stream_; + // stream_mu_ is only needed when changing an element of stream_ or context_ + std::vector<std::mutex> stream_mu_; + struct Bool { + bool val; + Bool() : val(false) {} + }; + std::vector<Bool> shutdown_; const int messages_per_stream_; std::vector<int> messages_issued_; @@ -185,9 +186,34 @@ class SynchronousStreamingClient : public SynchronousClient { gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", thread_idx, s.error_message().c_str()); } + // Lock the stream_mu_ now because the client context could change + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); context_[thread_idx].~ClientContext(); new (&context_[thread_idx]) ClientContext(); } + void OnAllStreams(std::function<bool(ClientContext*, StreamType*)> cleaner) { + std::vector<std::thread> cleanup_threads; + for (size_t i = 0; i < num_threads_; i++) { + cleanup_threads.emplace_back([this, i, cleaner]() { + std::lock_guard<std::mutex> l(stream_mu_[i]); + if (stream_[i]) { + shutdown_[i].val = cleaner(&context_[i], stream_[i].get()); + } + }); + } + for (auto& th : cleanup_threads) { + th.join(); + } + } + + private: + void DestroyMultithreading() override final { + OnAllStreams([](ClientContext* ctx, StreamType* s) -> bool { + ctx->TryCancel(); + return true; + }); + EndThreads(); + } }; class SynchronousStreamingPingPongClient final @@ -197,24 +223,24 @@ class SynchronousStreamingPingPongClient final SynchronousStreamingPingPongClient(const ClientConfig& config) : SynchronousStreamingClient(config) {} ~SynchronousStreamingPingPongClient() { - std::vector<std::thread> cleanup_threads; - for (size_t i = 0; i < num_threads_; i++) { - cleanup_threads.emplace_back([this, i]() { - auto stream = &stream_[i]; - if (*stream) { - (*stream)->WritesDone(); - } - }); - } - for (auto& th : cleanup_threads) { - th.join(); - } + OnAllStreams( + [](ClientContext* ctx, + grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>* s) -> bool { + s->WritesDone(); + return true; + }); } - void InitThreadFuncImpl(size_t thread_idx) override { + bool InitThreadFuncImpl(size_t thread_idx) override { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + } else { + return false; + } messages_issued_[thread_idx] = 0; + return true; } bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { @@ -239,7 +265,13 @@ class SynchronousStreamingPingPongClient final stream_[thread_idx]->WritesDone(); FinishStream(entry, thread_idx); auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); + } else { + stream_[thread_idx].reset(); + return false; + } messages_issued_[thread_idx] = 0; return true; } @@ -251,25 +283,24 @@ class SynchronousStreamingFromClientClient final SynchronousStreamingFromClientClient(const ClientConfig& config) : SynchronousStreamingClient(config), last_issue_(num_threads_) {} ~SynchronousStreamingFromClientClient() { - std::vector<std::thread> cleanup_threads; - for (size_t i = 0; i < num_threads_; i++) { - cleanup_threads.emplace_back([this, i]() { - auto stream = &stream_[i]; - if (*stream) { - (*stream)->WritesDone(); - } - }); - } - for (auto& th : cleanup_threads) { - th.join(); - } + OnAllStreams( + [](ClientContext* ctx, grpc::ClientWriter<SimpleRequest>* s) -> bool { + s->WritesDone(); + return true; + }); } - void InitThreadFuncImpl(size_t thread_idx) override { + bool InitThreadFuncImpl(size_t thread_idx) override { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], - &responses_[thread_idx]); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + } else { + return false; + } last_issue_[thread_idx] = UsageTimer::Now(); + return true; } bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { @@ -287,8 +318,14 @@ class SynchronousStreamingFromClientClient final stream_[thread_idx]->WritesDone(); FinishStream(entry, thread_idx); auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], - &responses_[thread_idx]); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + } else { + stream_[thread_idx].reset(); + return false; + } return true; } @@ -301,11 +338,17 @@ class SynchronousStreamingFromServerClient final public: SynchronousStreamingFromServerClient(const ClientConfig& config) : SynchronousStreamingClient(config), last_recv_(num_threads_) {} - void InitThreadFuncImpl(size_t thread_idx) override { + bool InitThreadFuncImpl(size_t thread_idx) override { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = - stub->StreamingFromServer(&context_[thread_idx], request_); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + } else { + return false; + } last_recv_[thread_idx] = UsageTimer::Now(); + return true; } bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { GPR_TIMER_SCOPE("SynchronousStreamingFromServerClient::ThreadFunc", 0); @@ -317,8 +360,14 @@ class SynchronousStreamingFromServerClient final } FinishStream(entry, thread_idx); auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = - stub->StreamingFromServer(&context_[thread_idx], request_); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + } else { + stream_[thread_idx].reset(); + return false; + } return true; } @@ -333,23 +382,23 @@ class SynchronousStreamingBothWaysClient final SynchronousStreamingBothWaysClient(const ClientConfig& config) : SynchronousStreamingClient(config) {} ~SynchronousStreamingBothWaysClient() { - std::vector<std::thread> cleanup_threads; - for (size_t i = 0; i < num_threads_; i++) { - cleanup_threads.emplace_back([this, i]() { - auto stream = &stream_[i]; - if (*stream) { - (*stream)->WritesDone(); - } - }); - } - for (auto& th : cleanup_threads) { - th.join(); - } + OnAllStreams( + [](ClientContext* ctx, + grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>* s) -> bool { + s->WritesDone(); + return true; + }); } - void InitThreadFuncImpl(size_t thread_idx) override { + bool InitThreadFuncImpl(size_t thread_idx) override { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]); + std::lock_guard<std::mutex> l(stream_mu_[thread_idx]); + if (!shutdown_[thread_idx].val) { + stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]); + } else { + return false; + } + return true; } bool ThreadFuncImpl(HistogramEntry* entry, size_t thread_idx) override { // TODO (vjpai): Do this |