diff options
author | vjpai <vpai@google.com> | 2016-01-07 10:20:46 -0800 |
---|---|---|
committer | vjpai <vpai@google.com> | 2016-01-07 10:20:46 -0800 |
commit | de332dfcac51e080e4c294d183906e5969672133 (patch) | |
tree | 0cab8680a3c7cd108eebc139cb207fbcd912f49d /test | |
parent | 18c0477528169c2032a57b5da094964a6d4beb2f (diff) |
Refactor server side to support generic tests.
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/qps/server_async.cc | 108 |
1 files changed, 70 insertions, 38 deletions
diff --git a/test/cpp/qps/server_async.cc b/test/cpp/qps/server_async.cc index c151918ce4..85a47ff71d 100644 --- a/test/cpp/qps/server_async.cc +++ b/test/cpp/qps/server_async.cc @@ -42,6 +42,7 @@ #include <grpc/support/alloc.h> #include <grpc/support/host_port.h> #include <grpc/support/log.h> +#include <grpc++/generic/async_generic_service.h> #include <grpc++/support/config.h> #include <grpc++/server.h> #include <grpc++/server_builder.h> @@ -55,9 +56,15 @@ namespace grpc { namespace testing { +template <class RequestType, class ResponseType, class ServiceType, class ServerContextType> class AsyncQpsServerTest : public Server { public: - explicit AsyncQpsServerTest(const ServerConfig &config) : Server(config) { + AsyncQpsServerTest(const ServerConfig &config, + std::function<void(ServerBuilder *, ServiceType *)> register_service, + std::function<void(ServiceType *, ServerContextType *, RequestType *, ServerAsyncResponseWriter<ResponseType>*, CompletionQueue *, ServerCompletionQueue *, void *)> request_unary_function, + std::function<void(ServiceType *, ServerContextType *, ServerAsyncReaderWriter<ResponseType, RequestType>*, CompletionQueue *, ServerCompletionQueue *, void *)> request_streaming_function, + std::function<grpc::Status(const ServerConfig&, const RequestType *, ResponseType *)> process_rpc) + : Server(config) { char *server_address = NULL; gpr_join_host_port(&server_address, config.host().c_str(), port()); @@ -67,7 +74,8 @@ class AsyncQpsServerTest : public Server { Server::CreateServerCredentials(config)); gpr_free(server_address); - builder.RegisterAsyncService(&async_service_); + register_service(&builder, &async_service_); + for (int i = 0; i < config.async_server_threads(); i++) { srv_cqs_.emplace_back(builder.AddCompletionQueue()); } @@ -75,22 +83,27 @@ class AsyncQpsServerTest : public Server { server_ = builder.BuildAndStart(); using namespace std::placeholders; + + auto process_rpc_bound = std::bind(process_rpc, config, _1, _2); + for (int i = 0; i < 10000 / config.async_server_threads(); i++) { for (int j = 0; j < config.async_server_threads(); j++) { - auto request_unary = std::bind( - &BenchmarkService::AsyncService::RequestUnaryCall, &async_service_, - _1, _2, _3, srv_cqs_[j].get(), srv_cqs_[j].get(), _4); - auto request_streaming = std::bind( - &BenchmarkService::AsyncService::RequestStreamingCall, - &async_service_, _1, _2, srv_cqs_[j].get(), srv_cqs_[j].get(), _3); - contexts_.push_front( - new ServerRpcContextUnaryImpl<SimpleRequest, SimpleResponse>( - request_unary, ProcessRPC)); - contexts_.push_front( - new ServerRpcContextStreamingImpl<SimpleRequest, SimpleResponse>( - request_streaming, ProcessRPC)); + if (request_unary_function) { + auto request_unary = std::bind( + request_unary_function, &async_service_, + _1, _2, _3, srv_cqs_[j].get(), srv_cqs_[j].get(), _4); + contexts_.push_front(new ServerRpcContextUnaryImpl(request_unary, process_rpc_bound)); + } + if (request_streaming_function) { + auto request_streaming = std::bind( + request_streaming_function, + &async_service_, _1, _2, srv_cqs_[j].get(), srv_cqs_[j].get(), _3); + contexts_.push_front(new ServerRpcContextStreamingImpl( + request_streaming, process_rpc_bound)); + } } } + for (int i = 0; i < config.async_server_threads(); i++) { shutdown_state_.emplace_back(new PerThreadShutdownState()); } @@ -155,16 +168,15 @@ class AsyncQpsServerTest : public Server { return reinterpret_cast<ServerRpcContext *>(tag); } - template <class RequestType, class ResponseType> class ServerRpcContextUnaryImpl GRPC_FINAL : public ServerRpcContext { public: ServerRpcContextUnaryImpl( - std::function<void(ServerContext *, RequestType *, + std::function<void(ServerContextType *, RequestType *, grpc::ServerAsyncResponseWriter<ResponseType> *, void *)> request_method, std::function<grpc::Status(const RequestType *, ResponseType *)> invoke_method) - : srv_ctx_(new ServerContext), + : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextUnaryImpl::invoker), request_method_(request_method), invoke_method_(invoke_method), @@ -177,7 +189,7 @@ class AsyncQpsServerTest : public Server { return (this->*next_state_)(ok); } void Reset() GRPC_OVERRIDE { - srv_ctx_.reset(new ServerContext); + srv_ctx_.reset(new ServerContextType); req_ = RequestType(); response_writer_ = grpc::ServerAsyncResponseWriter<ResponseType>(srv_ctx_.get()); @@ -205,10 +217,10 @@ class AsyncQpsServerTest : public Server { response_writer_.Finish(response, status, AsyncQpsServerTest::tag(this)); return true; } - std::unique_ptr<ServerContext> srv_ctx_; + std::unique_ptr<ServerContextType> srv_ctx_; RequestType req_; bool (ServerRpcContextUnaryImpl::*next_state_)(bool); - std::function<void(ServerContext *, RequestType *, + std::function<void(ServerContextType *, RequestType *, grpc::ServerAsyncResponseWriter<ResponseType> *, void *)> request_method_; std::function<grpc::Status(const RequestType *, ResponseType *)> @@ -216,16 +228,15 @@ class AsyncQpsServerTest : public Server { grpc::ServerAsyncResponseWriter<ResponseType> response_writer_; }; - template <class RequestType, class ResponseType> class ServerRpcContextStreamingImpl GRPC_FINAL : public ServerRpcContext { public: ServerRpcContextStreamingImpl( - std::function<void(ServerContext *, grpc::ServerAsyncReaderWriter< + std::function<void(ServerContextType *, grpc::ServerAsyncReaderWriter< ResponseType, RequestType> *, void *)> request_method, std::function<grpc::Status(const RequestType *, ResponseType *)> invoke_method) - : srv_ctx_(new ServerContext), + : srv_ctx_(new ServerContextType), next_state_(&ServerRpcContextStreamingImpl::request_done), request_method_(request_method), invoke_method_(invoke_method), @@ -237,7 +248,7 @@ class AsyncQpsServerTest : public Server { return (this->*next_state_)(ok); } void Reset() GRPC_OVERRIDE { - srv_ctx_.reset(new ServerContext); + srv_ctx_.reset(new ServerContextType); req_ = RequestType(); stream_ = grpc::ServerAsyncReaderWriter<ResponseType, RequestType>( srv_ctx_.get()); @@ -286,11 +297,11 @@ class AsyncQpsServerTest : public Server { } bool finish_done(bool ok) { return false; /* reset the context */ } - std::unique_ptr<ServerContext> srv_ctx_; + std::unique_ptr<ServerContextType> srv_ctx_; RequestType req_; bool (ServerRpcContextStreamingImpl::*next_state_)(bool); std::function<void( - ServerContext *, + ServerContextType *, grpc::ServerAsyncReaderWriter<ResponseType, RequestType> *, void *)> request_method_; std::function<grpc::Status(const RequestType *, ResponseType *)> @@ -298,20 +309,10 @@ class AsyncQpsServerTest : public Server { grpc::ServerAsyncReaderWriter<ResponseType, RequestType> stream_; }; - static Status ProcessRPC(const SimpleRequest *request, - SimpleResponse *response) { - if (request->response_size() > 0) { - if (!SetPayload(request->response_type(), request->response_size(), - response->mutable_payload())) { - return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); - } - } - return Status::OK; - } std::vector<std::thread> threads_; std::unique_ptr<grpc::Server> server_; std::vector<std::unique_ptr<grpc::ServerCompletionQueue>> srv_cqs_; - BenchmarkService::AsyncService async_service_; + ServiceType async_service_; std::forward_list<ServerRpcContext *> contexts_; class PerThreadShutdownState { @@ -335,8 +336,39 @@ class AsyncQpsServerTest : public Server { std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_; }; +static void RegisterBenchmarkService(ServerBuilder *builder, + BenchmarkService::AsyncService *service) { + builder->RegisterAsyncService(service); +} +static void RegisterGenericService(ServerBuilder *builder, + grpc::AsyncGenericService *service) { + builder->RegisterAsyncGenericService(service); +} + +template<class RequestType, class ResponseType> +Status ProcessRPC(const ServerConfig &config, const RequestType *request, + ResponseType *response) { + if (request->response_size() > 0) { + if (!Server::SetPayload(request->response_type(), request->response_size(), + response->mutable_payload())) { + return Status(grpc::StatusCode::INTERNAL, "Error creating payload."); + } + } + return Status::OK; +} + +template<> +Status ProcessRPC(const ServerConfig &config, const ByteBuffer *request, + ByteBuffer *response) { + return Status::OK; +} + + std::unique_ptr<Server> CreateAsyncServer(const ServerConfig &config) { - return std::unique_ptr<Server>(new AsyncQpsServerTest(config)); + return std::unique_ptr<Server>(new AsyncQpsServerTest<SimpleRequest,SimpleResponse,BenchmarkService::AsyncService,grpc::ServerContext>(config, RegisterBenchmarkService, &BenchmarkService::AsyncService::RequestUnaryCall, &BenchmarkService::AsyncService::RequestStreamingCall, ProcessRPC<SimpleRequest,SimpleResponse>)); +} +std::unique_ptr<Server> CreateAsyncGenericServer(const ServerConfig &config) { + return std::unique_ptr<Server>(new AsyncQpsServerTest<ByteBuffer, ByteBuffer, grpc::AsyncGenericService,grpc::GenericServerContext>(config, RegisterGenericService, nullptr, &grpc::AsyncGenericService::RequestCall, ProcessRPC<ByteBuffer, ByteBuffer>)); } } // namespace testing |