diff options
Diffstat (limited to 'src/cpp/server/server.cc')
-rw-r--r-- | src/cpp/server/server.cc | 304 |
1 files changed, 257 insertions, 47 deletions
diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index 1abdf702e2..ee9a1daa8e 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -37,25 +37,25 @@ #include <grpc/grpc.h> #include <grpc/grpc_security.h> #include <grpc/support/log.h> -#include "src/cpp/server/server_rpc_handler.h" -#include "src/cpp/server/thread_pool.h" -#include <grpc++/async_server_context.h> #include <grpc++/completion_queue.h> #include <grpc++/impl/rpc_service_method.h> +#include <grpc++/impl/service_type.h> +#include <grpc++/server_context.h> #include <grpc++/server_credentials.h> +#include <grpc++/thread_pool_interface.h> -namespace grpc { +#include "src/cpp/proto/proto_utils.h" +#include "src/cpp/util/time.h" -// TODO(rocking): consider a better default value like num of cores. -static const int kNumThreads = 4; +namespace grpc { -Server::Server(ThreadPoolInterface *thread_pool, ServerCredentials *creds) +Server::Server(ThreadPoolInterface* thread_pool, bool thread_pool_owned, + ServerCredentials* creds) : started_(false), shutdown_(false), num_running_cb_(0), - thread_pool_(thread_pool == nullptr ? new ThreadPool(kNumThreads) - : thread_pool), - thread_pool_owned_(thread_pool == nullptr), + thread_pool_(thread_pool), + thread_pool_owned_(thread_pool_owned), secure_(creds != nullptr) { if (creds) { server_ = @@ -75,6 +75,8 @@ Server::~Server() { if (started_ && !shutdown_) { lock.unlock(); Shutdown(); + } else { + lock.unlock(); } grpc_server_destroy(server_); if (thread_pool_owned_) { @@ -82,37 +84,180 @@ Server::~Server() { } } -void Server::RegisterService(RpcService *service) { +bool Server::RegisterService(RpcService* service) { for (int i = 0; i < service->GetMethodCount(); ++i) { - RpcServiceMethod *method = service->GetMethod(i); - method_map_.insert(std::make_pair(method->name(), method)); + RpcServiceMethod* method = service->GetMethod(i); + void* tag = + grpc_server_register_method(server_, method->name(), nullptr, cq_.cq()); + if (!tag) { + gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", + method->name()); + return false; + } + sync_methods_.emplace_back(method, tag); } + return true; } -void Server::AddPort(const grpc::string &addr) { +bool Server::RegisterAsyncService(AsynchronousService* service) { + GPR_ASSERT(service->dispatch_impl_ == nullptr && + "Can only register an asynchronous service against one server."); + service->dispatch_impl_ = this; + service->request_args_ = new void* [service->method_count_]; + for (size_t i = 0; i < service->method_count_; ++i) { + void* tag = + grpc_server_register_method(server_, service->method_names_[i], nullptr, + service->completion_queue()->cq()); + if (!tag) { + gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", + service->method_names_[i]); + return false; + } + service->request_args_[i] = tag; + } + return true; +} + +int Server::AddPort(const grpc::string& addr) { GPR_ASSERT(!started_); - int success; if (secure_) { - success = grpc_server_add_secure_http2_port(server_, addr.c_str()); + return grpc_server_add_secure_http2_port(server_, addr.c_str()); } else { - success = grpc_server_add_http2_port(server_, addr.c_str()); + return grpc_server_add_http2_port(server_, addr.c_str()); } - GPR_ASSERT(success); } -void Server::Start() { +class Server::SyncRequest final : public CompletionQueueTag { + public: + SyncRequest(RpcServiceMethod* method, void* tag) + : method_(method), + tag_(tag), + has_request_payload_(method->method_type() == RpcMethod::NORMAL_RPC || + method->method_type() == + RpcMethod::SERVER_STREAMING), + has_response_payload_(method->method_type() == RpcMethod::NORMAL_RPC || + method->method_type() == + RpcMethod::CLIENT_STREAMING) { + grpc_metadata_array_init(&request_metadata_); + } + + static SyncRequest* Wait(CompletionQueue* cq, bool* ok) { + void* tag = nullptr; + *ok = false; + if (!cq->Next(&tag, ok)) { + return nullptr; + } + auto* mrd = static_cast<SyncRequest*>(tag); + GPR_ASSERT(mrd->in_flight_); + return mrd; + } + + void Request(grpc_server* server) { + GPR_ASSERT(!in_flight_); + in_flight_ = true; + cq_ = grpc_completion_queue_create(); + GPR_ASSERT(GRPC_CALL_OK == + grpc_server_request_registered_call( + server, tag_, &call_, &deadline_, &request_metadata_, + has_request_payload_ ? &request_payload_ : nullptr, cq_, + this)); + } + + void FinalizeResult(void** tag, bool* status) override { + if (!*status) { + grpc_completion_queue_destroy(cq_); + } + } + + class CallData final { + public: + explicit CallData(Server* server, SyncRequest* mrd) + : cq_(mrd->cq_), + call_(mrd->call_, server, &cq_), + ctx_(mrd->deadline_, mrd->request_metadata_.metadata, + mrd->request_metadata_.count), + has_request_payload_(mrd->has_request_payload_), + has_response_payload_(mrd->has_response_payload_), + request_payload_(mrd->request_payload_), + method_(mrd->method_) { + ctx_.call_ = mrd->call_; + GPR_ASSERT(mrd->in_flight_); + mrd->in_flight_ = false; + mrd->request_metadata_.count = 0; + } + + ~CallData() { + if (has_request_payload_ && request_payload_) { + grpc_byte_buffer_destroy(request_payload_); + } + } + + void Run() { + std::unique_ptr<google::protobuf::Message> req; + std::unique_ptr<google::protobuf::Message> res; + if (has_request_payload_) { + req.reset(method_->AllocateRequestProto()); + if (!DeserializeProto(request_payload_, req.get())) { + abort(); // for now + } + } + if (has_response_payload_) { + res.reset(method_->AllocateResponseProto()); + } + auto status = method_->handler()->RunHandler( + MethodHandler::HandlerParameter(&call_, &ctx_, req.get(), res.get())); + CallOpBuffer buf; + if (!ctx_.sent_initial_metadata_) { + buf.AddSendInitialMetadata(&ctx_.initial_metadata_); + } + if (has_response_payload_) { + buf.AddSendMessage(*res); + } + buf.AddServerSendStatus(&ctx_.trailing_metadata_, status); + bool cancelled; + buf.AddServerRecvClose(&cancelled); + call_.PerformOps(&buf); + GPR_ASSERT(cq_.Pluck(&buf)); + } + + private: + CompletionQueue cq_; + Call call_; + ServerContext ctx_; + const bool has_request_payload_; + const bool has_response_payload_; + grpc_byte_buffer* request_payload_; + RpcServiceMethod* const method_; + }; + + private: + RpcServiceMethod* const method_; + void* const tag_; + bool in_flight_ = false; + const bool has_request_payload_; + const bool has_response_payload_; + grpc_call* call_; + gpr_timespec deadline_; + grpc_metadata_array request_metadata_; + grpc_byte_buffer* request_payload_; + grpc_completion_queue* cq_; +}; + +bool Server::Start() { GPR_ASSERT(!started_); started_ = true; grpc_server_start(server_); // Start processing rpcs. - ScheduleCallback(); -} + if (!sync_methods_.empty()) { + for (auto& m : sync_methods_) { + m.Request(server_); + } + + ScheduleCallback(); + } -void Server::AllowOneRpc() { - GPR_ASSERT(started_); - grpc_call_error err = grpc_server_request_call_old(server_, nullptr); - GPR_ASSERT(err == GRPC_CALL_OK); + return true; } void Server::Shutdown() { @@ -121,6 +266,7 @@ void Server::Shutdown() { if (started_ && !shutdown_) { shutdown_ = true; grpc_server_shutdown(server_); + cq_.Shutdown(); // Wait for running callbacks to finish. while (num_running_cb_ != 0) { @@ -128,12 +274,85 @@ void Server::Shutdown() { } } } +} + +void Server::PerformOpsOnCall(CallOpBuffer* buf, Call* call) { + static const size_t MAX_OPS = 8; + size_t nops = MAX_OPS; + grpc_op ops[MAX_OPS]; + buf->FillOps(ops, &nops); + GPR_ASSERT(GRPC_CALL_OK == + grpc_call_start_batch(call->call(), ops, nops, buf)); +} + +class Server::AsyncRequest final : public CompletionQueueTag { + public: + AsyncRequest(Server* server, void* registered_method, ServerContext* ctx, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, CompletionQueue* cq, + void* tag) + : tag_(tag), + request_(request), + stream_(stream), + cq_(cq), + ctx_(ctx), + server_(server) { + memset(&array_, 0, sizeof(array_)); + grpc_server_request_registered_call( + server->server_, registered_method, &call_, &deadline_, &array_, + request ? &payload_ : nullptr, cq->cq(), this); + } + + ~AsyncRequest() { + if (payload_) { + grpc_byte_buffer_destroy(payload_); + } + grpc_metadata_array_destroy(&array_); + } + + void FinalizeResult(void** tag, bool* status) override { + *tag = tag_; + if (*status && request_) { + if (payload_) { + *status = *status && DeserializeProto(payload_, request_); + } else { + *status = false; + } + } + if (*status) { + ctx_->deadline_ = Timespec2Timepoint(deadline_); + for (size_t i = 0; i < array_.count; i++) { + ctx_->client_metadata_.insert(std::make_pair( + grpc::string(array_.metadata[i].key), + grpc::string( + array_.metadata[i].value, + array_.metadata[i].value + array_.metadata[i].value_length))); + } + } + ctx_->call_ = call_; + Call call(call_, server_, cq_); + stream_->BindCall(&call); + delete this; + } + + private: + void* const tag_; + ::google::protobuf::Message* const request_; + ServerAsyncStreamingInterface* const stream_; + CompletionQueue* const cq_; + ServerContext* const ctx_; + Server* const server_; + grpc_call* call_ = nullptr; + gpr_timespec deadline_; + grpc_metadata_array array_; + grpc_byte_buffer* payload_ = nullptr; +}; - // Shutdown the completion queue. - cq_.Shutdown(); - void *tag = nullptr; - CompletionQueue::CompletionType t = cq_.Next(&tag); - GPR_ASSERT(t == CompletionQueue::QUEUE_CLOSED); +void Server::RequestAsyncCall(void* registered_method, ServerContext* context, + ::google::protobuf::Message* request, + ServerAsyncStreamingInterface* stream, + CompletionQueue* cq, void* tag) { + new AsyncRequest(this, registered_method, context, request, stream, cq, tag); } void Server::ScheduleCallback() { @@ -141,30 +360,21 @@ void Server::ScheduleCallback() { std::unique_lock<std::mutex> lock(mu_); num_running_cb_++; } - std::function<void()> callback = std::bind(&Server::RunRpc, this); - thread_pool_->ScheduleCallback(callback); + thread_pool_->ScheduleCallback(std::bind(&Server::RunRpc, this)); } void Server::RunRpc() { // Wait for one more incoming rpc. - void *tag = nullptr; - AllowOneRpc(); - CompletionQueue::CompletionType t = cq_.Next(&tag); - GPR_ASSERT(t == CompletionQueue::SERVER_RPC_NEW); - - AsyncServerContext *server_context = static_cast<AsyncServerContext *>(tag); - // server_context could be nullptr during server shutdown. - if (server_context != nullptr) { - // Schedule a new callback to handle more rpcs. + bool ok; + auto* mrd = SyncRequest::Wait(&cq_, &ok); + if (mrd) { ScheduleCallback(); + if (ok) { + SyncRequest::CallData cd(this, mrd); + mrd->Request(server_); - RpcServiceMethod *method = nullptr; - auto iter = method_map_.find(server_context->method()); - if (iter != method_map_.end()) { - method = iter->second; + cd.Run(); } - ServerRpcHandler rpc_handler(server_context, method); - rpc_handler.StartRpc(); } { |