diff options
Diffstat (limited to 'test/cpp/qps/server_async.cc')
-rw-r--r-- | test/cpp/qps/server_async.cc | 138 |
1 files changed, 120 insertions, 18 deletions
diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index 7b81bd35a2..83bb08cd49 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -33,6 +33,7 @@ #include <forward_list> #include <functional> +#include <mutex> #include <sys/time.h> #include <sys/resource.h> #include <sys/signal.h> @@ -48,6 +49,7 @@ #include <grpc++/server_context.h> #include <grpc++/server_credentials.h> #include <grpc++/status.h> +#include <grpc++/stream.h> #include <gtest/gtest.h> #include "src/cpp/server/thread_pool.h" #include "test/core/util/grpc_profiler.h" @@ -63,7 +65,8 @@ namespace testing { class AsyncQpsServerTest : public Server { public: AsyncQpsServerTest(const ServerConfig& config, int port) - : srv_cq_(), async_service_(&srv_cq_), server_(nullptr) { + : srv_cq_(), async_service_(&srv_cq_), server_(nullptr), + shutdown_(false) { char* server_address = NULL; gpr_join_host_port(&server_address, "::", port); @@ -78,10 +81,16 @@ class AsyncQpsServerTest : public Server { using namespace std::placeholders; request_unary_ = std::bind(&TestService::AsyncService::RequestUnaryCall, &async_service_, _1, _2, _3, &srv_cq_, _4); + request_streaming_ = + std::bind(&TestService::AsyncService::RequestStreamingCall, + &async_service_, _1, _2, &srv_cq_, _3); for (int i = 0; i < 100; i++) { contexts_.push_front( new ServerRpcContextUnaryImpl<SimpleRequest, SimpleResponse>( - request_unary_, UnaryCall)); + request_unary_, ProcessRPC)); + contexts_.push_front( + new ServerRpcContextStreamingImpl<SimpleRequest, SimpleResponse>( + request_streaming_, ProcessRPC)); } for (int i = 0; i < config.threads(); i++) { threads_.push_back(std::thread([=]() { @@ -89,14 +98,15 @@ class AsyncQpsServerTest : public Server { bool ok; void* got_tag; while (srv_cq_.Next(&got_tag, &ok)) { - if (ok) { - ServerRpcContext* ctx = detag(got_tag); - // The tag is a pointer to an RPC context to invoke - if (ctx->RunNextState() == false) { - // this RPC context is done, so refresh it + ServerRpcContext* ctx = detag(got_tag); + // The tag is a pointer to an RPC context to invoke + if (ctx->RunNextState(ok) == false) { + // this RPC context is done, so refresh it + std::lock_guard<std::mutex> g(shutdown_mutex_); + if (!shutdown_) { ctx->Reset(); } - } + } } return; })); @@ -104,7 +114,11 @@ class AsyncQpsServerTest : public Server { } ~AsyncQpsServerTest() { server_->Shutdown(); - srv_cq_.Shutdown(); + { + std::lock_guard<std::mutex> g(shutdown_mutex_); + shutdown_ = true; + srv_cq_.Shutdown(); + } for (auto thr = threads_.begin(); thr != threads_.end(); thr++) { thr->join(); } @@ -119,7 +133,7 @@ class AsyncQpsServerTest : public Server { public: ServerRpcContext() {} virtual ~ServerRpcContext(){}; - virtual bool RunNextState() = 0; // do next state, return false if all done + virtual bool RunNextState(bool) = 0; // next state, return false if done virtual void Reset() = 0; // start this back at a clean state }; static void* tag(ServerRpcContext* func) { @@ -130,7 +144,7 @@ class AsyncQpsServerTest : public Server { } template <class RequestType, class ResponseType> - class ServerRpcContextUnaryImpl : public ServerRpcContext { + class ServerRpcContextUnaryImpl GRPC_FINAL : public ServerRpcContext { public: ServerRpcContextUnaryImpl( std::function<void(ServerContext*, RequestType*, @@ -146,7 +160,7 @@ class AsyncQpsServerTest : public Server { AsyncQpsServerTest::tag(this)); } ~ServerRpcContextUnaryImpl() GRPC_OVERRIDE {} - bool RunNextState() GRPC_OVERRIDE { return (this->*next_state_)(); } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} void Reset() GRPC_OVERRIDE { srv_ctx_ = ServerContext(); req_ = RequestType(); @@ -160,8 +174,11 @@ class AsyncQpsServerTest : public Server { } private: - bool finisher() { return false; } - bool invoker() { + bool finisher(bool) { return false; } + bool invoker(bool ok) { + if (!ok) + return false; + ResponseType response; // Call the RPC processing function @@ -174,7 +191,7 @@ class AsyncQpsServerTest : public Server { } ServerContext srv_ctx_; RequestType req_; - bool (ServerRpcContextUnaryImpl::*next_state_)(); + bool (ServerRpcContextUnaryImpl::*next_state_)(bool); std::function<void(ServerContext*, RequestType*, grpc::ServerAsyncResponseWriter<ResponseType>*, void*)> request_method_; @@ -183,9 +200,88 @@ class AsyncQpsServerTest : public Server { grpc::ServerAsyncResponseWriter<ResponseType> response_writer_; }; - static Status UnaryCall(const SimpleRequest* request, - SimpleResponse* response) { - if (request->has_response_size() && request->response_size() > 0) { + template <class RequestType, class ResponseType> + class ServerRpcContextStreamingImpl GRPC_FINAL : public ServerRpcContext { + public: + ServerRpcContextStreamingImpl( + std::function<void(ServerContext *, + grpc::ServerAsyncReaderWriter<ResponseType, + RequestType> *, void *)> request_method, + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method) + : next_state_(&ServerRpcContextStreamingImpl::request_done), + request_method_(request_method), + invoke_method_(invoke_method), + stream_(&srv_ctx_) { + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + ~ServerRpcContextStreamingImpl() GRPC_OVERRIDE { + } + bool RunNextState(bool ok) GRPC_OVERRIDE {return (this->*next_state_)(ok);} + void Reset() GRPC_OVERRIDE { + srv_ctx_ = ServerContext(); + req_ = RequestType(); + stream_ = grpc::ServerAsyncReaderWriter<ResponseType, + RequestType>(&srv_ctx_); + + // Then request the method + next_state_ = &ServerRpcContextStreamingImpl::request_done; + request_method_(&srv_ctx_, &stream_, AsyncQpsServerTest::tag(this)); + } + + private: + bool request_done(bool ok) { + if (!ok) + return false; + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + return true; + } + + bool read_done(bool ok) { + if (ok) { + // invoke the method + ResponseType response; + // Call the RPC processing function + grpc::Status status = invoke_method_(&req_, &response); + // initiate the write + stream_.Write(response, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::write_done; + } else { // client has sent writes done + // finish the stream + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool write_done(bool ok) { + // now go back and get another streaming read! + if (ok) { + stream_.Read(&req_, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::read_done; + } + else { + stream_.Finish(Status::OK, AsyncQpsServerTest::tag(this)); + next_state_ = &ServerRpcContextStreamingImpl::finish_done; + } + return true; + } + bool finish_done(bool ok) {return false; /* reset the context */ } + + ServerContext srv_ctx_; + RequestType req_; + bool (ServerRpcContextStreamingImpl::*next_state_)(bool); + std::function<void(ServerContext *, + grpc::ServerAsyncReaderWriter<ResponseType, + RequestType> *, void *)> request_method_; + std::function<grpc::Status(const RequestType *, ResponseType *)> + invoke_method_; + grpc::ServerAsyncReaderWriter<ResponseType,RequestType> stream_; + }; + + static Status ProcessRPC(const SimpleRequest* request, + SimpleResponse* response) { + if (request->response_size() > 0) { if (!SetPayload(request->response_type(), request->response_size(), response->mutable_payload())) { return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); @@ -200,7 +296,13 @@ class AsyncQpsServerTest : public Server { std::function<void(ServerContext*, SimpleRequest*, grpc::ServerAsyncResponseWriter<SimpleResponse>*, void*)> request_unary_; + std::function<void(ServerContext*, grpc::ServerAsyncReaderWriter< + SimpleResponse,SimpleRequest>*, void*)> + request_streaming_; std::forward_list<ServerRpcContext*> contexts_; + + std::mutex shutdown_mutex_; + bool shutdown_; }; std::unique_ptr<Server> CreateAsyncServer(const ServerConfig& config, |