diff options
Diffstat (limited to 'src/cpp')
-rw-r--r-- | src/cpp/server/health/default_health_check_service.cc | 54 | ||||
-rw-r--r-- | src/cpp/server/health/default_health_check_service.h | 16 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 82 |
3 files changed, 133 insertions, 19 deletions
diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc index cec7ecce70..26d8fd999f 100644 --- a/src/cpp/server/health/default_health_check_service.cc +++ b/src/cpp/server/health/default_health_check_service.cc @@ -48,21 +48,9 @@ namespace { const char kHealthCheckMethodName[] = "/grpc.health.v1.Health/Check"; -} // namespace - -DefaultHealthCheckService::SyncHealthCheckServiceImpl:: - SyncHealthCheckServiceImpl(DefaultHealthCheckService* service) - : service_(service) { - auto* handler = - new RpcMethodHandler<SyncHealthCheckServiceImpl, ByteBuffer, ByteBuffer>( - std::mem_fn(&SyncHealthCheckServiceImpl::Check), this); - auto* method = new RpcServiceMethod(kHealthCheckMethodName, - RpcMethod::NORMAL_RPC, handler); - AddMethod(method); -} - -Status DefaultHealthCheckService::SyncHealthCheckServiceImpl::Check( - ServerContext* context, const ByteBuffer* request, ByteBuffer* response) { +Status CheckHealth(const DefaultHealthCheckService* service, + ServerContext* context, const ByteBuffer* request, + ByteBuffer* response) { // Decode request. std::vector<Slice> slices; request->Dump(&slices); @@ -99,7 +87,7 @@ Status DefaultHealthCheckService::SyncHealthCheckServiceImpl::Check( // Check status from the associated default health checking service. DefaultHealthCheckService::ServingStatus serving_status = - service_->GetServingStatus( + service->GetServingStatus( request_struct.has_service ? request_struct.service : ""); if (serving_status == DefaultHealthCheckService::NOT_FOUND) { return Status(StatusCode::NOT_FOUND, ""); @@ -129,9 +117,41 @@ Status DefaultHealthCheckService::SyncHealthCheckServiceImpl::Check( response->Swap(&response_buffer); return Status::OK; } +} // namespace + +DefaultHealthCheckService::SyncHealthCheckServiceImpl:: + SyncHealthCheckServiceImpl(DefaultHealthCheckService* service) + : service_(service) { + auto* handler = + new RpcMethodHandler<SyncHealthCheckServiceImpl, ByteBuffer, ByteBuffer>( + std::mem_fn(&SyncHealthCheckServiceImpl::Check), this); + auto* method = new RpcServiceMethod(kHealthCheckMethodName, + RpcMethod::NORMAL_RPC, handler); + AddMethod(method); +} + +Status DefaultHealthCheckService::SyncHealthCheckServiceImpl::Check( + ServerContext* context, const ByteBuffer* request, ByteBuffer* response) { + return CheckHealth(service_, context, request, response); +} + +DefaultHealthCheckService::AsyncHealthCheckServiceImpl:: + AsyncHealthCheckServiceImpl(DefaultHealthCheckService* service) + : service_(service) { + auto* method = new RpcServiceMethod(kHealthCheckMethodName, + RpcMethod::NORMAL_RPC, nullptr); + AddMethod(method); + method_ = method; +} + +Status DefaultHealthCheckService::AsyncHealthCheckServiceImpl::Check( + ServerContext* context, const ByteBuffer* request, ByteBuffer* response) { + return CheckHealth(service_, context, request, response); +} DefaultHealthCheckService::DefaultHealthCheckService() - : sync_service_(new SyncHealthCheckServiceImpl(this)) { + : sync_service_(new SyncHealthCheckServiceImpl(this)), + async_service_(new AsyncHealthCheckServiceImpl(this)) { services_map_.emplace("", true); } diff --git a/src/cpp/server/health/default_health_check_service.h b/src/cpp/server/health/default_health_check_service.h index 541c720aaa..411aac9713 100644 --- a/src/cpp/server/health/default_health_check_service.h +++ b/src/cpp/server/health/default_health_check_service.h @@ -56,6 +56,18 @@ class DefaultHealthCheckService : public HealthCheckServiceInterface { const DefaultHealthCheckService* service_; }; + class AsyncHealthCheckServiceImpl : public Service { + public: + explicit AsyncHealthCheckServiceImpl(DefaultHealthCheckService* service); + Status Check(ServerContext* context, const ByteBuffer* request, + ByteBuffer* response); + const RpcServiceMethod* method() const { return method_; } + + private: + const DefaultHealthCheckService* service_; + const RpcServiceMethod* method_; + }; + DefaultHealthCheckService(); void SetServingStatus(const grpc::string& service_name, bool serving) final; void SetServingStatus(bool serving) final; @@ -64,11 +76,15 @@ class DefaultHealthCheckService : public HealthCheckServiceInterface { SyncHealthCheckServiceImpl* GetSyncHealthCheckService() const { return sync_service_.get(); } + AsyncHealthCheckServiceImpl* GetAsyncHealthCheckService() const { + return async_service_.get(); + } private: mutable std::mutex mu_; std::map<grpc::string, bool> services_map_; std::unique_ptr<SyncHealthCheckServiceImpl> sync_service_; + std::unique_ptr<AsyncHealthCheckServiceImpl> async_service_; }; } // namespace grpc diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index c50c076bdc..20641aeea8 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -37,6 +37,7 @@ #include <grpc++/completion_queue.h> #include <grpc++/generic/async_generic_service.h> +#include <grpc++/impl/codegen/async_unary_call.h> #include <grpc++/impl/codegen/completion_queue_tag.h> #include <grpc++/impl/grpc_library.h> #include <grpc++/impl/method_handler_impl.h> @@ -118,6 +119,67 @@ class Server::UnimplementedAsyncResponse final UnimplementedAsyncRequest* const request_; }; +class Server::HealthCheckAsyncRequestContext { + protected: + HealthCheckAsyncRequestContext() : rpc_(&server_context_) {} + ServerContext server_context_; + ServerAsyncResponseWriter<ByteBuffer> rpc_; +}; + +class Server::HealthCheckAsyncRequest final + : public HealthCheckAsyncRequestContext, + public RegisteredAsyncRequest { + public: + HealthCheckAsyncRequest( + DefaultHealthCheckService::AsyncHealthCheckServiceImpl* service, + Server* server, ServerCompletionQueue* cq) + : RegisteredAsyncRequest(server, &server_context_, &rpc_, cq, this, + false), + service_(service), + server_(server), + cq_(cq), + had_request_(false) { + IssueRequest(service->method()->server_tag(), &payload_, cq); + } + + bool FinalizeResult(void** tag, bool* status) override; + + private: + DefaultHealthCheckService::AsyncHealthCheckServiceImpl* service_; + Server* const server_; + ServerCompletionQueue* const cq_; + grpc_byte_buffer* payload_; + bool had_request_; + ByteBuffer request_; + ByteBuffer response_; +}; + +bool Server::HealthCheckAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (!had_request_) { + had_request_ = true; + bool serialization_status = + *status && payload_ && + SerializationTraits<ByteBuffer>::Deserialize( + payload_, &request_, server_->max_receive_message_size()) + .ok(); + RegisteredAsyncRequest::FinalizeResult(tag, status); + *status = serialization_status && *status; + if (*status) { + new HealthCheckAsyncRequest(service_, server_, cq_); + Status s = service_->Check(&server_context_, &request_, &response_); + rpc_.Finish(response_, s, this); + return false; + } else { + // TODO what to do here + delete this; + return false; + } + } else { + delete this; + return false; + } +} + class ShutdownTag : public CompletionQueueTag { public: bool FinalizeResult(void** tag, bool* status) { return false; } @@ -498,6 +560,8 @@ bool Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { // Only create default health check service when user did not provide an // explicit one. + DefaultHealthCheckService::AsyncHealthCheckServiceImpl* async_health_service = + nullptr; if (health_check_service_ == nullptr && !health_check_service_disabled_ && DefaultHealthCheckServiceEnabled()) { auto* default_hc_service = new DefaultHealthCheckService; @@ -505,6 +569,10 @@ bool Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { if (!sync_server_cqs_->empty()) { // Has sync methods. RegisterService(nullptr, default_hc_service->GetSyncHealthCheckService()); } + if (sync_server_cqs_->empty()) { // No sync methods. + async_health_service = default_hc_service->GetAsyncHealthCheckService(); + RegisterService(nullptr, async_health_service); + } } grpc_server_start(server_); @@ -521,6 +589,14 @@ bool Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { } } + if (async_health_service) { + for (size_t i = 0; i < num_cqs; i++) { + if (cqs[i]->IsFrequentlyPolled()) { + new HealthCheckAsyncRequest(async_health_service, this, cqs[i]); + } + } + } + for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) { (*it)->Start(); } @@ -641,8 +717,10 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, - ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, void* tag, + bool delete_on_finalize) + : BaseAsyncRequest(server, context, stream, call_cq, tag, + delete_on_finalize) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, |