diff options
author | vjpai <vpai@google.com> | 2015-03-23 10:10:27 -0700 |
---|---|---|
committer | vjpai <vpai@google.com> | 2015-03-23 10:10:27 -0700 |
commit | 46f65239cf5bb55c106558196b9f293188746ef7 (patch) | |
tree | db6a554d6358196eb54a43a187f316c84b15a978 | |
parent | 99f1c1ee7efcceae76bf3dce17417ce5dc0ca0d7 (diff) |
Added streaming C++ tests for sync and sync cases
-rw-r--r-- | test/cpp/qps/client.h | 16 | ||||
-rw-r--r-- | test/cpp/qps/client_async.cc | 173 | ||||
-rw-r--r-- | test/cpp/qps/client_sync.cc | 70 | ||||
-rw-r--r-- | test/cpp/qps/qpstest.proto | 78 | ||||
-rw-r--r-- | test/cpp/qps/server.cc | 2 | ||||
-rw-r--r-- | test/cpp/qps/server_async.cc | 126 | ||||
-rw-r--r-- | test/cpp/qps/server_sync.cc | 19 | ||||
-rw-r--r-- | test/cpp/qps/worker.cc | 7 |
8 files changed, 369 insertions, 122 deletions
diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h index 221fb30fc5..cae7f44537 100644 --- a/test/cpp/qps/client.h +++ b/test/cpp/qps/client.h @@ -115,12 +115,12 @@ class Client { impl_([this, idx, client]() { for (;;) { // run the loop body - client->ThreadFunc(&histogram_, idx); + client->ThreadFunc(&histogram_, idx); // lock, see if we're done std::lock_guard<std::mutex> g(mu_); - if (done_) return; - // also check if we're marking, and swap out the histogram if so - if (new_) { + if (done_) {return;} + // check if we're marking, swap out the histogram if so + if (new_) { new_->Swap(&histogram_); new_ = nullptr; cv_.notify_one(); @@ -164,8 +164,12 @@ class Client { std::unique_ptr<Timer> timer_; }; -std::unique_ptr<Client> CreateSynchronousClient(const ClientConfig& args); -std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& args); +std::unique_ptr<Client> + CreateSynchronousUnaryClient(const ClientConfig& args); +std::unique_ptr<Client> + CreateSynchronousStreamingClient(const ClientConfig& args); +std::unique_ptr<Client> CreateAsyncUnaryClient(const ClientConfig& args); +std::unique_ptr<Client> CreateAsyncStreamingClient(const ClientConfig& args); } // namespace testing } // namespace grpc diff --git a/test/cpp/qps/client_async.cc b/test/cpp/qps/client_async.cc index c6535bebf8..30317d11e1 100644 --- a/test/cpp/qps/client_async.cc +++ b/test/cpp/qps/client_async.cc @@ -46,6 +46,7 @@ #include <grpc++/async_unary_call.h> #include <grpc++/client_context.h> #include <grpc++/status.h> +#include <grpc++/stream.h> #include "test/core/util/grpc_profiler.h" #include "test/cpp/util/create_test_channel.h" #include "test/cpp/qps/qpstest.pb.h" @@ -59,13 +60,13 @@ class ClientRpcContext { public: ClientRpcContext() {} virtual ~ClientRpcContext() {} - virtual bool RunNextState() = 0; // do next state, return false if steps done + // next state, return false if done. Collect stats when appropriate + virtual bool RunNextState(bool, Histogram *hist) = 0; virtual void StartNewClone() = 0; static void *tag(ClientRpcContext *c) { return reinterpret_cast<void *>(c); } static ClientRpcContext *detag(void *t) { return reinterpret_cast<ClientRpcContext *>(t); } - virtual void report_stats(Histogram *hist) = 0; }; template <class RequestType, class ResponseType> @@ -89,9 +90,12 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { response_reader_( start_req(stub_, &context_, req_, ClientRpcContext::tag(this))) {} ~ClientRpcContextUnaryImpl() GRPC_OVERRIDE {} - bool RunNextState() GRPC_OVERRIDE { return (this->*next_state_)(); } - void report_stats(Histogram *hist) GRPC_OVERRIDE { - hist->Add((Timer::Now() - start_) * 1e9); + bool RunNextState(bool ok, Histogram *hist) GRPC_OVERRIDE { + bool ret = (this->*next_state_)(ok); + if (!ret) { + hist->Add((Timer::Now() - start_) * 1e9); + } + return ret; } void StartNewClone() GRPC_OVERRIDE { @@ -99,16 +103,16 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { } private: - bool ReqSent() { + bool ReqSent(bool) { next_state_ = &ClientRpcContextUnaryImpl::RespDone; response_reader_->Finish(&response_, &status_, ClientRpcContext::tag(this)); return true; } - bool RespDone() { + bool RespDone(bool) { next_state_ = &ClientRpcContextUnaryImpl::DoCallBack; return false; } - bool DoCallBack() { + bool DoCallBack(bool) { callback_(status_, &response_); return false; } @@ -116,7 +120,7 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { TestService::Stub *stub_; RequestType req_; ResponseType response_; - bool (ClientRpcContextUnaryImpl::*next_state_)(); + bool (ClientRpcContextUnaryImpl::*next_state_)(bool); std::function<void(grpc::Status, ResponseType *)> callback_; std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>( TestService::Stub *, grpc::ClientContext *, const RequestType &, void *)> @@ -127,9 +131,9 @@ class ClientRpcContextUnaryImpl : public ClientRpcContext { response_reader_; }; -class AsyncClient GRPC_FINAL : public Client { +class AsyncUnaryClient GRPC_FINAL : public Client { public: - explicit AsyncClient(const ClientConfig &config) : Client(config) { + explicit AsyncUnaryClient(const ClientConfig &config) : Client(config) { for (int i = 0; i < config.async_client_threads(); i++) { cli_cqs_.emplace_back(new CompletionQueue); } @@ -162,7 +166,7 @@ class AsyncClient GRPC_FINAL : public Client { StartThreads(config.async_client_threads()); } - ~AsyncClient() GRPC_OVERRIDE { + ~AsyncUnaryClient() GRPC_OVERRIDE { EndThreads(); for (auto &cq : cli_cqs_) { @@ -181,10 +185,9 @@ class AsyncClient GRPC_FINAL : public Client { cli_cqs_[thread_idx]->Next(&got_tag, &ok); ClientRpcContext *ctx = ClientRpcContext::detag(got_tag); - if (ctx->RunNextState() == false) { + if (ctx->RunNextState(ok, histogram) == false) { // call the callback and then delete it - ctx->report_stats(histogram); - ctx->RunNextState(); + ctx->RunNextState(ok, histogram); ctx->StartNewClone(); delete ctx; } @@ -193,8 +196,144 @@ class AsyncClient GRPC_FINAL : public Client { std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_; }; -std::unique_ptr<Client> CreateAsyncClient(const ClientConfig &args) { - return std::unique_ptr<Client>(new AsyncClient(args)); +template <class RequestType, class ResponseType> +class ClientRpcContextStreamingImpl : public ClientRpcContext { + public: + ClientRpcContextStreamingImpl( + TestService::Stub *stub, const RequestType &req, + std::function< + std::unique_ptr<grpc::ClientAsyncReaderWriter< + RequestType,ResponseType>>( + TestService::Stub *, grpc::ClientContext *, void *)> start_req, + std::function<void(grpc::Status, ResponseType *)> on_done) + : context_(), + stub_(stub), + req_(req), + response_(), + next_state_(&ClientRpcContextStreamingImpl::ReqSent), + callback_(on_done), + start_req_(start_req), + start_(Timer::Now()), + stream_(start_req_(stub_, &context_, ClientRpcContext::tag(this))) {} + ~ClientRpcContextStreamingImpl() GRPC_OVERRIDE {} + bool RunNextState(bool ok, Histogram *hist) GRPC_OVERRIDE { + return (this->*next_state_)(ok, hist); + } + void StartNewClone() GRPC_OVERRIDE { + new ClientRpcContextStreamingImpl(stub_, req_, start_req_, callback_); + } + + private: + bool ReqSent(bool ok, Histogram *) { + return StartWrite(ok); + } + bool StartWrite(bool ok) { + if (!ok) { + return(false); + } + start_ = Timer::Now(); + next_state_ = &ClientRpcContextStreamingImpl::WriteDone; + stream_->Write(req_, ClientRpcContext::tag(this)); + return true; + } + bool WriteDone(bool ok, Histogram *) { + if (!ok) { + return(false); + } + next_state_ = &ClientRpcContextStreamingImpl::ReadDone; + stream_->Read(&response_, ClientRpcContext::tag(this)); + return true; + } + bool ReadDone(bool ok, Histogram *hist) { + hist->Add((Timer::Now() - start_) * 1e9); + return StartWrite(ok); + } + grpc::ClientContext context_; + TestService::Stub *stub_; + RequestType req_; + ResponseType response_; + bool (ClientRpcContextStreamingImpl::*next_state_)(bool, Histogram *); + std::function<void(grpc::Status, ResponseType *)> callback_; + std::function<std::unique_ptr<grpc::ClientAsyncReaderWriter< + RequestType,ResponseType>>( + TestService::Stub *, grpc::ClientContext *, void *)> start_req_; + grpc::Status status_; + double start_; + std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType,ResponseType>> + stream_; +}; + +class AsyncStreamingClient GRPC_FINAL : public Client { + public: + explicit AsyncStreamingClient(const ClientConfig &config) : Client(config) { + for (int i = 0; i < config.async_client_threads(); i++) { + cli_cqs_.emplace_back(new CompletionQueue); + } + + auto payload_size = config.payload_size(); + auto check_done = [payload_size](grpc::Status s, SimpleResponse *response) { + GPR_ASSERT(s.IsOk() && (response->payload().type() == + grpc::testing::PayloadType::COMPRESSABLE) && + (response->payload().body().length() == + static_cast<size_t>(payload_size))); + }; + + int t = 0; + for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) { + for (auto &channel : channels_) { + auto *cq = cli_cqs_[t].get(); + t = (t + 1) % cli_cqs_.size(); + auto start_req = [cq](TestService::Stub *stub, grpc::ClientContext *ctx, + void *tag) { + auto stream = stub->AsyncStreamingCall(ctx, cq, tag); + return stream; + }; + + TestService::Stub *stub = channel.get_stub(); + const SimpleRequest &request = request_; + new ClientRpcContextStreamingImpl<SimpleRequest, SimpleResponse>( + stub, request, start_req, check_done); + } + } + + StartThreads(config.async_client_threads()); + } + + ~AsyncStreamingClient() GRPC_OVERRIDE { + EndThreads(); + + for (auto &cq : cli_cqs_) { + cq->Shutdown(); + void *got_tag; + bool ok; + while (cq->Next(&got_tag, &ok)) { + delete ClientRpcContext::detag(got_tag); + } + } + } + + void ThreadFunc(Histogram *histogram, size_t thread_idx) GRPC_OVERRIDE { + void *got_tag; + bool ok; + cli_cqs_[thread_idx]->Next(&got_tag, &ok); + + ClientRpcContext *ctx = ClientRpcContext::detag(got_tag); + if (ctx->RunNextState(ok, histogram) == false) { + // call the callback and then delete it + ctx->RunNextState(ok, histogram); + ctx->StartNewClone(); + delete ctx; + } + } + + std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_; +}; + +std::unique_ptr<Client> CreateAsyncUnaryClient(const ClientConfig &args) { + return std::unique_ptr<Client>(new AsyncUnaryClient(args)); +} +std::unique_ptr<Client> CreateAsyncStreamingClient(const ClientConfig &args) { + return std::unique_ptr<Client>(new AsyncStreamingClient(args)); } } // namespace testing diff --git a/test/cpp/qps/client_sync.cc b/test/cpp/qps/client_sync.cc index 7bb7231c6f..1e14aa85c5 100644 --- a/test/cpp/qps/client_sync.cc +++ b/test/cpp/qps/client_sync.cc @@ -48,9 +48,11 @@ #include <grpc/support/host_port.h> #include <gflags/gflags.h> #include <grpc++/client_context.h> -#include <grpc++/status.h> #include <grpc++/server.h> #include <grpc++/server_builder.h> +#include <grpc++/status.h> +#include <grpc++/stream.h> +#include <gtest/gtest.h> #include "test/core/util/grpc_profiler.h" #include "test/cpp/util/create_test_channel.h" #include "test/cpp/qps/client.h" @@ -61,18 +63,28 @@ namespace grpc { namespace testing { -class SynchronousClient GRPC_FINAL : public Client { +class SynchronousClient : public Client { public: SynchronousClient(const ClientConfig& config) : Client(config) { - size_t num_threads = - config.outstanding_rpcs_per_channel() * config.client_channels(); - responses_.resize(num_threads); - StartThreads(num_threads); + num_threads_ = + config.outstanding_rpcs_per_channel() * config.client_channels(); + responses_.resize(num_threads_); } - ~SynchronousClient() { EndThreads(); } + virtual ~SynchronousClient() { EndThreads(); } + + protected: + size_t num_threads_; + std::vector<SimpleResponse> responses_; +}; - void ThreadFunc(Histogram* histogram, size_t thread_idx) { +class SynchronousUnaryClient GRPC_FINAL : public SynchronousClient { + public: + SynchronousUnaryClient(const ClientConfig& config): + SynchronousClient(config) {StartThreads(num_threads_);} + ~SynchronousUnaryClient() {} + + void ThreadFunc(Histogram* histogram, size_t thread_idx) GRPC_OVERRIDE { auto* stub = channels_[thread_idx % channels_.size()].get_stub(); double start = Timer::Now(); grpc::ClientContext context; @@ -80,13 +92,47 @@ class SynchronousClient GRPC_FINAL : public Client { stub->UnaryCall(&context, request_, &responses_[thread_idx]); histogram->Add((Timer::Now() - start) * 1e9); } +}; - private: - std::vector<SimpleResponse> responses_; +class SynchronousStreamingClient GRPC_FINAL : public SynchronousClient { + public: + SynchronousStreamingClient(const ClientConfig& config): + SynchronousClient(config) { + for (size_t thread_idx=0;thread_idx<num_threads_;thread_idx++){ + auto* stub = channels_[thread_idx % channels_.size()].get_stub(); + stream_ = stub->StreamingCall(&context_); + } + StartThreads(num_threads_); + } + ~SynchronousStreamingClient() { + if (stream_) { + SimpleResponse response; + stream_->WritesDone(); + EXPECT_FALSE(stream_->Read(&response)); + + Status s = stream_->Finish(); + EXPECT_TRUE(s.IsOk()); + } + } + + void ThreadFunc(Histogram* histogram, size_t thread_idx) GRPC_OVERRIDE { + double start = Timer::Now(); + EXPECT_TRUE(stream_->Write(request_)); + EXPECT_TRUE(stream_->Read(&responses_[thread_idx])); + histogram->Add((Timer::Now() - start) * 1e9); + } + private: + grpc::ClientContext context_; + std::unique_ptr<grpc::ClientReaderWriter<SimpleRequest,SimpleResponse>> stream_; }; -std::unique_ptr<Client> CreateSynchronousClient(const ClientConfig& config) { - return std::unique_ptr<Client>(new SynchronousClient(config)); +std::unique_ptr<Client> +CreateSynchronousUnaryClient(const ClientConfig& config) { + return std::unique_ptr<Client>(new SynchronousUnaryClient(config)); +} +std::unique_ptr<Client> +CreateSynchronousStreamingClient(const ClientConfig& config) { + return std::unique_ptr<Client>(new SynchronousStreamingClient(config)); } } // namespace testing diff --git a/test/cpp/qps/qpstest.proto b/test/cpp/qps/qpstest.proto index 6a7170bf58..70cc926f16 100644 --- a/test/cpp/qps/qpstest.proto +++ b/test/cpp/qps/qpstest.proto @@ -87,15 +87,21 @@ enum ServerType { ASYNC_SERVER = 2; } +enum TestType { + UNARY_TEST = 1; + STREAMING_TEST = 2; +} + message ClientConfig { repeated string server_targets = 1; required ClientType client_type = 2; - required bool enable_ssl = 3; + optional bool enable_ssl = 3 [default=false]; required int32 outstanding_rpcs_per_channel = 4; required int32 client_channels = 5; required int32 payload_size = 6; // only for async client: optional int32 async_client_threads = 7; + optional TestType test_type = 8 [default=UNARY_TEST]; } // Request current stats @@ -121,8 +127,8 @@ message ClientStatus { message ServerConfig { required ServerType server_type = 1; - required int32 threads = 2; - required bool enable_ssl = 3; + optional int32 threads = 2 [default=1]; + optional bool enable_ssl = 3 [default=false]; } message ServerArgs { @@ -144,7 +150,7 @@ message SimpleRequest { // Desired payload size in the response from the server. // If response_type is COMPRESSABLE, this denotes the size before compression. - optional int32 response_size = 2; + optional int32 response_size = 2 [default=0]; // Optional input payload sent along with the request. optional Payload payload = 3; @@ -154,72 +160,14 @@ message SimpleResponse { optional Payload payload = 1; } -message StreamingInputCallRequest { - // Optional input payload sent along with the request. - optional Payload payload = 1; - - // Not expecting any payload from the response. -} - -message StreamingInputCallResponse { - // Aggregated size of payloads received from the client. - optional int32 aggregated_payload_size = 1; -} - -message ResponseParameters { - // Desired payload sizes in responses from the server. - // If response_type is COMPRESSABLE, this denotes the size before compression. - required int32 size = 1; - - // Desired interval between consecutive responses in the response stream in - // microseconds. - required int32 interval_us = 2; -} - -message StreamingOutputCallRequest { - // Desired payload type in the response from the server. - // If response_type is RANDOM, the payload from each response in the stream - // might be of different types. This is to simulate a mixed type of payload - // stream. - optional PayloadType response_type = 1 [default=COMPRESSABLE]; - - repeated ResponseParameters response_parameters = 2; - - // Optional input payload sent along with the request. - optional Payload payload = 3; -} - -message StreamingOutputCallResponse { - optional Payload payload = 1; -} - service TestService { // One request followed by one response. // The server returns the client payload as-is. rpc UnaryCall(SimpleRequest) returns (SimpleResponse); - // One request followed by a sequence of responses (streamed download). - // The server returns the payload with client desired type and sizes. - rpc StreamingOutputCall(StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); - - // A sequence of requests followed by one response (streamed upload). - // The server returns the aggregated size of client payload as the result. - rpc StreamingInputCall(stream StreamingInputCallRequest) - returns (StreamingInputCallResponse); - - // A sequence of requests with each request served by the server immediately. - // As one request could lead to multiple responses, this interface - // demonstrates the idea of full duplexing. - rpc FullDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); - - // A sequence of requests followed by a sequence of responses. - // The server buffers all the client requests and then serves them in order. A - // stream of responses are returned to the client when the server starts with - // first request. - rpc HalfDuplexCall(stream StreamingOutputCallRequest) - returns (stream StreamingOutputCallResponse); + // One request followed by one response. + // The server returns the client payload as-is. + rpc StreamingCall(stream SimpleRequest) returns (stream SimpleResponse); } service Worker { diff --git a/test/cpp/qps/server.cc b/test/cpp/qps/server.cc index 005f0f9c5e..c6c6c9543d 100644 --- a/test/cpp/qps/server.cc +++ b/test/cpp/qps/server.cc @@ -115,7 +115,7 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { } Status UnaryCall(ServerContext* context, const SimpleRequest* request, SimpleResponse* response) { - if (request->has_response_size() && request->response_size() > 0) { + if (request->response_size() > 0) { if (!SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index 19778e5a7c..4312f597b2 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -48,6 +48,7 @@ #include <grpc++/server_context.h> #include <grpc++/server_credentials.h> #include <grpc++/status.h> +#include <grpc++/stream.h> #include <gtest/gtest.h> #include "src/cpp/server/thread_pool.h" #include "test/core/util/grpc_profiler.h" @@ -78,10 +79,16 @@ class AsyncQpsServerTest : public Server { using namespace std::placeholders; request_unary_ = std::bind(&TestService::AsyncService::RequestUnaryCall, &async_service_, _1, _2, _3, &srv_cq_, _4); + request_streaming_ = + std::bind(&TestService::AsyncService::RequestStreamingCall, + &async_service_, _1, _2, &srv_cq_, _3); for (int i = 0; i < 100; i++) { contexts_.push_front( new ServerRpcContextUnaryImpl<SimpleRequest, SimpleResponse>( - request_unary_, UnaryCall)); + request_unary_, ProcessRPC)); + contexts_.push_front( + new ServerRpcContextStreamingImpl<SimpleRequest, SimpleResponse>( + request_streaming_, ProcessRPC)); } for (int i = 0; i < config.threads(); i++) { threads_.push_back(std::thread([=]() { @@ -89,14 +96,12 @@ class AsyncQpsServerTest : public Server { bool ok; void *got_tag; while (srv_cq_.Next(&got_tag, &ok)) { - if (ok) { - ServerRpcContext *ctx = detag(got_tag); - // The tag is a pointer to an RPC context to invoke - if (ctx->RunNextState() == false) { - // this RPC context is done, so refresh it - ctx->Reset(); - } - } + ServerRpcContext *ctx = detag(got_tag); + // The tag is a pointer to an RPC context to invoke + if (ctx->RunNextState(ok) == false) { + // this RPC context is done, so refresh it + ctx->Reset(); + } } return; })); @@ -119,7 +124,7 @@ class AsyncQpsServerTest : public Server { public: ServerRpcContext() {} virtual ~ServerRpcContext(){}; - virtual bool RunNextState() = 0; // do next state, return false if all done + virtual bool RunNextState(bool) = 0; // next state, return false if done virtual void Reset() = 0; // start this back at a clean state }; static void *tag(ServerRpcContext *func) { @@ -130,7 +135,7 @@ class AsyncQpsServerTest : public Server { } template <class RequestType, class ResponseType> - class ServerRpcContextUnaryImpl : public ServerRpcContext { + class ServerRpcContextUnaryImpl GRPC_FINAL : public ServerRpcContext { public: ServerRpcContextUnaryImpl( std::function<void(ServerContext *, RequestType *, @@ -146,7 +151,7 @@ class AsyncQpsServerTest : public Server { AsyncQpsServerTest::tag(this)); } ~ServerRpcContextUnaryImpl() GRPC_OVERRIDE {} - bool RunNextState() GRPC_OVERRIDE { return (this->*next_state_)(); } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} void Reset() GRPC_OVERRIDE { srv_ctx_ = ServerContext(); req_ = RequestType(); @@ -160,8 +165,11 @@ class AsyncQpsServerTest : public Server { } private: - bool finisher() { return false; } - bool invoker() { + bool finisher(bool) { return false; } + bool invoker(bool ok) { + if (!ok) + return false; + ResponseType response; // Call the RPC processing function @@ -174,7 +182,7 @@ class AsyncQpsServerTest : public Server { } ServerContext srv_ctx_; RequestType req_; - bool (ServerRpcContextUnaryImpl::*next_state_)(); + bool (ServerRpcContextUnaryImpl::*next_state_)(bool); std::function<void(ServerContext *, RequestType *, grpc::ServerAsyncResponseWriter<ResponseType> *, void *)> request_method_; @@ -183,9 +191,88 @@ class AsyncQpsServerTest : public Server { grpc::ServerAsyncResponseWriter<ResponseType> response_writer_; }; - static Status UnaryCall(const SimpleRequest *request, - SimpleResponse *response) { - if (request->has_response_size() && request->response_size() > 0) { + template <class RequestType, class ResponseType> + class ServerRpcContextStreamingImpl GRPC_FINAL : public ServerRpcContext { + public: + ServerRpcContextStreamingImpl( + std::function<void(ServerContext *, + grpc::ServerAsyncReaderWriter<ResponseType, + RequestType> *, void *)> request_method, + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method) + : next_state_(&ServerRpcContextStreamingImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(&srv_ctx_) { + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingImpl() GRPC_OVERRIDE { + } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} + void Reset() GRPC_OVERRIDE { + srv_ctx_ = ServerContext(); + req_ = RequestType(); + stream_ = grpc::ServerAsyncReaderWriter<ResponseType, + RequestType>(&srv_ctx_); + + // Then request the method + next_state_ = &ServerRpcContextStreamingImpl::request_done; + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) + return false; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + return true; + } + + bool read_done(bool ok) { + if (ok) { + // invoke the method + ResponseType response; + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response); + // initiate the write + stream_.Write(response, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::write_done; + } else { // client has sent writes done + // finish the stream + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool write_done(bool ok) { + // now go back and get another streaming read! + if (ok) { + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + } + else { + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool finish_done(bool ok) {return false; /* reset the context */ } + + ServerContext srv_ctx_; + RequestType req_; + bool (ServerRpcContextStreamingImpl::*next_state_)(bool); + std::function<void(ServerContext *, + grpc::ServerAsyncReaderWriter<ResponseType, + RequestType> *, void *)> request_method_; + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method_; + grpc::ServerAsyncReaderWriter<ResponseType,RequestType> stream_; + }; + + static Status ProcessRPC(const SimpleRequest *request, + SimpleResponse *response) { + if (request->response_size() > 0) { if (!SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); @@ -200,6 +287,9 @@ class AsyncQpsServerTest : public Server { std::function<void(ServerContext *, SimpleRequest *, grpc::ServerAsyncResponseWriter<SimpleResponse> *, void *)> request_unary_; + std::function<void(ServerContext *, grpc::ServerAsyncReaderWriter< + SimpleResponse,SimpleRequest> *, void *)> + request_streaming_; std::forward_list<ServerRpcContext *> contexts_; }; diff --git a/test/cpp/qps/server_sync.cc b/test/cpp/qps/server_sync.cc index 5c6541989c..1dbdd64a34 100644 --- a/test/cpp/qps/server_sync.cc +++ b/test/cpp/qps/server_sync.cc @@ -62,7 +62,7 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { public: Status UnaryCall(ServerContext* context, const SimpleRequest* request, SimpleResponse* response) GRPC_OVERRIDE { - if (request->has_response_size() && request->response_size() > 0) { + if (request->response_size() > 0) { if (!Server::SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { @@ -71,6 +71,23 @@ class TestServiceImpl GRPC_FINAL : public TestService::Service { } return Status::OK; } + Status StreamingCall(ServerContext *context, + ServerReaderWriter<SimpleResponse, SimpleRequest>* + stream) GRPC_OVERRIDE { + SimpleRequest request; + while (stream->Read(&request)) { + SimpleResponse response; + if (request.response_size() > 0) { + if (!Server::SetPayload(request.response_type(), + request.response_size(), + response.mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + stream->Write(response); + } + return Status::OK; + } }; class SynchronousServer GRPC_FINAL : public grpc::testing::Server { diff --git a/test/cpp/qps/worker.cc b/test/cpp/qps/worker.cc index faabfd1147..4c8c7cfea9 100644 --- a/test/cpp/qps/worker.cc +++ b/test/cpp/qps/worker.cc @@ -77,9 +77,12 @@ namespace testing { std::unique_ptr<Client> CreateClient(const ClientConfig& config) { switch (config.client_type()) { case ClientType::SYNCHRONOUS_CLIENT: - return CreateSynchronousClient(config); + return (config.test_type() == TestType::UNARY_TEST) ? + CreateSynchronousUnaryClient(config) : + CreateSynchronousStreamingClient(config); case ClientType::ASYNC_CLIENT: - return CreateAsyncClient(config); + return (config.test_type() == TestType::UNARY_TEST) ? + CreateAsyncUnaryClient(config) : CreateAsyncStreamingClient(config); } abort(); } |