diff options
author | Vijay Pai <vpai@google.com> | 2017-05-05 11:24:07 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-05 11:24:07 -0700 |
commit | ea4adc1508cfd8325b11f811926e68b1beeb1db9 (patch) | |
tree | dd8ae4629164ba7ba8e78da4e2959d8f1d947063 /test | |
parent | 6815e414a4dbca4d0d4dd62b5ec3c6faa60c9bb9 (diff) | |
parent | eea8cf0fe3a836b78e9ba122a01f6f1552ad8402 (diff) |
Merge pull request #10313 from vjpai/onesided
Add QPS tests for 1-sided unconstrained streaming
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/qps/client.h | 7 | ||||
-rw-r--r-- | test/cpp/qps/client_async.cc | 296 | ||||
-rw-r--r-- | test/cpp/qps/client_sync.cc | 220 | ||||
-rw-r--r-- | test/cpp/qps/qps_worker.cc | 12 | ||||
-rw-r--r-- | test/cpp/qps/server_async.cc | 190 | ||||
-rw-r--r-- | test/cpp/qps/server_sync.cc | 111 |
6 files changed, 748 insertions, 88 deletions
diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 25a19a5a74..c3197eb622 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -443,11 +443,8 @@ class ClientImpl : public Client { create_stub_; }; -std::unique_ptr<Client> CreateSynchronousUnaryClient(const ClientConfig& args); -std::unique_ptr<Client> CreateSynchronousStreamingClient( - const ClientConfig& args); -std::unique_ptr<Client> CreateAsyncUnaryClient(const ClientConfig& args); -std::unique_ptr<Client> CreateAsyncStreamingClient(const ClientConfig& args); +std::unique_ptr<Client> CreateSynchronousClient(const ClientConfig& args); +std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& args); std::unique_ptr<Client> CreateGenericAsyncStreamingClient( const ClientConfig& args); diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc index 29a79e7343..01856f714a 100644 --- a/test/cpp/qps/client_async.cc +++ b/test/cpp/qps/client_async.cc @@ -313,9 +313,9 @@ class AsyncUnaryClient final }; template <class RequestType, class ResponseType> -class ClientRpcContextStreamingImpl : public ClientRpcContext { +class ClientRpcContextStreamingPingPongImpl : public ClientRpcContext { public: - ClientRpcContextStreamingImpl( + ClientRpcContextStreamingPingPongImpl( BenchmarkService::Stub* stub, const RequestType& req, std::function<gpr_timespec()> next_issue, std::function<std::unique_ptr< @@ -333,7 +333,7 @@ class ClientRpcContextStreamingImpl : public ClientRpcContext { callback_(on_done), next_issue_(next_issue), start_req_(start_req) {} - ~ClientRpcContextStreamingImpl() override {} + ~ClientRpcContextStreamingPingPongImpl() override {} void Start(CompletionQueue* cq, const ClientConfig& config) override { StartInternal(cq, config.messages_per_stream()); } @@ -394,8 +394,8 @@ class ClientRpcContextStreamingImpl : public ClientRpcContext { } } void StartNewClone(CompletionQueue* cq) override { - auto* clone = new ClientRpcContextStreamingImpl(stub_, req_, next_issue_, - start_req_, callback_); + auto* clone = new ClientRpcContextStreamingPingPongImpl( + stub_, req_, next_issue_, start_req_, callback_); clone->StartInternal(cq, messages_per_stream_); } @@ -434,23 +434,23 @@ class ClientRpcContextStreamingImpl : public ClientRpcContext { void StartInternal(CompletionQueue* cq, int messages_per_stream) { cq_ = cq; - next_state_ = State::STREAM_IDLE; - stream_ = start_req_(stub_, &context_, cq, ClientRpcContext::tag(this)); messages_per_stream_ = messages_per_stream; messages_issued_ = 0; + next_state_ = State::STREAM_IDLE; + stream_ = start_req_(stub_, &context_, cq, ClientRpcContext::tag(this)); } }; -class AsyncStreamingClient final +class AsyncStreamingPingPongClient final : public AsyncClient<BenchmarkService::Stub, SimpleRequest> { public: - explicit AsyncStreamingClient(const ClientConfig& config) + explicit AsyncStreamingPingPongClient(const ClientConfig& config) : AsyncClient<BenchmarkService::Stub, SimpleRequest>( config, SetupCtx, BenchmarkStubCreator) { StartThreads(num_async_threads_); } - ~AsyncStreamingClient() override {} + ~AsyncStreamingPingPongClient() override {} private: static void CheckDone(grpc::Status s, SimpleResponse* response) {} @@ -464,9 +464,250 @@ class AsyncStreamingClient final static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, std::function<gpr_timespec()> next_issue, const SimpleRequest& req) { - return new ClientRpcContextStreamingImpl<SimpleRequest, SimpleResponse>( - stub, req, next_issue, AsyncStreamingClient::StartReq, - AsyncStreamingClient::CheckDone); + return new ClientRpcContextStreamingPingPongImpl<SimpleRequest, + SimpleResponse>( + stub, req, next_issue, AsyncStreamingPingPongClient::StartReq, + AsyncStreamingPingPongClient::CheckDone); + } +}; + +template <class RequestType, class ResponseType> +class ClientRpcContextStreamingFromClientImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingFromClientImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function<gpr_timespec()> next_issue, + std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>( + BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*, + CompletionQueue*, void*)> + start_req, + std::function<void(grpc::Status, ResponseType*)> on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(on_done), + next_issue_(next_issue), + start_req_(start_req) {} + ~ClientRpcContextStreamingFromClientImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + StartInternal(cq); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!next_issue_) { // ready to issue + next_state_ = State::READY_TO_WRITE; + } else { + next_state_ = State::WAIT; + } + break; // loop around, don't return + case State::WAIT: + alarm_.reset( + new Alarm(cq_, next_issue_(), ClientRpcContext::tag(this))); + next_state_ = State::READY_TO_WRITE; + return true; + case State::READY_TO_WRITE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::WRITE_DONE; + stream_->Write(req_, ClientRpcContext::tag(this)); + return true; + case State::WRITE_DONE: + if (!ok) { + return false; + } + entry->set_value((UsageTimer::Now() - start_) * 1e9); + next_state_ = State::STREAM_IDLE; + break; // loop around + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextStreamingFromClientImpl( + stub_, req_, next_issue_, start_req_, callback_); + clone->StartInternal(cq); + } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr<Alarm> alarm_; + RequestType req_; + ResponseType response_; + enum State { + INVALID, + STREAM_IDLE, + WAIT, + READY_TO_WRITE, + WRITE_DONE, + }; + State next_state_; + std::function<void(grpc::Status, ResponseType*)> callback_; + std::function<gpr_timespec()> next_issue_; + std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>( + BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*, + CompletionQueue*, void*)> + start_req_; + grpc::Status status_; + double start_; + std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> stream_; + + void StartInternal(CompletionQueue* cq) { + cq_ = cq; + stream_ = start_req_(stub_, &context_, &response_, cq, + ClientRpcContext::tag(this)); + next_state_ = State::STREAM_IDLE; + } +}; + +class AsyncStreamingFromClientClient final + : public AsyncClient<BenchmarkService::Stub, SimpleRequest> { + public: + explicit AsyncStreamingFromClientClient(const ClientConfig& config) + : AsyncClient<BenchmarkService::Stub, SimpleRequest>( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + + ~AsyncStreamingFromClientClient() override {} + + private: + static void CheckDone(grpc::Status s, SimpleResponse* response) {} + static std::unique_ptr<grpc::ClientAsyncWriter<SimpleRequest>> StartReq( + BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + SimpleResponse* resp, CompletionQueue* cq, void* tag) { + auto stream = stub->AsyncStreamingFromClient(ctx, resp, cq, tag); + return stream; + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function<gpr_timespec()> next_issue, + const SimpleRequest& req) { + return new ClientRpcContextStreamingFromClientImpl<SimpleRequest, + SimpleResponse>( + stub, req, next_issue, AsyncStreamingFromClientClient::StartReq, + AsyncStreamingFromClientClient::CheckDone); + } +}; + +template <class RequestType, class ResponseType> +class ClientRpcContextStreamingFromServerImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingFromServerImpl( + BenchmarkService::Stub* stub, const RequestType& req, + std::function<gpr_timespec()> next_issue, + std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*, void*)> + start_req, + std::function<void(grpc::Status, ResponseType*)> on_done) + : context_(), + stub_(stub), + cq_(nullptr), + req_(req), + response_(), + next_state_(State::INVALID), + callback_(on_done), + next_issue_(next_issue), + start_req_(start_req) {} + ~ClientRpcContextStreamingFromServerImpl() override {} + void Start(CompletionQueue* cq, const ClientConfig& config) override { + StartInternal(cq); + } + bool RunNextState(bool ok, HistogramEntry* entry) override { + while (true) { + switch (next_state_) { + case State::STREAM_IDLE: + if (!ok) { + return false; + } + start_ = UsageTimer::Now(); + next_state_ = State::READ_DONE; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + case State::READ_DONE: + if (!ok) { + return false; + } + entry->set_value((UsageTimer::Now() - start_) * 1e9); + callback_(status_, &response_); + next_state_ = State::STREAM_IDLE; + break; // loop around + default: + GPR_ASSERT(false); + return false; + } + } + } + void StartNewClone(CompletionQueue* cq) override { + auto* clone = new ClientRpcContextStreamingFromServerImpl( + stub_, req_, next_issue_, start_req_, callback_); + clone->StartInternal(cq); + } + + private: + grpc::ClientContext context_; + BenchmarkService::Stub* stub_; + CompletionQueue* cq_; + std::unique_ptr<Alarm> alarm_; + RequestType req_; + ResponseType response_; + enum State { INVALID, STREAM_IDLE, READ_DONE }; + State next_state_; + std::function<void(grpc::Status, ResponseType*)> callback_; + std::function<gpr_timespec()> next_issue_; + std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>( + BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&, + CompletionQueue*, void*)> + start_req_; + grpc::Status status_; + double start_; + std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> stream_; + + void StartInternal(CompletionQueue* cq) { + // TODO(vjpai): Add support to rate-pace this + cq_ = cq; + next_state_ = State::STREAM_IDLE; + stream_ = + start_req_(stub_, &context_, req_, cq, ClientRpcContext::tag(this)); + } +}; + +class AsyncStreamingFromServerClient final + : public AsyncClient<BenchmarkService::Stub, SimpleRequest> { + public: + explicit AsyncStreamingFromServerClient(const ClientConfig& config) + : AsyncClient<BenchmarkService::Stub, SimpleRequest>( + config, SetupCtx, BenchmarkStubCreator) { + StartThreads(num_async_threads_); + } + + ~AsyncStreamingFromServerClient() override {} + + private: + static void CheckDone(grpc::Status s, SimpleResponse* response) {} + static std::unique_ptr<grpc::ClientAsyncReader<SimpleResponse>> StartReq( + BenchmarkService::Stub* stub, grpc::ClientContext* ctx, + const SimpleRequest& req, CompletionQueue* cq, void* tag) { + auto stream = stub->AsyncStreamingFromServer(ctx, req, cq, tag); + return stream; + }; + static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub, + std::function<gpr_timespec()> next_issue, + const SimpleRequest& req) { + return new ClientRpcContextStreamingFromServerImpl<SimpleRequest, + SimpleResponse>( + stub, req, next_issue, AsyncStreamingFromServerClient::StartReq, + AsyncStreamingFromServerClient::CheckDone); } }; @@ -591,11 +832,11 @@ class ClientRpcContextGenericStreamingImpl : public ClientRpcContext { cq_ = cq; const grpc::string kMethodName( "/grpc.testing.BenchmarkService/StreamingCall"); + messages_per_stream_ = messages_per_stream; + messages_issued_ = 0; next_state_ = State::STREAM_IDLE; stream_ = start_req_(stub_, &context_, kMethodName, cq, ClientRpcContext::tag(this)); - messages_per_stream_ = messages_per_stream; - messages_issued_ = 0; } }; @@ -632,11 +873,26 @@ class GenericAsyncStreamingClient final } }; -std::unique_ptr<Client> CreateAsyncUnaryClient(const ClientConfig& args) { - return std::unique_ptr<Client>(new AsyncUnaryClient(args)); -} -std::unique_ptr<Client> CreateAsyncStreamingClient(const ClientConfig& args) { - return std::unique_ptr<Client>(new AsyncStreamingClient(args)); +std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& config) { + switch (config.rpc_type()) { + case UNARY: + return std::unique_ptr<Client>(new AsyncUnaryClient(config)); + case STREAMING: + return std::unique_ptr<Client>(new AsyncStreamingPingPongClient(config)); + case STREAMING_FROM_CLIENT: + return std::unique_ptr<Client>( + new AsyncStreamingFromClientClient(config)); + case STREAMING_FROM_SERVER: + return std::unique_ptr<Client>( + new AsyncStreamingFromServerClient(config)); + case STREAMING_BOTH_WAYS: + // TODO(vjpai): Implement this + assert(false); + return nullptr; + default: + assert(false); + return nullptr; + } } std::unique_ptr<Client> CreateGenericAsyncStreamingClient( const ClientConfig& args) { diff --git a/test/cpp/qps/client_sync.cc b/test/cpp/qps/client_sync.cc index f8ce2cccbe..9075033bd4 100644 --- a/test/cpp/qps/client_sync.cc +++ b/test/cpp/qps/client_sync.cc @@ -137,7 +137,8 @@ class SynchronousUnaryClient final : public SynchronousClient { } }; -class SynchronousStreamingClient final : public SynchronousClient { +template <class StreamType> +class SynchronousStreamingClient : public SynchronousClient { public: SynchronousStreamingClient(const ClientConfig& config) : SynchronousClient(config), @@ -145,30 +146,69 @@ class SynchronousStreamingClient final : public SynchronousClient { stream_(num_threads_), messages_per_stream_(config.messages_per_stream()), messages_issued_(num_threads_) { + StartThreads(num_threads_); + } + virtual ~SynchronousStreamingClient() { + std::vector<std::thread> cleanup_threads; + for (size_t i = 0; i < num_threads_; i++) { + cleanup_threads.emplace_back([this, i]() { + auto stream = &stream_[i]; + if (*stream) { + // forcibly cancel the streams, then finish + context_[i].TryCancel(); + (*stream)->Finish(); + // don't log any error message on !ok since this was canceled + } + }); + } + for (auto& th : cleanup_threads) { + th.join(); + } + } + + protected: + std::vector<grpc::ClientContext> context_; + std::vector<std::unique_ptr<StreamType>> stream_; + const int messages_per_stream_; + std::vector<int> messages_issued_; + + void FinishStream(HistogramEntry* entry, size_t thread_idx) { + Status s = stream_[thread_idx]->Finish(); + // don't set the value since the stream is failed and shouldn't be timed + entry->set_status(s.error_code()); + if (!s.ok()) { + gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", thread_idx, + s.error_message().c_str()); + } + context_[thread_idx].~ClientContext(); + new (&context_[thread_idx]) ClientContext(); + } +}; + +class SynchronousStreamingPingPongClient final + : public SynchronousStreamingClient< + grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>> { + public: + SynchronousStreamingPingPongClient(const ClientConfig& config) + : SynchronousStreamingClient(config) { for (size_t thread_idx = 0; thread_idx < num_threads_; thread_idx++) { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); messages_issued_[thread_idx] = 0; } - StartThreads(num_threads_); } - ~SynchronousStreamingClient() { + ~SynchronousStreamingPingPongClient() { std::vector<std::thread> cleanup_threads; for (size_t i = 0; i < num_threads_; i++) { cleanup_threads.emplace_back([this, i]() { auto stream = &stream_[i]; if (*stream) { (*stream)->WritesDone(); - Status s = (*stream)->Finish(); - if (!s.ok()) { - gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", i, - s.error_message().c_str()); - } } }); } - for (size_t i = 0; i < num_threads_; i++) { - cleanup_threads[i].join(); + for (auto& th : cleanup_threads) { + th.join(); } } @@ -176,7 +216,7 @@ class SynchronousStreamingClient final : public SynchronousClient { if (!WaitToIssue(thread_idx)) { return true; } - GPR_TIMER_SCOPE("SynchronousStreamingClient::ThreadFunc", 0); + GPR_TIMER_SCOPE("SynchronousStreamingPingPongClient::ThreadFunc", 0); double start = UsageTimer::Now(); if (stream_[thread_idx]->Write(request_) && stream_[thread_idx]->Read(&responses_[thread_idx])) { @@ -192,40 +232,148 @@ class SynchronousStreamingClient final : public SynchronousClient { } } stream_[thread_idx]->WritesDone(); - Status s = stream_[thread_idx]->Finish(); - // don't set the value since this is either a failure (shouldn't be timed) - // or a stream-end (already has been timed) - entry->set_status(s.error_code()); - if (!s.ok()) { - gpr_log(GPR_ERROR, "Stream %" PRIuPTR " received an error %s", thread_idx, - s.error_message().c_str()); - } + FinishStream(entry, thread_idx); auto* stub = channels_[thread_idx % channels_.size()].get_stub(); - context_[thread_idx].~ClientContext(); - new (&context_[thread_idx]) ClientContext(); stream_[thread_idx] = stub->StreamingCall(&context_[thread_idx]); messages_issued_[thread_idx] = 0; return true; } +}; + +class SynchronousStreamingFromClientClient final + : public SynchronousStreamingClient<grpc::ClientWriter<SimpleRequest>> { + public: + SynchronousStreamingFromClientClient(const ClientConfig& config) + : SynchronousStreamingClient(config), last_issue_(num_threads_) { + for (size_t thread_idx = 0; thread_idx < num_threads_; thread_idx++) { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + last_issue_[thread_idx] = UsageTimer::Now(); + } + } + ~SynchronousStreamingFromClientClient() { + std::vector<std::thread> cleanup_threads; + for (size_t i = 0; i < num_threads_; i++) { + cleanup_threads.emplace_back([this, i]() { + auto stream = &stream_[i]; + if (*stream) { + (*stream)->WritesDone(); + } + }); + } + for (auto& th : cleanup_threads) { + th.join(); + } + } + + bool ThreadFunc(HistogramEntry* entry, size_t thread_idx) override { + // Figure out how to make histogram sensible if this is rate-paced + if (!WaitToIssue(thread_idx)) { + return true; + } + GPR_TIMER_SCOPE("SynchronousStreamingFromClientClient::ThreadFunc", 0); + if (stream_[thread_idx]->Write(request_)) { + double now = UsageTimer::Now(); + entry->set_value((now - last_issue_[thread_idx]) * 1e9); + last_issue_[thread_idx] = now; + return true; + } + stream_[thread_idx]->WritesDone(); + FinishStream(entry, thread_idx); + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_[thread_idx] = stub->StreamingFromClient(&context_[thread_idx], + &responses_[thread_idx]); + return true; + } private: - // These are both conceptually std::vector but cannot be for old compilers - // that expect contained classes to support copy constructors - std::vector<grpc::ClientContext> context_; - std::vector< - std::unique_ptr<grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>>> - stream_; - const int messages_per_stream_; - std::vector<int> messages_issued_; + std::vector<double> last_issue_; }; -std::unique_ptr<Client> CreateSynchronousUnaryClient( - const ClientConfig& config) { - return std::unique_ptr<Client>(new SynchronousUnaryClient(config)); -} -std::unique_ptr<Client> CreateSynchronousStreamingClient( - const ClientConfig& config) { - return std::unique_ptr<Client>(new SynchronousStreamingClient(config)); +class SynchronousStreamingFromServerClient final + : public SynchronousStreamingClient<grpc::ClientReader<SimpleResponse>> { + public: + SynchronousStreamingFromServerClient(const ClientConfig& config) + : SynchronousStreamingClient(config), last_recv_(num_threads_) { + for (size_t thread_idx = 0; thread_idx < num_threads_; thread_idx++) { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + last_recv_[thread_idx] = UsageTimer::Now(); + } + } + bool ThreadFunc(HistogramEntry* entry, size_t thread_idx) override { + GPR_TIMER_SCOPE("SynchronousStreamingFromServerClient::ThreadFunc", 0); + if (stream_[thread_idx]->Read(&responses_[thread_idx])) { + double now = UsageTimer::Now(); + entry->set_value((now - last_recv_[thread_idx]) * 1e9); + last_recv_[thread_idx] = now; + return true; + } + FinishStream(entry, thread_idx); + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_[thread_idx] = + stub->StreamingFromServer(&context_[thread_idx], request_); + return true; + } + + private: + std::vector<double> last_recv_; +}; + +class SynchronousStreamingBothWaysClient final + : public SynchronousStreamingClient< + grpc::ClientReaderWriter<SimpleRequest, SimpleResponse>> { + public: + SynchronousStreamingBothWaysClient(const ClientConfig& config) + : SynchronousStreamingClient(config) { + for (size_t thread_idx = 0; thread_idx < num_threads_; thread_idx++) { + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_[thread_idx] = stub->StreamingBothWays(&context_[thread_idx]); + } + } + ~SynchronousStreamingBothWaysClient() { + std::vector<std::thread> cleanup_threads; + for (size_t i = 0; i < num_threads_; i++) { + cleanup_threads.emplace_back([this, i]() { + auto stream = &stream_[i]; + if (*stream) { + (*stream)->WritesDone(); + } + }); + } + for (auto& th : cleanup_threads) { + th.join(); + } + } + + bool ThreadFunc(HistogramEntry* entry, size_t thread_idx) override { + // TODO (vjpai): Do this + return true; + } +}; + +std::unique_ptr<Client> CreateSynchronousClient(const ClientConfig& config) { + switch (config.rpc_type()) { + case UNARY: + return std::unique_ptr<Client>(new SynchronousUnaryClient(config)); + case STREAMING: + return std::unique_ptr<Client>( + new SynchronousStreamingPingPongClient(config)); + case STREAMING_FROM_CLIENT: + return std::unique_ptr<Client>( + new SynchronousStreamingFromClientClient(config)); + case STREAMING_FROM_SERVER: + return std::unique_ptr<Client>( + new SynchronousStreamingFromServerClient(config)); + case STREAMING_BOTH_WAYS: + return std::unique_ptr<Client>( + new SynchronousStreamingBothWaysClient(config)); + default: + assert(false); + return nullptr; + } } } // namespace testing diff --git a/test/cpp/qps/qps_worker.cc b/test/cpp/qps/qps_worker.cc index d437920e68..92408974bd 100644 --- a/test/cpp/qps/qps_worker.cc +++ b/test/cpp/qps/qps_worker.cc @@ -68,15 +68,11 @@ static std::unique_ptr<Client> CreateClient(const ClientConfig& config) { switch (config.client_type()) { case ClientType::SYNC_CLIENT: - return (config.rpc_type() == RpcType::UNARY) - ? CreateSynchronousUnaryClient(config) - : CreateSynchronousStreamingClient(config); + return CreateSynchronousClient(config); case ClientType::ASYNC_CLIENT: - return (config.rpc_type() == RpcType::UNARY) - ? CreateAsyncUnaryClient(config) - : (config.payload_config().has_bytebuf_params() - ? CreateGenericAsyncStreamingClient(config) - : CreateAsyncStreamingClient(config)); + return config.payload_config().has_bytebuf_params() + ? CreateGenericAsyncStreamingClient(config) + : CreateAsyncClient(config); default: abort(); } diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index b499b82091..84f1579c2f 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -71,6 +71,18 @@ class AsyncQpsServerTest final : public grpc::testing::Server { ServerAsyncReaderWriter<ResponseType, RequestType> *, CompletionQueue *, ServerCompletionQueue *, void *)> request_streaming_function, + std::function<void(ServiceType *, ServerContextType *, + ServerAsyncReader<ResponseType, RequestType> *, + CompletionQueue *, ServerCompletionQueue *, void *)> + request_streaming_from_client_function, + std::function<void(ServiceType *, ServerContextType *, RequestType *, + ServerAsyncWriter<ResponseType> *, CompletionQueue *, + ServerCompletionQueue *, void *)> + request_streaming_from_server_function, + std::function<void(ServiceType *, ServerContextType *, + ServerAsyncReaderWriter<ResponseType, RequestType> *, + CompletionQueue *, ServerCompletionQueue *, void *)> + request_streaming_both_ways_function, std::function<grpc::Status(const PayloadConfig &, const RequestType *, ResponseType *)> process_rpc) @@ -107,7 +119,7 @@ class AsyncQpsServerTest final : public grpc::testing::Server { std::bind(process_rpc, config.payload_config(), std::placeholders::_1, std::placeholders::_2); - for (int i = 0; i < 15000; i++) { + for (int i = 0; i < 5000; i++) { for (int j = 0; j < num_threads; j++) { if (request_unary_function) { auto request_unary = std::bind( @@ -125,6 +137,26 @@ class AsyncQpsServerTest final : public grpc::testing::Server { contexts_.emplace_back(new ServerRpcContextStreamingImpl( request_streaming, process_rpc_bound)); } + if (request_streaming_from_client_function) { + auto request_streaming_from_client = std::bind( + request_streaming_from_client_function, &async_service_, + std::placeholders::_1, std::placeholders::_2, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_3); + contexts_.emplace_back(new ServerRpcContextStreamingFromClientImpl( + request_streaming_from_client, process_rpc_bound)); + } + if (request_streaming_from_server_function) { + auto request_streaming_from_server = + std::bind(request_streaming_from_server_function, &async_service_, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, srv_cqs_[j].get(), + srv_cqs_[j].get(), std::placeholders::_4); + contexts_.emplace_back(new ServerRpcContextStreamingFromServerImpl( + request_streaming_from_server, process_rpc_bound)); + } + if (request_streaming_both_ways_function) { + // TODO(vjpai): Add this code + } } } @@ -289,8 +321,8 @@ class AsyncQpsServerTest final : public grpc::testing::Server { if (!ok) { return false; } - stream_.Read(&req_, AsyncQpsServerTest::tag(this)); next_state_ = &ServerRpcContextStreamingImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); return true; } @@ -300,23 +332,23 @@ class AsyncQpsServerTest final : public grpc::testing::Server { // Call the RPC processing function grpc::Status status = invoke_method_(&req_, &response_); // initiate the write - stream_.Write(response_, AsyncQpsServerTest::tag(this)); next_state_ = &ServerRpcContextStreamingImpl::write_done; + stream_.Write(response_, AsyncQpsServerTest::tag(this)); } else { // client has sent writes done // finish the stream - stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); next_state_ = &ServerRpcContextStreamingImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); } return true; } bool write_done(bool ok) { // now go back and get another streaming read! if (ok) { - stream_.Read(&req_, AsyncQpsServerTest::tag(this)); next_state_ = &ServerRpcContextStreamingImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); } else { - stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); next_state_ = &ServerRpcContextStreamingImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); } return true; } @@ -335,6 +367,146 @@ class AsyncQpsServerTest final : public grpc::testing::Server { grpc::ServerAsyncReaderWriter<ResponseType, RequestType> stream_; }; + class ServerRpcContextStreamingFromClientImpl final + : public ServerRpcContext { + public: + ServerRpcContextStreamingFromClientImpl( + std::function<void(ServerContextType *, + grpc::ServerAsyncReader<ResponseType, RequestType> *, + void *)> + request_method, + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextStreamingFromClientImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingFromClientImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + stream_ = + grpc::ServerAsyncReader<ResponseType, RequestType>(srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextStreamingFromClientImpl::request_done; + request_method_(srv_ctx_.get(), &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) { + return false; + } + next_state_ = &ServerRpcContextStreamingFromClientImpl::read_done; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + return true; + } + + bool read_done(bool ok) { + if (ok) { + // In this case, just do another read + // next_state_ is unchanged + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + return true; + } else { // client has sent writes done + // invoke the method + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + // finish the stream + next_state_ = &ServerRpcContextStreamingFromClientImpl::finish_done; + stream_.Finish(response_, Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool finish_done(bool ok) { return false; /* reset the context */ } + + std::unique_ptr<ServerContextType> srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextStreamingFromClientImpl::*next_state_)(bool); + std::function<void(ServerContextType *, + grpc::ServerAsyncReader<ResponseType, RequestType> *, + void *)> + request_method_; + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method_; + grpc::ServerAsyncReader<ResponseType, RequestType> stream_; + }; + + class ServerRpcContextStreamingFromServerImpl final + : public ServerRpcContext { + public: + ServerRpcContextStreamingFromServerImpl( + std::function<void(ServerContextType *, RequestType *, + grpc::ServerAsyncWriter<ResponseType> *, void *)> + request_method, + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method) + : srv_ctx_(new ServerContextType), + next_state_(&ServerRpcContextStreamingFromServerImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(srv_ctx_.get()) { + request_method_(srv_ctx_.get(), &req_, &stream_, + AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingFromServerImpl() override {} + bool RunNextState(bool ok) override { return (this->*next_state_)(ok); } + void Reset() override { + srv_ctx_.reset(new ServerContextType); + req_ = RequestType(); + stream_ = grpc::ServerAsyncWriter<ResponseType>(srv_ctx_.get()); + + // Then request the method + next_state_ = &ServerRpcContextStreamingFromServerImpl::request_done; + request_method_(srv_ctx_.get(), &req_, &stream_, + AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) { + return false; + } + // invoke the method + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response_); + + next_state_ = &ServerRpcContextStreamingFromServerImpl::write_done; + stream_.Write(response_, AsyncQpsServerTest::tag(this)); + return true; + } + + bool write_done(bool ok) { + if (ok) { + // Do another write! + // next_state_ is unchanged + stream_.Write(response_, AsyncQpsServerTest::tag(this)); + } else { // must be done so let's finish + next_state_ = &ServerRpcContextStreamingFromServerImpl::finish_done; + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + } + return true; + } + bool finish_done(bool ok) { return false; /* reset the context */ } + + std::unique_ptr<ServerContextType> srv_ctx_; + RequestType req_; + ResponseType response_; + bool (ServerRpcContextStreamingFromServerImpl::*next_state_)(bool); + std::function<void(ServerContextType *, RequestType *, + grpc::ServerAsyncWriter<ResponseType> *, void *)> + request_method_; + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method_; + grpc::ServerAsyncWriter<ResponseType> stream_; + }; + std::vector<std::thread> threads_; std::unique_ptr<grpc::Server> server_; std::vector<std::unique_ptr<grpc::ServerCompletionQueue>> srv_cqs_; @@ -390,6 +562,9 @@ std::unique_ptr<Server> CreateAsyncServer(const ServerConfig &config) { config, RegisterBenchmarkService, &BenchmarkService::AsyncService::RequestUnaryCall, &BenchmarkService::AsyncService::RequestStreamingCall, + &BenchmarkService::AsyncService::RequestStreamingFromClient, + &BenchmarkService::AsyncService::RequestStreamingFromServer, + &BenchmarkService::AsyncService::RequestStreamingBothWays, ProcessSimpleRPC)); } std::unique_ptr<Server> CreateAsyncGenericServer(const ServerConfig &config) { @@ -397,7 +572,8 @@ std::unique_ptr<Server> CreateAsyncGenericServer(const ServerConfig &config) { new AsyncQpsServerTest<ByteBuffer, ByteBuffer, grpc::AsyncGenericService, grpc::GenericServerContext>( config, RegisterGenericService, nullptr, - &grpc::AsyncGenericService::RequestCall, ProcessGenericRPC)); + &grpc::AsyncGenericService::RequestCall, nullptr, nullptr, nullptr, + ProcessGenericRPC)); } } // namespace testing diff --git a/test/cpp/qps/server_sync.cc b/test/cpp/qps/server_sync.cc index f79284d225..f04465e261 100644 --- a/test/cpp/qps/server_sync.cc +++ b/test/cpp/qps/server_sync.cc @@ -31,6 +31,9 @@ * */ +#include <atomic> +#include <thread> + #include <grpc++/resource_quota.h> #include <grpc++/security/server_credentials.h> #include <grpc++/server.h> @@ -52,12 +55,9 @@ class BenchmarkServiceImpl final : public BenchmarkService::Service { public: Status UnaryCall(ServerContext* context, const SimpleRequest* request, SimpleResponse* response) override { - if (request->response_size() > 0) { - if (!Server::SetPayload(request->response_type(), - request->response_size(), - response->mutable_payload())) { - return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); - } + auto s = SetResponse(request, response); + if (!s.ok()) { + return s; } return Status::OK; } @@ -67,12 +67,9 @@ class BenchmarkServiceImpl final : public BenchmarkService::Service { SimpleRequest request; while (stream->Read(&request)) { SimpleResponse response; - if (request.response_size() > 0) { - if (!Server::SetPayload(request.response_type(), - request.response_size(), - response.mutable_payload())) { - return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); - } + auto s = SetResponse(&request, &response); + if (!s.ok()) { + return s; } if (!stream->Write(response)) { return Status(StatusCode::INTERNAL, "Server couldn't respond"); @@ -80,6 +77,96 @@ class BenchmarkServiceImpl final : public BenchmarkService::Service { } return Status::OK; } + Status StreamingFromClient(ServerContext* context, + ServerReader<SimpleRequest>* stream, + SimpleResponse* response) override { + auto s = ClientPull(context, stream, response); + if (!s.ok()) { + return s; + } + return Status::OK; + } + Status StreamingFromServer(ServerContext* context, + const SimpleRequest* request, + ServerWriter<SimpleResponse>* stream) override { + SimpleResponse response; + auto s = SetResponse(request, &response); + if (!s.ok()) { + return s; + } + return ServerPush(context, stream, response, nullptr); + } + Status StreamingBothWays( + ServerContext* context, + ServerReaderWriter<SimpleResponse, SimpleRequest>* stream) override { + // Read the first client message to setup server response + SimpleRequest request; + if (!stream->Read(&request)) { + return Status::OK; + } + SimpleResponse response; + auto s = SetResponse(&request, &response); + if (!s.ok()) { + return s; + } + std::atomic_bool done; + Status sp; + std::thread t([context, stream, &response, &done, &sp]() { + sp = ServerPush(context, stream, response, [&done]() { + return done.load(std::memory_order_relaxed); + }); + }); + SimpleResponse dummy; + auto cp = ClientPull(context, stream, &dummy); + done.store(true, std::memory_order_relaxed); // can be lazy + t.join(); + if (!cp.ok()) { + return cp; + } + if (!sp.ok()) { + return sp; + } + return Status::OK; + } + + private: + static Status ClientPull(ServerContext* context, + ReaderInterface<SimpleRequest>* stream, + SimpleResponse* response) { + SimpleRequest request; + while (stream->Read(&request)) { + } + if (request.response_size() > 0) { + if (!Server::SetPayload(request.response_type(), request.response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; + } + static Status ServerPush(ServerContext* context, + WriterInterface<SimpleResponse>* stream, + const SimpleResponse& response, + std::function<bool()> done) { + while ((done == nullptr) || !done()) { + // TODO(vjpai): Add potential for rate-pacing on this + if (!stream->Write(response)) { + return Status(StatusCode::INTERNAL, "Server couldn't push"); + } + } + return Status::OK; + } + static Status SetResponse(const SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { + if (!Server::SetPayload(request->response_type(), + request->response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; + } }; class SynchronousServer final : public grpc::testing::Server { |