diff options
Diffstat (limited to 'src/cpp')
-rw-r--r-- | src/cpp/client/channel_cc.cc | 40 | ||||
-rw-r--r-- | src/cpp/client/client_context.cc | 12 | ||||
-rw-r--r-- | src/cpp/client/client_interceptor.cc | 34 | ||||
-rw-r--r-- | src/cpp/client/secure_credentials.cc | 7 | ||||
-rw-r--r-- | src/cpp/common/completion_queue_cc.cc | 10 | ||||
-rw-r--r-- | src/cpp/common/core_codegen.cc | 7 | ||||
-rw-r--r-- | src/cpp/common/version_cc.cc | 2 | ||||
-rw-r--r-- | src/cpp/server/channelz/channelz_service.cc | 92 | ||||
-rw-r--r-- | src/cpp/server/channelz/channelz_service.h | 9 | ||||
-rw-r--r-- | src/cpp/server/health/default_health_check_service.cc | 475 | ||||
-rw-r--r-- | src/cpp/server/health/default_health_check_service.h | 234 | ||||
-rw-r--r-- | src/cpp/server/health/health.pb.c | 24 | ||||
-rw-r--r-- | src/cpp/server/health/health.pb.h | 72 | ||||
-rw-r--r-- | src/cpp/server/secure_server_credentials.cc | 7 | ||||
-rw-r--r-- | src/cpp/server/server_builder.cc | 38 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 495 | ||||
-rw-r--r-- | src/cpp/server/server_context.cc | 231 |
17 files changed, 1425 insertions, 364 deletions
diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 2cab41b3f5..8e1cea0269 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -33,6 +33,7 @@ #include <grpcpp/client_context.h> #include <grpcpp/completion_queue.h> #include <grpcpp/impl/call.h> +#include <grpcpp/impl/codegen/call_op_set.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/grpc_library.h> #include <grpcpp/impl/rpc_method.h> @@ -57,9 +58,8 @@ Channel::Channel( std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>> interceptor_creators) : host_(host), c_channel_(channel) { - auto* vector = interceptor_creators.release(); - if (vector != nullptr) { - interceptor_creators_ = std::move(*vector); + if (interceptor_creators != nullptr) { + interceptor_creators_ = std::move(*interceptor_creators); } g_gli_initializer.summon(); } @@ -112,9 +112,10 @@ void ChannelResetConnectionBackoff(Channel* channel) { } // namespace experimental -internal::Call Channel::CreateCall(const internal::RpcMethod& method, - ClientContext* context, - CompletionQueue* cq) { +internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq, + size_t interceptor_pos) { const bool kRegistered = method.channel_tag() && context->authority().empty(); grpc_call* c_call = nullptr; if (kRegistered) { @@ -146,18 +147,27 @@ internal::Call Channel::CreateCall(const internal::RpcMethod& method, } } grpc_census_call_set_context(c_call, context->census_context()); + + // ClientRpcInfo should be set before call because set_call also checks + // whether the call has been cancelled, and if the call was cancelled, we + // should notify the interceptors too/ + auto* info = context->set_client_rpc_info( + method.name(), this, interceptor_creators_, interceptor_pos); context->set_call(c_call, shared_from_this()); - return internal::Call(c_call, this, cq); + + return internal::Call(c_call, this, cq, info); +} + +internal::Call Channel::CreateCall(const internal::RpcMethod& method, + ClientContext* context, + CompletionQueue* cq) { + return CreateCallInternal(method, context, cq, 0); } void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), cops, nops, - ops->cq_tag(), nullptr)); + ops->FillOps( + call); // Make a copy of call. It's fine since Call just has pointers } void* Channel::RegisterMethod(const char* method) { @@ -219,7 +229,7 @@ class ShutdownCallback : public grpc_experimental_completion_queue_functor { static void Run(grpc_experimental_completion_queue_functor* cb, int) { auto* callback = static_cast<ShutdownCallback*>(cb); delete callback->cq_; - grpc_core::Delete(callback); + delete callback; } private: @@ -232,7 +242,7 @@ CompletionQueue* Channel::CallbackCQ() { // if there is no explicit per-channel CQ registered std::lock_guard<std::mutex> l(mu_); if (callback_cq_ == nullptr) { - auto* shutdown_callback = grpc_core::New<ShutdownCallback>(); + auto* shutdown_callback = new ShutdownCallback; callback_cq_ = new CompletionQueue(grpc_completion_queue_attributes{ GRPC_CQ_CURRENT_VERSION, GRPC_CQ_CALLBACK, GRPC_CQ_DEFAULT_POLLING, shutdown_callback}); diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index 07a04e4268..50da75f09c 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -24,6 +24,7 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> +#include <grpcpp/impl/codegen/interceptor_common.h> #include <grpcpp/impl/grpc_library.h> #include <grpcpp/security/credentials.h> #include <grpcpp/server_context.h> @@ -86,10 +87,13 @@ void ClientContext::set_call(grpc_call* call, call_ = call; channel_ = channel; if (creds_ && !creds_->ApplyToCall(call_)) { + // TODO(yashykt): should interceptors also see this status? + SendCancelToInterceptors(); grpc_call_cancel_with_status(call, GRPC_STATUS_CANCELLED, "Failed to set credentials to rpc.", nullptr); } if (call_canceled_) { + SendCancelToInterceptors(); grpc_call_cancel(call_, nullptr); } } @@ -110,12 +114,20 @@ void ClientContext::set_compression_algorithm( void ClientContext::TryCancel() { std::unique_lock<std::mutex> lock(mu_); if (call_) { + SendCancelToInterceptors(); grpc_call_cancel(call_, nullptr); } else { call_canceled_ = true; } } +void ClientContext::SendCancelToInterceptors() { + internal::CancelInterceptorBatchMethods cancel_methods; + for (size_t i = 0; i < rpc_info_.interceptors_.size(); i++) { + rpc_info_.RunInterceptor(&cancel_methods, i); + } +} + grpc::string ClientContext::peer() const { grpc::string peer; if (call_) { diff --git a/src/cpp/client/client_interceptor.cc b/src/cpp/client/client_interceptor.cc new file mode 100644 index 0000000000..3a5cac9830 --- /dev/null +++ b/src/cpp/client/client_interceptor.cc @@ -0,0 +1,34 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpcpp/impl/codegen/client_interceptor.h> + +namespace grpc { + +namespace internal { +experimental::ClientInterceptorFactoryInterface* + g_global_client_interceptor_factory = nullptr; +} + +namespace experimental { +void RegisterGlobalClientInterceptorFactory( + ClientInterceptorFactoryInterface* factory) { + internal::g_global_client_interceptor_factory = factory; +} +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index d1cd78e755..7faaa20e78 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -228,9 +228,10 @@ int MetadataCredentialsPluginWrapper::GetMetadata( } if (w->plugin_->IsBlocking()) { // Asynchronous return. - w->thread_pool_->Add( - std::bind(&MetadataCredentialsPluginWrapper::InvokePlugin, w, context, - cb, user_data, nullptr, nullptr, nullptr, nullptr)); + w->thread_pool_->Add([w, context, cb, user_data] { + w->MetadataCredentialsPluginWrapper::InvokePlugin( + context, cb, user_data, nullptr, nullptr, nullptr, nullptr); + }); return 0; } else { // Synchronous return. diff --git a/src/cpp/common/completion_queue_cc.cc b/src/cpp/common/completion_queue_cc.cc index 6893201e2e..d93a54aed7 100644 --- a/src/cpp/common/completion_queue_cc.cc +++ b/src/cpp/common/completion_queue_cc.cc @@ -60,10 +60,10 @@ CompletionQueue::NextStatus CompletionQueue::AsyncNextInternal( case GRPC_QUEUE_SHUTDOWN: return SHUTDOWN; case GRPC_OP_COMPLETE: - auto cq_tag = static_cast<internal::CompletionQueueTag*>(ev.tag); + auto core_cq_tag = static_cast<internal::CompletionQueueTag*>(ev.tag); *ok = ev.success != 0; - *tag = cq_tag; - if (cq_tag->FinalizeResult(tag, ok)) { + *tag = core_cq_tag; + if (core_cq_tag->FinalizeResult(tag, ok)) { return GOT_EVENT; } break; @@ -87,9 +87,9 @@ bool CompletionQueue::CompletionQueueTLSCache::Flush(void** tag, bool* ok) { flushed_ = true; if (grpc_completion_queue_thread_local_cache_flush(cq_->cq_, &res_tag, &res)) { - auto cq_tag = static_cast<internal::CompletionQueueTag*>(res_tag); + auto core_cq_tag = static_cast<internal::CompletionQueueTag*>(res_tag); *ok = res == 1; - if (cq_tag->FinalizeResult(tag, ok)) { + if (core_cq_tag->FinalizeResult(tag, ok)) { return true; } } diff --git a/src/cpp/common/core_codegen.cc b/src/cpp/common/core_codegen.cc index 619aacadaa..cfaa2e7b19 100644 --- a/src/cpp/common/core_codegen.cc +++ b/src/cpp/common/core_codegen.cc @@ -102,6 +102,13 @@ size_t CoreCodegen::grpc_byte_buffer_length(grpc_byte_buffer* bb) { return ::grpc_byte_buffer_length(bb); } +grpc_call_error CoreCodegen::grpc_call_start_batch(grpc_call* call, + const grpc_op* ops, + size_t nops, void* tag, + void* reserved) { + return ::grpc_call_start_batch(call, ops, nops, tag, reserved); +} + grpc_call_error CoreCodegen::grpc_call_cancel_with_status( grpc_call* call, grpc_status_code status, const char* description, void* reserved) { diff --git a/src/cpp/common/version_cc.cc b/src/cpp/common/version_cc.cc index cc797f1546..8abd45efb7 100644 --- a/src/cpp/common/version_cc.cc +++ b/src/cpp/common/version_cc.cc @@ -22,5 +22,5 @@ #include <grpcpp/grpcpp.h> namespace grpc { -grpc::string Version() { return "1.16.0-dev"; } +grpc::string Version() { return "1.17.0-dev"; } } // namespace grpc diff --git a/src/cpp/server/channelz/channelz_service.cc b/src/cpp/server/channelz/channelz_service.cc index 4e3fe8c1c9..9ecb9de7e4 100644 --- a/src/cpp/server/channelz/channelz_service.cc +++ b/src/cpp/server/channelz/channelz_service.cc @@ -20,9 +20,6 @@ #include "src/cpp/server/channelz/channelz_service.h" -#include <google/protobuf/text_format.h> -#include <google/protobuf/util/json_util.h> - #include <grpc/grpc.h> #include <grpc/support/alloc.h> @@ -33,13 +30,14 @@ Status ChannelzService::GetTopChannels( channelz::v1::GetTopChannelsResponse* response) { char* json_str = grpc_channelz_get_top_channels(request->start_channel_id()); if (json_str == nullptr) { - return Status(INTERNAL, "grpc_channelz_get_top_channels returned null"); + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_top_channels returned null"); } - google::protobuf::util::Status s = - google::protobuf::util::JsonStringToMessage(json_str, response); + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); gpr_free(json_str); - if (s != google::protobuf::util::Status::OK) { - return Status(INTERNAL, s.ToString()); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); } return Status::OK; } @@ -49,13 +47,49 @@ Status ChannelzService::GetServers( channelz::v1::GetServersResponse* response) { char* json_str = grpc_channelz_get_servers(request->start_server_id()); if (json_str == nullptr) { - return Status(INTERNAL, "grpc_channelz_get_servers returned null"); + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_servers returned null"); + } + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetServer(ServerContext* unused, + const channelz::v1::GetServerRequest* request, + channelz::v1::GetServerResponse* response) { + char* json_str = grpc_channelz_get_server(request->server_id()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_server returned null"); + } + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); + gpr_free(json_str); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); + } + return Status::OK; +} + +Status ChannelzService::GetServerSockets( + ServerContext* unused, const channelz::v1::GetServerSocketsRequest* request, + channelz::v1::GetServerSocketsResponse* response) { + char* json_str = grpc_channelz_get_server_sockets(request->server_id(), + request->start_socket_id()); + if (json_str == nullptr) { + return Status(StatusCode::INTERNAL, + "grpc_channelz_get_server_sockets returned null"); } - google::protobuf::util::Status s = - google::protobuf::util::JsonStringToMessage(json_str, response); + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); gpr_free(json_str); - if (s != google::protobuf::util::Status::OK) { - return Status(INTERNAL, s.ToString()); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); } return Status::OK; } @@ -65,13 +99,13 @@ Status ChannelzService::GetChannel( channelz::v1::GetChannelResponse* response) { char* json_str = grpc_channelz_get_channel(request->channel_id()); if (json_str == nullptr) { - return Status(NOT_FOUND, "No object found for that ChannelId"); + return Status(StatusCode::NOT_FOUND, "No object found for that ChannelId"); } - google::protobuf::util::Status s = - google::protobuf::util::JsonStringToMessage(json_str, response); + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); gpr_free(json_str); - if (s != google::protobuf::util::Status::OK) { - return Status(INTERNAL, s.ToString()); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); } return Status::OK; } @@ -81,13 +115,14 @@ Status ChannelzService::GetSubchannel( channelz::v1::GetSubchannelResponse* response) { char* json_str = grpc_channelz_get_subchannel(request->subchannel_id()); if (json_str == nullptr) { - return Status(NOT_FOUND, "No object found for that SubchannelId"); + return Status(StatusCode::NOT_FOUND, + "No object found for that SubchannelId"); } - google::protobuf::util::Status s = - google::protobuf::util::JsonStringToMessage(json_str, response); + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); gpr_free(json_str); - if (s != google::protobuf::util::Status::OK) { - return Status(INTERNAL, s.ToString()); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); } return Status::OK; } @@ -96,15 +131,14 @@ Status ChannelzService::GetSocket(ServerContext* unused, const channelz::v1::GetSocketRequest* request, channelz::v1::GetSocketResponse* response) { char* json_str = grpc_channelz_get_socket(request->socket_id()); - gpr_log(GPR_ERROR, "%s", json_str); if (json_str == nullptr) { - return Status(NOT_FOUND, "No object found for that SocketId"); + return Status(StatusCode::NOT_FOUND, "No object found for that SocketId"); } - google::protobuf::util::Status s = - google::protobuf::util::JsonStringToMessage(json_str, response); + grpc::protobuf::util::Status s = + grpc::protobuf::json::JsonStringToMessage(json_str, response); gpr_free(json_str); - if (s != google::protobuf::util::Status::OK) { - return Status(INTERNAL, s.ToString()); + if (!s.ok()) { + return Status(StatusCode::INTERNAL, s.ToString()); } return Status::OK; } diff --git a/src/cpp/server/channelz/channelz_service.h b/src/cpp/server/channelz/channelz_service.h index 1be4e01c73..b4a66ba1c6 100644 --- a/src/cpp/server/channelz/channelz_service.h +++ b/src/cpp/server/channelz/channelz_service.h @@ -36,6 +36,15 @@ class ChannelzService final : public channelz::v1::Channelz::Service { Status GetServers(ServerContext* unused, const channelz::v1::GetServersRequest* request, channelz::v1::GetServersResponse* response) override; + // implementation of GetServer rpc + Status GetServer(ServerContext* unused, + const channelz::v1::GetServerRequest* request, + channelz::v1::GetServerResponse* response) override; + // implementation of GetServerSockets rpc + Status GetServerSockets( + ServerContext* unused, + const channelz::v1::GetServerSocketsRequest* request, + channelz::v1::GetServerSocketsResponse* response) override; // implementation of GetChannel rpc Status GetChannel(ServerContext* unused, const channelz::v1::GetChannelRequest* request, diff --git a/src/cpp/server/health/default_health_check_service.cc b/src/cpp/server/health/default_health_check_service.cc index bfda67d086..c951c69d51 100644 --- a/src/cpp/server/health/default_health_check_service.cc +++ b/src/cpp/server/health/default_health_check_service.cc @@ -26,79 +26,199 @@ #include "pb_decode.h" #include "pb_encode.h" +#include "src/core/ext/filters/client_channel/health/health.pb.h" #include "src/cpp/server/health/default_health_check_service.h" -#include "src/cpp/server/health/health.pb.h" namespace grpc { + +// +// DefaultHealthCheckService +// + +DefaultHealthCheckService::DefaultHealthCheckService() { + services_map_[""].SetServingStatus(SERVING); +} + +void DefaultHealthCheckService::SetServingStatus( + const grpc::string& service_name, bool serving) { + std::unique_lock<std::mutex> lock(mu_); + services_map_[service_name].SetServingStatus(serving ? SERVING : NOT_SERVING); +} + +void DefaultHealthCheckService::SetServingStatus(bool serving) { + const ServingStatus status = serving ? SERVING : NOT_SERVING; + std::unique_lock<std::mutex> lock(mu_); + for (auto& p : services_map_) { + ServiceData& service_data = p.second; + service_data.SetServingStatus(status); + } +} + +DefaultHealthCheckService::ServingStatus +DefaultHealthCheckService::GetServingStatus( + const grpc::string& service_name) const { + std::lock_guard<std::mutex> lock(mu_); + auto it = services_map_.find(service_name); + if (it == services_map_.end()) { + return NOT_FOUND; + } + const ServiceData& service_data = it->second; + return service_data.GetServingStatus(); +} + +void DefaultHealthCheckService::RegisterCallHandler( + const grpc::string& service_name, + std::shared_ptr<HealthCheckServiceImpl::CallHandler> handler) { + std::unique_lock<std::mutex> lock(mu_); + ServiceData& service_data = services_map_[service_name]; + service_data.AddCallHandler(handler /* copies ref */); + HealthCheckServiceImpl::CallHandler* h = handler.get(); + h->SendHealth(std::move(handler), service_data.GetServingStatus()); +} + +void DefaultHealthCheckService::UnregisterCallHandler( + const grpc::string& service_name, + const std::shared_ptr<HealthCheckServiceImpl::CallHandler>& handler) { + std::unique_lock<std::mutex> lock(mu_); + auto it = services_map_.find(service_name); + if (it == services_map_.end()) return; + ServiceData& service_data = it->second; + service_data.RemoveCallHandler(handler); + if (service_data.Unused()) { + services_map_.erase(it); + } +} + +DefaultHealthCheckService::HealthCheckServiceImpl* +DefaultHealthCheckService::GetHealthCheckService( + std::unique_ptr<ServerCompletionQueue> cq) { + GPR_ASSERT(impl_ == nullptr); + impl_.reset(new HealthCheckServiceImpl(this, std::move(cq))); + return impl_.get(); +} + +// +// DefaultHealthCheckService::ServiceData +// + +void DefaultHealthCheckService::ServiceData::SetServingStatus( + ServingStatus status) { + status_ = status; + for (auto& call_handler : call_handlers_) { + call_handler->SendHealth(call_handler /* copies ref */, status); + } +} + +void DefaultHealthCheckService::ServiceData::AddCallHandler( + std::shared_ptr<HealthCheckServiceImpl::CallHandler> handler) { + call_handlers_.insert(std::move(handler)); +} + +void DefaultHealthCheckService::ServiceData::RemoveCallHandler( + const std::shared_ptr<HealthCheckServiceImpl::CallHandler>& handler) { + call_handlers_.erase(handler); +} + +// +// DefaultHealthCheckService::HealthCheckServiceImpl +// + namespace { const char kHealthCheckMethodName[] = "/grpc.health.v1.Health/Check"; +const char kHealthWatchMethodName[] = "/grpc.health.v1.Health/Watch"; } // namespace DefaultHealthCheckService::HealthCheckServiceImpl::HealthCheckServiceImpl( - DefaultHealthCheckService* service) - : service_(service), method_(nullptr) { - internal::MethodHandler* handler = - new internal::RpcMethodHandler<HealthCheckServiceImpl, ByteBuffer, - ByteBuffer>( - std::mem_fn(&HealthCheckServiceImpl::Check), this); - method_ = new internal::RpcServiceMethod( - kHealthCheckMethodName, internal::RpcMethod::NORMAL_RPC, handler); - AddMethod(method_); -} - -Status DefaultHealthCheckService::HealthCheckServiceImpl::Check( - ServerContext* context, const ByteBuffer* request, ByteBuffer* response) { - // Decode request. - std::vector<Slice> slices; - if (!request->Dump(&slices).ok()) { - return Status(StatusCode::INVALID_ARGUMENT, ""); + DefaultHealthCheckService* database, + std::unique_ptr<ServerCompletionQueue> cq) + : database_(database), cq_(std::move(cq)) { + // Add Check() method. + AddMethod(new internal::RpcServiceMethod( + kHealthCheckMethodName, internal::RpcMethod::NORMAL_RPC, nullptr)); + // Add Watch() method. + AddMethod(new internal::RpcServiceMethod( + kHealthWatchMethodName, internal::RpcMethod::SERVER_STREAMING, nullptr)); + // Create serving thread. + thread_ = std::unique_ptr<::grpc_core::Thread>( + new ::grpc_core::Thread("grpc_health_check_service", Serve, this)); +} + +DefaultHealthCheckService::HealthCheckServiceImpl::~HealthCheckServiceImpl() { + // We will reach here after the server starts shutting down. + shutdown_ = true; + { + std::unique_lock<std::mutex> lock(cq_shutdown_mu_); + cq_->Shutdown(); } + thread_->Join(); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::StartServingThread() { + // Request the calls we're interested in. + // We do this before starting the serving thread, so that we know it's + // done before server startup is complete. + CheckCallHandler::CreateAndStart(cq_.get(), database_, this); + WatchCallHandler::CreateAndStart(cq_.get(), database_, this); + // Start serving thread. + thread_->Start(); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::Serve(void* arg) { + HealthCheckServiceImpl* service = + reinterpret_cast<HealthCheckServiceImpl*>(arg); + void* tag; + bool ok; + while (true) { + if (!service->cq_->Next(&tag, &ok)) { + // The completion queue is shutting down. + GPR_ASSERT(service->shutdown_); + break; + } + auto* next_step = static_cast<CallableTag*>(tag); + next_step->Run(ok); + } +} + +bool DefaultHealthCheckService::HealthCheckServiceImpl::DecodeRequest( + const ByteBuffer& request, grpc::string* service_name) { + std::vector<Slice> slices; + if (!request.Dump(&slices).ok()) return false; uint8_t* request_bytes = nullptr; - bool request_bytes_owned = false; size_t request_size = 0; grpc_health_v1_HealthCheckRequest request_struct; - if (slices.empty()) { - request_struct.has_service = false; - } else if (slices.size() == 1) { + request_struct.has_service = false; + if (slices.size() == 1) { request_bytes = const_cast<uint8_t*>(slices[0].begin()); request_size = slices[0].size(); - } else { - request_bytes_owned = true; - request_bytes = static_cast<uint8_t*>(gpr_malloc(request->Length())); + } else if (slices.size() > 1) { + request_bytes = static_cast<uint8_t*>(gpr_malloc(request.Length())); uint8_t* copy_to = request_bytes; for (size_t i = 0; i < slices.size(); i++) { memcpy(copy_to, slices[i].begin(), slices[i].size()); copy_to += slices[i].size(); } } - - if (request_bytes != nullptr) { - pb_istream_t istream = pb_istream_from_buffer(request_bytes, request_size); - bool decode_status = pb_decode( - &istream, grpc_health_v1_HealthCheckRequest_fields, &request_struct); - if (request_bytes_owned) { - gpr_free(request_bytes); - } - if (!decode_status) { - return Status(StatusCode::INVALID_ARGUMENT, ""); - } - } - - // Check status from the associated default health checking service. - DefaultHealthCheckService::ServingStatus serving_status = - service_->GetServingStatus( - request_struct.has_service ? request_struct.service : ""); - if (serving_status == DefaultHealthCheckService::NOT_FOUND) { - return Status(StatusCode::NOT_FOUND, ""); + pb_istream_t istream = pb_istream_from_buffer(request_bytes, request_size); + bool decode_status = pb_decode( + &istream, grpc_health_v1_HealthCheckRequest_fields, &request_struct); + if (slices.size() > 1) { + gpr_free(request_bytes); } + if (!decode_status) return false; + *service_name = request_struct.has_service ? request_struct.service : ""; + return true; +} - // Encode response +bool DefaultHealthCheckService::HealthCheckServiceImpl::EncodeResponse( + ServingStatus status, ByteBuffer* response) { grpc_health_v1_HealthCheckResponse response_struct; response_struct.has_status = true; response_struct.status = - serving_status == DefaultHealthCheckService::SERVING - ? grpc_health_v1_HealthCheckResponse_ServingStatus_SERVING - : grpc_health_v1_HealthCheckResponse_ServingStatus_NOT_SERVING; + status == NOT_FOUND + ? grpc_health_v1_HealthCheckResponse_ServingStatus_SERVICE_UNKNOWN + : status == SERVING + ? grpc_health_v1_HealthCheckResponse_ServingStatus_SERVING + : grpc_health_v1_HealthCheckResponse_ServingStatus_NOT_SERVING; pb_ostream_t ostream; memset(&ostream, 0, sizeof(ostream)); pb_encode(&ostream, grpc_health_v1_HealthCheckResponse_fields, @@ -108,48 +228,253 @@ Status DefaultHealthCheckService::HealthCheckServiceImpl::Check( GRPC_SLICE_LENGTH(response_slice)); bool encode_status = pb_encode( &ostream, grpc_health_v1_HealthCheckResponse_fields, &response_struct); - if (!encode_status) { - return Status(StatusCode::INTERNAL, "Failed to encode response."); - } + if (!encode_status) return false; Slice encoded_response(response_slice, Slice::STEAL_REF); ByteBuffer response_buffer(&encoded_response, 1); response->Swap(&response_buffer); - return Status::OK; + return true; } -DefaultHealthCheckService::DefaultHealthCheckService() { - services_map_.emplace("", true); +// +// DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler +// + +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) { + std::shared_ptr<CallHandler> self = + std::make_shared<CheckCallHandler>(cq, database, service); + CheckCallHandler* handler = static_cast<CheckCallHandler*>(self.get()); + { + std::unique_lock<std::mutex> lock(service->cq_shutdown_mu_); + if (service->shutdown_) return; + // Request a Check() call. + handler->next_ = + CallableTag(std::bind(&CheckCallHandler::OnCallReceived, handler, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + service->RequestAsyncUnary(0, &handler->ctx_, &handler->request_, + &handler->writer_, cq, cq, &handler->next_); + } } -void DefaultHealthCheckService::SetServingStatus( - const grpc::string& service_name, bool serving) { - std::lock_guard<std::mutex> lock(mu_); - services_map_[service_name] = serving; +DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + CheckCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) + : cq_(cq), database_(database), service_(service), writer_(&ctx_) {} + +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + OnCallReceived(std::shared_ptr<CallHandler> self, bool ok) { + if (!ok) { + // The value of ok being false means that the server is shutting down. + return; + } + // Spawn a new handler instance to serve the next new client. Every handler + // instance will deallocate itself when it's done. + CreateAndStart(cq_, database_, service_); + // Process request. + gpr_log(GPR_DEBUG, "[HCS %p] Health check started for handler %p", service_, + this); + grpc::string service_name; + grpc::Status status = Status::OK; + ByteBuffer response; + if (!service_->DecodeRequest(request_, &service_name)) { + status = Status(StatusCode::INVALID_ARGUMENT, "could not parse request"); + } else { + ServingStatus serving_status = database_->GetServingStatus(service_name); + if (serving_status == NOT_FOUND) { + status = Status(StatusCode::NOT_FOUND, "service name unknown"); + } else if (!service_->EncodeResponse(serving_status, &response)) { + status = Status(StatusCode::INTERNAL, "could not encode response"); + } + } + // Send response. + { + std::unique_lock<std::mutex> lock(service_->cq_shutdown_mu_); + if (!service_->shutdown_) { + next_ = + CallableTag(std::bind(&CheckCallHandler::OnFinishDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + if (status.ok()) { + writer_.Finish(response, status, &next_); + } else { + writer_.FinishWithError(status, &next_); + } + } + } } -void DefaultHealthCheckService::SetServingStatus(bool serving) { - std::lock_guard<std::mutex> lock(mu_); - for (auto iter = services_map_.begin(); iter != services_map_.end(); ++iter) { - iter->second = serving; +void DefaultHealthCheckService::HealthCheckServiceImpl::CheckCallHandler:: + OnFinishDone(std::shared_ptr<CallHandler> self, bool ok) { + if (ok) { + gpr_log(GPR_DEBUG, "[HCS %p] Health check call finished for handler %p", + service_, this); } + self.reset(); // To appease clang-tidy. } -DefaultHealthCheckService::ServingStatus -DefaultHealthCheckService::GetServingStatus( - const grpc::string& service_name) const { - std::lock_guard<std::mutex> lock(mu_); - const auto& iter = services_map_.find(service_name); - if (iter == services_map_.end()) { - return NOT_FOUND; +// +// DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler +// + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) { + std::shared_ptr<CallHandler> self = + std::make_shared<WatchCallHandler>(cq, database, service); + WatchCallHandler* handler = static_cast<WatchCallHandler*>(self.get()); + { + std::unique_lock<std::mutex> lock(service->cq_shutdown_mu_); + if (service->shutdown_) return; + // Request AsyncNotifyWhenDone(). + handler->on_done_notified_ = + CallableTag(std::bind(&WatchCallHandler::OnDoneNotified, handler, + std::placeholders::_1, std::placeholders::_2), + self /* copies ref */); + handler->ctx_.AsyncNotifyWhenDone(&handler->on_done_notified_); + // Request a Watch() call. + handler->next_ = + CallableTag(std::bind(&WatchCallHandler::OnCallReceived, handler, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + service->RequestAsyncServerStreaming(1, &handler->ctx_, &handler->request_, + &handler->stream_, cq, cq, + &handler->next_); } - return iter->second ? SERVING : NOT_SERVING; } -DefaultHealthCheckService::HealthCheckServiceImpl* -DefaultHealthCheckService::GetHealthCheckService() { - GPR_ASSERT(impl_ == nullptr); - impl_.reset(new HealthCheckServiceImpl(this)); - return impl_.get(); +DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + WatchCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service) + : cq_(cq), database_(database), service_(service), stream_(&ctx_) {} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnCallReceived(std::shared_ptr<CallHandler> self, bool ok) { + if (!ok) { + // Server shutting down. + // + // AsyncNotifyWhenDone() needs to be called before the call starts, but the + // tag will not pop out if the call never starts ( + // https://github.com/grpc/grpc/issues/10136). So we need to manually + // release the ownership of the handler in this case. + GPR_ASSERT(on_done_notified_.ReleaseHandler() != nullptr); + return; + } + // Spawn a new handler instance to serve the next new client. Every handler + // instance will deallocate itself when it's done. + CreateAndStart(cq_, database_, service_); + // Parse request. + if (!service_->DecodeRequest(request_, &service_name_)) { + SendFinish(std::move(self), + Status(StatusCode::INVALID_ARGUMENT, "could not parse request")); + return; + } + // Register the call for updates to the service. + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch started for service \"%s\" (handler: %p)", + service_, service_name_.c_str(), this); + database_->RegisterCallHandler(service_name_, std::move(self)); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendHealth(std::shared_ptr<CallHandler> self, ServingStatus status) { + std::unique_lock<std::mutex> lock(send_mu_); + // If there's already a send in flight, cache the new status, and + // we'll start a new send for it when the one in flight completes. + if (send_in_flight_) { + pending_status_ = status; + return; + } + // Start a send. + SendHealthLocked(std::move(self), status); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendHealthLocked(std::shared_ptr<CallHandler> self, ServingStatus status) { + send_in_flight_ = true; + // Construct response. + ByteBuffer response; + bool success = service_->EncodeResponse(status, &response); + // Grab shutdown lock and send response. + std::unique_lock<std::mutex> cq_lock(service_->cq_shutdown_mu_); + if (service_->shutdown_) { + SendFinishLocked(std::move(self), Status::CANCELLED); + return; + } + if (!success) { + SendFinishLocked(std::move(self), + Status(StatusCode::INTERNAL, "could not encode response")); + return; + } + next_ = CallableTag(std::bind(&WatchCallHandler::OnSendHealthDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Write(response, &next_); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnSendHealthDone(std::shared_ptr<CallHandler> self, bool ok) { + if (!ok) { + SendFinish(std::move(self), Status::CANCELLED); + return; + } + std::unique_lock<std::mutex> lock(send_mu_); + send_in_flight_ = false; + // If we got a new status since we started the last send, start a + // new send for it. + if (pending_status_ != NOT_FOUND) { + auto status = pending_status_; + pending_status_ = NOT_FOUND; + SendHealthLocked(std::move(self), status); + } +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendFinish(std::shared_ptr<CallHandler> self, const Status& status) { + if (finish_called_) return; + std::unique_lock<std::mutex> cq_lock(service_->cq_shutdown_mu_); + if (service_->shutdown_) return; + SendFinishLocked(std::move(self), status); +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + SendFinishLocked(std::shared_ptr<CallHandler> self, const Status& status) { + on_finish_done_ = + CallableTag(std::bind(&WatchCallHandler::OnFinishDone, this, + std::placeholders::_1, std::placeholders::_2), + std::move(self)); + stream_.Finish(status, &on_finish_done_); + finish_called_ = true; +} + +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnFinishDone(std::shared_ptr<CallHandler> self, bool ok) { + if (ok) { + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch call finished (service_name: \"%s\", " + "handler: %p).", + service_, service_name_.c_str(), this); + } + self.reset(); // To appease clang-tidy. +} + +// TODO(roth): This method currently assumes that there will be only one +// thread polling the cq and invoking the corresponding callbacks. If +// that changes, we will need to add synchronization here. +void DefaultHealthCheckService::HealthCheckServiceImpl::WatchCallHandler:: + OnDoneNotified(std::shared_ptr<CallHandler> self, bool ok) { + GPR_ASSERT(ok); + gpr_log(GPR_DEBUG, + "[HCS %p] Health watch call is notified done (handler: %p, " + "is_cancelled: %d).", + service_, this, static_cast<int>(ctx_.IsCancelled())); + database_->UnregisterCallHandler(service_name_, self); + SendFinish(std::move(self), Status::CANCELLED); } } // namespace grpc diff --git a/src/cpp/server/health/default_health_check_service.h b/src/cpp/server/health/default_health_check_service.h index a1ce5aa64e..450bd543f5 100644 --- a/src/cpp/server/health/default_health_check_service.h +++ b/src/cpp/server/health/default_health_check_service.h @@ -19,42 +19,260 @@ #ifndef GRPC_INTERNAL_CPP_SERVER_DEFAULT_HEALTH_CHECK_SERVICE_H #define GRPC_INTERNAL_CPP_SERVER_DEFAULT_HEALTH_CHECK_SERVICE_H +#include <atomic> #include <mutex> +#include <set> +#include <grpc/support/log.h> +#include <grpcpp/grpcpp.h> #include <grpcpp/health_check_service_interface.h> +#include <grpcpp/impl/codegen/async_generic_service.h> +#include <grpcpp/impl/codegen/async_unary_call.h> #include <grpcpp/impl/codegen/service_type.h> #include <grpcpp/support/byte_buffer.h> +#include "src/core/lib/gprpp/thd.h" + namespace grpc { // Default implementation of HealthCheckServiceInterface. Server will create and // own it. class DefaultHealthCheckService final : public HealthCheckServiceInterface { public: + enum ServingStatus { NOT_FOUND, SERVING, NOT_SERVING }; + // The service impl to register with the server. class HealthCheckServiceImpl : public Service { public: - explicit HealthCheckServiceImpl(DefaultHealthCheckService* service); + // Base class for call handlers. + class CallHandler { + public: + virtual ~CallHandler() = default; + virtual void SendHealth(std::shared_ptr<CallHandler> self, + ServingStatus status) = 0; + }; - Status Check(ServerContext* context, const ByteBuffer* request, - ByteBuffer* response); + HealthCheckServiceImpl(DefaultHealthCheckService* database, + std::unique_ptr<ServerCompletionQueue> cq); + + ~HealthCheckServiceImpl(); + + void StartServingThread(); private: - const DefaultHealthCheckService* const service_; - internal::RpcServiceMethod* method_; + // A tag that can be called with a bool argument. It's tailored for + // CallHandler's use. Before being used, it should be constructed with a + // method of CallHandler and a shared pointer to the handler. The + // shared pointer will be moved to the invoked function and the function + // can only be invoked once. That makes ref counting of the handler easier, + // because the shared pointer is not bound to the function and can be gone + // once the invoked function returns (if not used any more). + class CallableTag { + public: + using HandlerFunction = + std::function<void(std::shared_ptr<CallHandler>, bool)>; + + CallableTag() {} + + CallableTag(HandlerFunction func, std::shared_ptr<CallHandler> handler) + : handler_function_(std::move(func)), handler_(std::move(handler)) { + GPR_ASSERT(handler_function_ != nullptr); + GPR_ASSERT(handler_ != nullptr); + } + + // Runs the tag. This should be called only once. The handler is no + // longer owned by this tag after this method is invoked. + void Run(bool ok) { + GPR_ASSERT(handler_function_ != nullptr); + GPR_ASSERT(handler_ != nullptr); + handler_function_(std::move(handler_), ok); + } + + // Releases and returns the shared pointer to the handler. + std::shared_ptr<CallHandler> ReleaseHandler() { + return std::move(handler_); + } + + private: + HandlerFunction handler_function_ = nullptr; + std::shared_ptr<CallHandler> handler_; + }; + + // Call handler for Check method. + // Each handler takes care of one call. It contains per-call data and it + // will access the members of the parent class (i.e., + // DefaultHealthCheckService) for per-service health data. + class CheckCallHandler : public CallHandler { + public: + // Instantiates a CheckCallHandler and requests the next health check + // call. The handler object will manage its own lifetime, so no action is + // needed from the caller any more regarding that object. + static void CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service); + + // This ctor is public because we want to use std::make_shared<> in + // CreateAndStart(). This ctor shouldn't be used elsewhere. + CheckCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service); + + // Not used for Check. + void SendHealth(std::shared_ptr<CallHandler> self, + ServingStatus status) override {} + + private: + // Called when we receive a call. + // Spawns a new handler so that we can keep servicing future calls. + void OnCallReceived(std::shared_ptr<CallHandler> self, bool ok); + + // Called when Finish() is done. + void OnFinishDone(std::shared_ptr<CallHandler> self, bool ok); + + // The members passed down from HealthCheckServiceImpl. + ServerCompletionQueue* cq_; + DefaultHealthCheckService* database_; + HealthCheckServiceImpl* service_; + + ByteBuffer request_; + GenericServerAsyncResponseWriter writer_; + ServerContext ctx_; + + CallableTag next_; + }; + + // Call handler for Watch method. + // Each handler takes care of one call. It contains per-call data and it + // will access the members of the parent class (i.e., + // DefaultHealthCheckService) for per-service health data. + class WatchCallHandler : public CallHandler { + public: + // Instantiates a WatchCallHandler and requests the next health check + // call. The handler object will manage its own lifetime, so no action is + // needed from the caller any more regarding that object. + static void CreateAndStart(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service); + + // This ctor is public because we want to use std::make_shared<> in + // CreateAndStart(). This ctor shouldn't be used elsewhere. + WatchCallHandler(ServerCompletionQueue* cq, + DefaultHealthCheckService* database, + HealthCheckServiceImpl* service); + + void SendHealth(std::shared_ptr<CallHandler> self, + ServingStatus status) override; + + private: + // Called when we receive a call. + // Spawns a new handler so that we can keep servicing future calls. + void OnCallReceived(std::shared_ptr<CallHandler> self, bool ok); + + // Requires holding send_mu_. + void SendHealthLocked(std::shared_ptr<CallHandler> self, + ServingStatus status); + + // When sending a health result finishes. + void OnSendHealthDone(std::shared_ptr<CallHandler> self, bool ok); + + void SendFinish(std::shared_ptr<CallHandler> self, const Status& status); + + // Requires holding service_->cq_shutdown_mu_. + void SendFinishLocked(std::shared_ptr<CallHandler> self, + const Status& status); + + // Called when Finish() is done. + void OnFinishDone(std::shared_ptr<CallHandler> self, bool ok); + + // Called when AsyncNotifyWhenDone() notifies us. + void OnDoneNotified(std::shared_ptr<CallHandler> self, bool ok); + + // The members passed down from HealthCheckServiceImpl. + ServerCompletionQueue* cq_; + DefaultHealthCheckService* database_; + HealthCheckServiceImpl* service_; + + ByteBuffer request_; + grpc::string service_name_; + GenericServerAsyncWriter stream_; + ServerContext ctx_; + + std::mutex send_mu_; + bool send_in_flight_ = false; // Guarded by mu_. + ServingStatus pending_status_ = NOT_FOUND; // Guarded by mu_. + + bool finish_called_ = false; + CallableTag next_; + CallableTag on_done_notified_; + CallableTag on_finish_done_; + }; + + // Handles the incoming requests and drives the completion queue in a loop. + static void Serve(void* arg); + + // Returns true on success. + static bool DecodeRequest(const ByteBuffer& request, + grpc::string* service_name); + static bool EncodeResponse(ServingStatus status, ByteBuffer* response); + + // Needed to appease Windows compilers, which don't seem to allow + // nested classes to access protected members in the parent's + // superclass. + using Service::RequestAsyncServerStreaming; + using Service::RequestAsyncUnary; + + DefaultHealthCheckService* database_; + std::unique_ptr<ServerCompletionQueue> cq_; + + // To synchronize the operations related to shutdown state of cq_, so that + // we don't enqueue new tags into cq_ after it is already shut down. + std::mutex cq_shutdown_mu_; + std::atomic_bool shutdown_{false}; + std::unique_ptr<::grpc_core::Thread> thread_; }; DefaultHealthCheckService(); + void SetServingStatus(const grpc::string& service_name, bool serving) override; void SetServingStatus(bool serving) override; - enum ServingStatus { NOT_FOUND, SERVING, NOT_SERVING }; + ServingStatus GetServingStatus(const grpc::string& service_name) const; - HealthCheckServiceImpl* GetHealthCheckService(); + + HealthCheckServiceImpl* GetHealthCheckService( + std::unique_ptr<ServerCompletionQueue> cq); private: + // Stores the current serving status of a service and any call + // handlers registered for updates when the service's status changes. + class ServiceData { + public: + void SetServingStatus(ServingStatus status); + ServingStatus GetServingStatus() const { return status_; } + void AddCallHandler( + std::shared_ptr<HealthCheckServiceImpl::CallHandler> handler); + void RemoveCallHandler( + const std::shared_ptr<HealthCheckServiceImpl::CallHandler>& handler); + bool Unused() const { + return call_handlers_.empty() && status_ == NOT_FOUND; + } + + private: + ServingStatus status_ = NOT_FOUND; + std::set<std::shared_ptr<HealthCheckServiceImpl::CallHandler>> + call_handlers_; + }; + + void RegisterCallHandler( + const grpc::string& service_name, + std::shared_ptr<HealthCheckServiceImpl::CallHandler> handler); + + void UnregisterCallHandler( + const grpc::string& service_name, + const std::shared_ptr<HealthCheckServiceImpl::CallHandler>& handler); + mutable std::mutex mu_; - std::map<grpc::string, bool> services_map_; + std::map<grpc::string, ServiceData> services_map_; // Guarded by mu_. std::unique_ptr<HealthCheckServiceImpl> impl_; }; diff --git a/src/cpp/server/health/health.pb.c b/src/cpp/server/health/health.pb.c deleted file mode 100644 index 09bd98a3d9..0000000000 --- a/src/cpp/server/health/health.pb.c +++ /dev/null @@ -1,24 +0,0 @@ -/* Automatically generated nanopb constant definitions */ -/* Generated by nanopb-0.3.7-dev */ - -#include "src/cpp/server/health/health.pb.h" - -/* @@protoc_insertion_point(includes) */ -#if PB_PROTO_HEADER_VERSION != 30 -#error Regenerate this file with the current version of nanopb generator. -#endif - - - -const pb_field_t grpc_health_v1_HealthCheckRequest_fields[2] = { - PB_FIELD( 1, STRING , OPTIONAL, STATIC , FIRST, grpc_health_v1_HealthCheckRequest, service, service, 0), - PB_LAST_FIELD -}; - -const pb_field_t grpc_health_v1_HealthCheckResponse_fields[2] = { - PB_FIELD( 1, UENUM , OPTIONAL, STATIC , FIRST, grpc_health_v1_HealthCheckResponse, status, status, 0), - PB_LAST_FIELD -}; - - -/* @@protoc_insertion_point(eof) */ diff --git a/src/cpp/server/health/health.pb.h b/src/cpp/server/health/health.pb.h deleted file mode 100644 index 29e1f3bacb..0000000000 --- a/src/cpp/server/health/health.pb.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Automatically generated nanopb header */ -/* Generated by nanopb-0.3.7-dev */ - -#ifndef PB_GRPC_HEALTH_V1_HEALTH_PB_H_INCLUDED -#define PB_GRPC_HEALTH_V1_HEALTH_PB_H_INCLUDED -#include "pb.h" -/* @@protoc_insertion_point(includes) */ -#if PB_PROTO_HEADER_VERSION != 30 -#error Regenerate this file with the current version of nanopb generator. -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -/* Enum definitions */ -typedef enum _grpc_health_v1_HealthCheckResponse_ServingStatus { - grpc_health_v1_HealthCheckResponse_ServingStatus_UNKNOWN = 0, - grpc_health_v1_HealthCheckResponse_ServingStatus_SERVING = 1, - grpc_health_v1_HealthCheckResponse_ServingStatus_NOT_SERVING = 2 -} grpc_health_v1_HealthCheckResponse_ServingStatus; -#define _grpc_health_v1_HealthCheckResponse_ServingStatus_MIN grpc_health_v1_HealthCheckResponse_ServingStatus_UNKNOWN -#define _grpc_health_v1_HealthCheckResponse_ServingStatus_MAX grpc_health_v1_HealthCheckResponse_ServingStatus_NOT_SERVING -#define _grpc_health_v1_HealthCheckResponse_ServingStatus_ARRAYSIZE ((grpc_health_v1_HealthCheckResponse_ServingStatus)(grpc_health_v1_HealthCheckResponse_ServingStatus_NOT_SERVING+1)) - -/* Struct definitions */ -typedef struct _grpc_health_v1_HealthCheckRequest { - bool has_service; - char service[200]; -/* @@protoc_insertion_point(struct:grpc_health_v1_HealthCheckRequest) */ -} grpc_health_v1_HealthCheckRequest; - -typedef struct _grpc_health_v1_HealthCheckResponse { - bool has_status; - grpc_health_v1_HealthCheckResponse_ServingStatus status; -/* @@protoc_insertion_point(struct:grpc_health_v1_HealthCheckResponse) */ -} grpc_health_v1_HealthCheckResponse; - -/* Default values for struct fields */ - -/* Initializer values for message structs */ -#define grpc_health_v1_HealthCheckRequest_init_default {false, ""} -#define grpc_health_v1_HealthCheckResponse_init_default {false, (grpc_health_v1_HealthCheckResponse_ServingStatus)0} -#define grpc_health_v1_HealthCheckRequest_init_zero {false, ""} -#define grpc_health_v1_HealthCheckResponse_init_zero {false, (grpc_health_v1_HealthCheckResponse_ServingStatus)0} - -/* Field tags (for use in manual encoding/decoding) */ -#define grpc_health_v1_HealthCheckRequest_service_tag 1 -#define grpc_health_v1_HealthCheckResponse_status_tag 1 - -/* Struct field encoding specification for nanopb */ -extern const pb_field_t grpc_health_v1_HealthCheckRequest_fields[2]; -extern const pb_field_t grpc_health_v1_HealthCheckResponse_fields[2]; - -/* Maximum encoded size of messages (where known) */ -#define grpc_health_v1_HealthCheckRequest_size 203 -#define grpc_health_v1_HealthCheckResponse_size 2 - -/* Message IDs (where set with "msgid" option) */ -#ifdef PB_MSGID - -#define HEALTH_MESSAGES \ - - -#endif - -#ifdef __cplusplus -} /* extern "C" */ -#endif -/* @@protoc_insertion_point(eof) */ - -#endif diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index 536bf022dd..ebb17def32 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -43,9 +43,10 @@ void AuthMetadataProcessorAyncWrapper::Process( return; } if (w->processor_->IsBlocking()) { - w->thread_pool_->Add( - std::bind(&AuthMetadataProcessorAyncWrapper::InvokeProcessor, w, - context, md, num_md, cb, user_data)); + w->thread_pool_->Add([w, context, md, num_md, cb, user_data] { + w->AuthMetadataProcessorAyncWrapper::InvokeProcessor(context, md, num_md, + cb, user_data); + }); } else { // invoke directly. w->InvokeProcessor(context, md, num_md, cb, user_data); diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 8417c45e64..0dc03b6876 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -71,7 +71,9 @@ ServerBuilder::~ServerBuilder() { std::unique_ptr<ServerCompletionQueue> ServerBuilder::AddCompletionQueue( bool is_frequently_polled) { ServerCompletionQueue* cq = new ServerCompletionQueue( - is_frequently_polled ? GRPC_CQ_DEFAULT_POLLING : GRPC_CQ_NON_LISTENING); + GRPC_CQ_NEXT, + is_frequently_polled ? GRPC_CQ_DEFAULT_POLLING : GRPC_CQ_NON_LISTENING, + nullptr); cqs_.push_back(cq); return std::unique_ptr<ServerCompletionQueue>(cq); } @@ -256,14 +258,22 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { // Create completion queues to listen to incoming rpc requests for (int i = 0; i < sync_server_settings_.num_cqs; i++) { - sync_server_cqs->emplace_back(new ServerCompletionQueue(polling_type)); + sync_server_cqs->emplace_back( + new ServerCompletionQueue(GRPC_CQ_NEXT, polling_type, nullptr)); } } - std::unique_ptr<Server> server(new Server( - max_receive_message_size_, &args, sync_server_cqs, - sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, resource_quota_)); + // == Determine if the server has any callback methods == + bool has_callback_methods = false; + for (auto it = services_.begin(); it != services_.end(); ++it) { + if ((*it)->service->has_callback_methods()) { + has_callback_methods = true; + break; + } + } + + // TODO(vjpai): Add a section here for plugins once they can support callback + // methods if (has_sync_methods) { // This is a Sync server @@ -275,6 +285,16 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { sync_server_settings_.cq_timeout_msec); } + if (has_callback_methods) { + gpr_log(GPR_INFO, "Callback server."); + } + + std::unique_ptr<Server> server(new Server( + max_receive_message_size_, &args, sync_server_cqs, + sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, + sync_server_settings_.cq_timeout_msec, resource_quota_, + std::move(interceptor_creators_))); + ServerInitializer* initializer = server->initializer(); // Register all the completion queues with the server. i.e @@ -288,6 +308,12 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { num_frequently_polled_cqs++; } + if (has_callback_methods) { + auto* cq = server->CallbackCQ(); + grpc_server_register_completion_queue(server->server_, cq->cq(), nullptr); + num_frequently_polled_cqs++; + } + // cqs_ contains the completion queue added by calling the ServerBuilder's // AddCompletionQueue() API. Some of them may not be frequently polled (i.e by // calling Next() or AsyncNext()) and hence are not safe to be used for diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7c764f4bce..c031528a8f 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -27,7 +27,9 @@ #include <grpcpp/completion_queue.h> #include <grpcpp/generic/async_generic_service.h> #include <grpcpp/impl/codegen/async_unary_call.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/grpc_library.h> #include <grpcpp/impl/method_handler_impl.h> #include <grpcpp/impl/rpc_service_method.h> @@ -38,8 +40,10 @@ #include <grpcpp/support/time.h> #include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" #include "src/cpp/client/create_channel_internal.h" #include "src/cpp/server/health/default_health_check_service.h" #include "src/cpp/thread_manager/thread_manager.h" @@ -127,10 +131,13 @@ class Server::UnimplementedAsyncResponse final ~UnimplementedAsyncResponse() { delete request_; } bool FinalizeResult(void** tag, bool* status) override { - internal::CallOpSet< - internal::CallOpSendInitialMetadata, - internal::CallOpServerSendStatus>::FinalizeResult(tag, status); - delete this; + if (internal::CallOpSet< + internal::CallOpSendInitialMetadata, + internal::CallOpServerSendStatus>::FinalizeResult(tag, status)) { + delete this; + } else { + // The tag was swallowed due to interception. We will see it again. + } return false; } @@ -140,9 +147,9 @@ class Server::UnimplementedAsyncResponse final class Server::SyncRequest final : public internal::CompletionQueueTag { public: - SyncRequest(internal::RpcServiceMethod* method, void* tag) + SyncRequest(internal::RpcServiceMethod* method, void* method_tag) : method_(method), - tag_(tag), + method_tag_(method_tag), in_flight_(false), has_request_payload_( method->method_type() == internal::RpcMethod::NORMAL_RPC || @@ -169,10 +176,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Request(grpc_server* server, grpc_completion_queue* notify_cq) { GPR_ASSERT(cq_ && !in_flight_); in_flight_ = true; - if (tag_) { + if (method_tag_) { if (GRPC_CALL_OK != grpc_server_request_registered_call( - server, tag_, &call_, &deadline_, &request_metadata_, + server, method_tag_, &call_, &deadline_, &request_metadata_, has_request_payload_ ? &request_payload_ : nullptr, cq_, notify_cq, this)) { TeardownRequest(); @@ -204,17 +211,25 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { return true; } + // The CallData class represents a call that is "active" as opposed + // to just being requested. It wraps and takes ownership of the cq from + // the call request class CallData final { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size()), ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), + request_(nullptr), method_(mrd->method_), - server_(server) { + call_(mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), + server->interceptor_creators_)), + server_(server), + global_callbacks_(nullptr), + resources_(false) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -230,38 +245,79 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Run(const std::shared_ptr<GlobalCallbacks>& global_callbacks, bool resources) { - ctx_.BeginCompletionOp(&call_); - global_callbacks->PreSynchronousRequest(&ctx_); - auto* handler = resources ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_payload_)); - global_callbacks->PostSynchronousRequest(&ctx_); - request_payload_ = nullptr; - - cq_.Shutdown(); + global_callbacks_ = global_callbacks; + resources_ = resources; + + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); + + if (has_request_payload_) { + // Set interception point for RECV MESSAGE + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + request_ = handler->Deserialize(call_.call(), request_payload_, + &request_status_); + + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); + } - internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); - cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + if (interceptor_methods_.RunInterceptors( + [this]() { ContinueRunAfterInterception(); })) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } - /* Ensure the cq_ is shutdown */ - DummyTag ignored_tag; - GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + void ContinueRunAfterInterception() { + { + ctx_.BeginCompletionOp(&call_, false); + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( + &call_, &ctx_, request_, request_status_, nullptr)); + request_ = nullptr; + global_callbacks_->PostSynchronousRequest(&ctx_); + + cq_.Shutdown(); + + internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + /* Ensure the cq_ is shutdown */ + DummyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + } + delete this; } private: CompletionQueue cq_; - internal::Call call_; ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; internal::RpcServiceMethod* const method_; + internal::Call call_; Server* server_; + std::shared_ptr<GlobalCallbacks> global_callbacks_; + bool resources_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; private: internal::RpcServiceMethod* const method_; - void* const tag_; + void* const method_tag_; bool in_flight_; const bool has_request_payload_; grpc_call* call_; @@ -272,6 +328,176 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { grpc_completion_queue* cq_; }; +class Server::CallbackRequest final : public internal::CompletionQueueTag { + public: + CallbackRequest(Server* server, internal::RpcServiceMethod* method, + void* method_tag) + : server_(server), + method_(method), + method_tag_(method_tag), + has_request_payload_( + method->method_type() == internal::RpcMethod::NORMAL_RPC || + method->method_type() == internal::RpcMethod::SERVER_STREAMING), + cq_(server->CallbackCQ()), + tag_(this) { + Setup(); + } + + ~CallbackRequest() { Clear(); } + + void Request() { + if (method_tag_) { + if (GRPC_CALL_OK != + grpc_server_request_registered_call( + server_->c_server(), method_tag_, &call_, &deadline_, + &request_metadata_, + has_request_payload_ ? &request_payload_ : nullptr, cq_->cq(), + cq_->cq(), static_cast<void*>(&tag_))) { + return; + } + } else { + if (!call_details_) { + call_details_ = new grpc_call_details; + grpc_call_details_init(call_details_); + } + if (grpc_server_request_call(server_->c_server(), &call_, call_details_, + &request_metadata_, cq_->cq(), cq_->cq(), + static_cast<void*>(&tag_)) != GRPC_CALL_OK) { + return; + } + } + } + + bool FinalizeResult(void** tag, bool* status) override { return false; } + + private: + class CallbackCallTag : public grpc_experimental_completion_queue_functor { + public: + CallbackCallTag(Server::CallbackRequest* req) : req_(req) { + functor_run = &CallbackCallTag::StaticRun; + } + + // force_run can not be performed on a tag if operations using this tag + // have been sent to PerformOpsOnCall. It is intended for error conditions + // that are detected before the operations are internally processed. + void force_run(bool ok) { Run(ok); } + + private: + Server::CallbackRequest* req_; + internal::Call* call_; + + static void StaticRun(grpc_experimental_completion_queue_functor* cb, + int ok) { + static_cast<CallbackCallTag*>(cb)->Run(static_cast<bool>(ok)); + } + void Run(bool ok) { + void* ignored = req_; + bool new_ok = ok; + GPR_ASSERT(!req_->FinalizeResult(&ignored, &new_ok)); + GPR_ASSERT(ignored == req_); + + if (!ok) { + // The call has been shutdown + req_->Clear(); + return; + } + + // Bind the call, deadline, and metadata from what we got + req_->ctx_.set_call(req_->call_); + req_->ctx_.cq_ = req_->cq_; + req_->ctx_.BindDeadlineAndMetadata(req_->deadline_, + &req_->request_metadata_); + req_->request_metadata_.count = 0; + + // Create a C++ Call to control the underlying core call + call_ = new (grpc_call_arena_alloc(req_->call_, sizeof(internal::Call))) + internal::Call( + req_->call_, req_->server_, req_->cq_, + req_->server_->max_receive_message_size(), + req_->ctx_.set_server_rpc_info( + req_->method_->name(), req_->server_->interceptor_creators_)); + + req_->interceptor_methods_.SetCall(call_); + req_->interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + req_->interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + req_->interceptor_methods_.SetRecvInitialMetadata( + &req_->ctx_.client_metadata_); + + if (req_->has_request_payload_) { + // Set interception point for RECV MESSAGE + req_->request_ = req_->method_->handler()->Deserialize( + req_->call_, req_->request_payload_, &req_->request_status_); + req_->request_payload_ = nullptr; + req_->interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + req_->interceptor_methods_.SetRecvMessage(req_->request_); + } + + if (req_->interceptor_methods_.RunInterceptors( + [this] { ContinueRunAfterInterception(); })) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } + void ContinueRunAfterInterception() { + req_->ctx_.BeginCompletionOp(call_, true); + req_->method_->handler()->RunHandler( + internal::MethodHandler::HandlerParameter( + call_, &req_->ctx_, req_->request_, req_->request_status_, + [this] { + req_->Reset(); + req_->Request(); + })); + } + }; + + void Reset() { + Clear(); + Setup(); + } + + void Clear() { + if (call_details_) { + delete call_details_; + call_details_ = nullptr; + } + grpc_metadata_array_destroy(&request_metadata_); + if (has_request_payload_ && request_payload_) { + grpc_byte_buffer_destroy(request_payload_); + } + ctx_.Clear(); + interceptor_methods_.ClearState(); + } + + void Setup() { + grpc_metadata_array_init(&request_metadata_); + ctx_.Setup(gpr_inf_future(GPR_CLOCK_REALTIME)); + request_payload_ = nullptr; + request_ = nullptr; + request_status_ = Status(); + } + + Server* const server_; + internal::RpcServiceMethod* const method_; + void* const method_tag_; + const bool has_request_payload_; + grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; + grpc_call_details* call_details_ = nullptr; + grpc_call* call_; + gpr_timespec deadline_; + grpc_metadata_array request_metadata_; + CompletionQueue* cq_; + CallbackCallTag tag_; + ServerContext ctx_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; +}; + // Implementation of ThreadManager. Each instance of SyncRequestThreadManager // manages a pool of threads that poll for incoming Sync RPCs and call the // appropriate RPC handlers @@ -318,8 +544,9 @@ class Server::SyncRequestThreadManager : public ThreadManager { } if (ok) { - // Calldata takes ownership of the completion queue inside sync_req - SyncRequest::CallData cd(server_, sync_req); + // Calldata takes ownership of the completion queue and interceptors + // inside sync_req + auto* cd = new SyncRequest::CallData(server_, sync_req); // Prepare for the next request if (!IsShutdown()) { sync_req->SetupRequest(); // Create new completion queue for sync_req @@ -327,7 +554,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_, resources); + cd->Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -380,7 +607,6 @@ class Server::SyncRequestThreadManager : public ThreadManager { int cq_timeout_msec_; std::vector<std::unique_ptr<SyncRequest>> sync_requests_; std::unique_ptr<internal::RpcServiceMethod> unknown_method_; - std::unique_ptr<internal::RpcServiceMethod> health_check_; std::shared_ptr<Server::GlobalCallbacks> global_callbacks_; }; @@ -390,8 +616,12 @@ Server::Server( std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq) - : max_receive_message_size_(max_receive_message_size), + grpc_resource_quota* server_rq, + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators) + : interceptor_creators_(std::move(interceptor_creators)), + max_receive_message_size_(max_receive_message_size), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), shutdown_(false), @@ -447,6 +677,9 @@ Server::Server( Server::~Server() { { std::unique_lock<std::mutex> lock(mu_); + if (callback_cq_ != nullptr) { + callback_cq_->Shutdown(); + } if (started_ && !shutdown_) { lock.unlock(); Shutdown(); @@ -519,21 +752,28 @@ bool Server::RegisterService(const grpc::string* host, Service* service) { } internal::RpcServiceMethod* method = it->get(); - void* tag = grpc_server_register_method( + void* method_registration_tag = grpc_server_register_method( server_, method->name(), host ? host->c_str() : nullptr, PayloadHandlingForMethod(method), 0); - if (tag == nullptr) { + if (method_registration_tag == nullptr) { gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", method->name()); return false; } - if (method->handler() == nullptr) { // Async method - method->set_server_tag(tag); - } else { + if (method->handler() == nullptr) { // Async method without handler + method->set_server_tag(method_registration_tag); + } else if (method->api_type() == + internal::RpcServiceMethod::ApiType::SYNC) { for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) { - (*it)->AddSyncMethod(method, tag); + (*it)->AddSyncMethod(method, method_registration_tag); } + } else { + // a callback method + auto* req = new CallbackRequest(this, method, method_registration_tag); + callback_reqs_.emplace_back(req); + // Enqueue it so that it will be Request'ed later once + // all request matchers are created at core server startup } method_name = method->name(); @@ -573,16 +813,25 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { // Only create default health check service when user did not provide an // explicit one. + ServerCompletionQueue* health_check_cq = nullptr; + DefaultHealthCheckService::HealthCheckServiceImpl* + default_health_check_service_impl = nullptr; if (health_check_service_ == nullptr && !health_check_service_disabled_ && DefaultHealthCheckServiceEnabled()) { - if (sync_server_cqs_ == nullptr || sync_server_cqs_->empty()) { - gpr_log(GPR_INFO, - "Default health check service disabled at async-only server."); - } else { - auto* default_hc_service = new DefaultHealthCheckService; - health_check_service_.reset(default_hc_service); - RegisterService(nullptr, default_hc_service->GetHealthCheckService()); - } + auto* default_hc_service = new DefaultHealthCheckService; + health_check_service_.reset(default_hc_service); + // We create a non-polling CQ to avoid impacting application + // performance. This ensures that we don't introduce thread hops + // for application requests that wind up on this CQ, which is polled + // in its own thread. + health_check_cq = + new ServerCompletionQueue(GRPC_CQ_NEXT, GRPC_CQ_NON_POLLING, nullptr); + grpc_server_register_completion_queue(server_, health_check_cq->cq(), + nullptr); + default_health_check_service_impl = + default_hc_service->GetHealthCheckService( + std::unique_ptr<ServerCompletionQueue>(health_check_cq)); + RegisterService(nullptr, default_health_check_service_impl); } grpc_server_start(server_); @@ -597,6 +846,9 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { new UnimplementedAsyncRequest(this, cqs[i]); } } + if (health_check_cq != nullptr) { + new UnimplementedAsyncRequest(this, health_check_cq); + } } // If this server has any support for synchronous methods (has any sync @@ -609,6 +861,14 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) { (*it)->Start(); } + + for (auto& cbreq : callback_reqs_) { + cbreq->Request(); + } + + if (default_health_check_service_impl != nullptr) { + default_health_check_service_impl->StartServingThread(); + } } void Server::ShutdownInternal(gpr_timespec deadline) { @@ -667,31 +927,27 @@ void Server::Wait() { void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - // TODO(vjpai): Use ops->cq_tag once this case supports callbacks - auto result = grpc_call_start_batch(call->call(), cops, nops, ops, nullptr); - if (result != GRPC_CALL_OK) { - gpr_log(GPR_ERROR, "Fatal: grpc_call_start_batch returned %d", result); - grpc_call_log_batch(__FILE__, __LINE__, GPR_LOG_SEVERITY_ERROR, - call->call(), cops, nops, ops); - abort(); - } + ops->FillOps(call); } ServerInterface::BaseAsyncRequest::BaseAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag, bool delete_on_finalize) + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) : server_(server), context_(context), stream_(stream), call_cq_(call_cq), + notification_cq_(notification_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), - call_(nullptr) { + call_(nullptr), + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops } @@ -701,15 +957,43 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (done_intercepting_) { + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } context_->set_call(call_); context_->cq_ = call_cq_; - internal::Call call(call_, server_, call_cq_, - server_->max_receive_message_size()); - if (*status && call_) { - context_->BeginCompletionOp(&call); + if (call_wrapper_.call() == nullptr) { + // Fill it since it is empty. + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); } + // just the pointers inside call are copied here - stream_->BindCall(&call); + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + if (interceptor_methods_.RunInterceptors( + [this]() { ContinueFinalizeResultAfterInterception(); })) { + // There are no interceptors to run. Continue + } else { + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. + return false; + } + } + if (*status && call_) { + context_->BeginCompletionOp(&call_wrapper_, false); + } *tag = tag_; if (delete_on_finalize_) { delete this; @@ -717,11 +1001,25 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, return true; } +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_, false); + // Queue a tag which will be returned immediately + grpc_core::ExecCtx exec_ctx; + grpc_cq_begin_op(notification_cq_->cq(), this); + grpc_cq_end_op( + notification_cq_->cq(), this, GRPC_ERROR_NONE, + [](void* arg, grpc_cq_completion* completion) { delete completion; }, + nullptr, new grpc_cq_completion()); +} + ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerCompletionQueue* notification_cq, void* tag, const char* name) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -737,7 +1035,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( ServerInterface* server, GenericServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) - : BaseAsyncRequest(server, context, stream, call_cq, tag, + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, delete_on_finalize) { grpc_call_details_init(&call_details_); GPR_ASSERT(notification_cq); @@ -750,6 +1048,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + // If we are done intercepting, there is nothing more for us to do + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } // TODO(yangg) remove the copy here. if (*status) { static_cast<GenericServerContext*>(context_)->method_ = @@ -760,16 +1062,26 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, } grpc_slice_unref(call_details_.method); grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info( + static_cast<GenericServerContext*>(context_)->method_.c_str(), + *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { - if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) { - new UnimplementedAsyncRequest(server_, cq_); - new UnimplementedAsyncResponse(this); + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + // We either had no interceptors run or we are done intercepting + if (*status) { + new UnimplementedAsyncRequest(server_, cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } } else { - delete this; + // The tag was swallowed due to interception. We will see it again. } return false; } @@ -784,4 +1096,41 @@ Server::UnimplementedAsyncResponse::UnimplementedAsyncResponse( ServerInitializer* Server::initializer() { return server_initializer_.get(); } +namespace { +class ShutdownCallback : public grpc_experimental_completion_queue_functor { + public: + ShutdownCallback() { functor_run = &ShutdownCallback::Run; } + // TakeCQ takes ownership of the cq into the shutdown callback + // so that the shutdown callback will be responsible for destroying it + void TakeCQ(CompletionQueue* cq) { cq_ = cq; } + + // The Run function will get invoked by the completion queue library + // when the shutdown is actually complete + static void Run(grpc_experimental_completion_queue_functor* cb, int) { + auto* callback = static_cast<ShutdownCallback*>(cb); + delete callback->cq_; + delete callback; + } + + private: + CompletionQueue* cq_ = nullptr; +}; +} // namespace + +CompletionQueue* Server::CallbackCQ() { + // TODO(vjpai): Consider using a single global CQ for the default CQ + // if there is no explicit per-server CQ registered + std::lock_guard<std::mutex> l(mu_); + if (callback_cq_ == nullptr) { + auto* shutdown_callback = new ShutdownCallback; + callback_cq_ = new CompletionQueue(grpc_completion_queue_attributes{ + GRPC_CQ_CURRENT_VERSION, GRPC_CQ_CALLBACK, GRPC_CQ_DEFAULT_POLLING, + shutdown_callback}); + + // Transfer ownership of the new cq to its own shutdown callback + shutdown_callback->TakeCQ(callback_cq_); + } + return callback_cq_; +}; + } // namespace grpc diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index b7254b6bb9..396996e5bc 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -40,14 +40,45 @@ namespace grpc { class ServerContext::CompletionOp final : public internal::CallOpSetInterface { public: // initial refs: one in the server context, one in the cq - CompletionOp() - : has_tag_(false), + // must ref the call before calling constructor and after deleting this + CompletionOp(internal::Call* call) + : call_(*call), + has_tag_(false), tag_(nullptr), + core_cq_tag_(this), refs_(2), finalized_(false), - cancelled_(0) {} + cancelled_(0), + done_intercepting_(false) {} + + // CompletionOp isn't copyable or movable + CompletionOp(const CompletionOp&) = delete; + CompletionOp& operator=(const CompletionOp&) = delete; + CompletionOp(CompletionOp&&) = delete; + CompletionOp& operator=(CompletionOp&&) = delete; + + ~CompletionOp() { + if (call_.server_rpc_info()) { + call_.server_rpc_info()->Unref(); + } + } + + void FillOps(internal::Call* call) override; + + // This should always be arena allocated in the call, so override delete. + // But this class is not trivially destructible, so must actually call delete + // before allowing the arena to be freed + static void operator delete(void* ptr, std::size_t size) { + assert(size == sizeof(CompletionOp)); + } + + // This operator should never be called as the memory should be freed as part + // of the arena destruction. It only exists to provide a matching operator + // delete to the operator new so that some compilers will not complain (see + // https://github.com/grpc/grpc/issues/11301) Note at the time of adding this + // there are no tests catching the compiler warning. + static void operator delete(void*, void*) { assert(0); } - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -61,97 +92,188 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { tag_ = tag; } - /// TODO(vjpai): Allow override of cq_tag if appropriate for callback API - void* cq_tag() override { return this; } + void set_core_cq_tag(void* core_cq_tag) { core_cq_tag_ = core_cq_tag; } + + void* core_cq_tag() override { return core_cq_tag_; } void Unref(); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_CODEGEN_ASSERT(false); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + /* We don't have a tag to return. */ + std::unique_lock<std::mutex> lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return; + } + /* Start a dummy op so that we can return the tag */ + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, this, nullptr)); + } + private: bool CheckCancelledNoPluck() { std::lock_guard<std::mutex> g(mu_); return finalized_ ? (cancelled_ != 0) : false; } + internal::Call call_; bool has_tag_; void* tag_; + void* core_cq_tag_; std::mutex mu_; int refs_; bool finalized_; int cancelled_; + bool done_intercepting_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; void ServerContext::CompletionOp::Unref() { std::unique_lock<std::mutex> lock(mu_); if (--refs_ == 0) { lock.unlock(); + grpc_call* call = call_.call(); delete this; + grpc_call_unref(call); } } -void ServerContext::CompletionOp::FillOps(grpc_call* call, grpc_op* ops, - size_t* nops) { - ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; - ops->data.recv_close_on_server.cancelled = &cancelled_; - ops->flags = 0; - ops->reserved = nullptr; - *nops = 1; +void ServerContext::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); + GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), &ops, 1, + core_cq_tag_, nullptr)); + /* No interceptors to run here */ } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { - std::unique_lock<std::mutex> lock(mu_); - finalized_ = true; bool ret = false; - if (has_tag_) { - *tag = tag_; - ret = true; + std::unique_lock<std::mutex> lock(mu_); + if (done_intercepting_) { + /* We are done intercepting. */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } + finalized_ = true; + if (!*status) cancelled_ = 1; - if (--refs_ == 0) { - lock.unlock(); - delete this; + /* Release the lock since we are going to be running through interceptors now + */ + lock.unlock(); + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + /* No interceptors were run */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + lock.lock(); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } - return ret; + /* There are interceptors to be run. Return false for now */ + return false; } // ServerContext body -ServerContext::ServerContext() - : completion_op_(nullptr), - has_notify_when_done_tag_(false), - async_notify_when_done_tag_(nullptr), - deadline_(gpr_inf_future(GPR_CLOCK_REALTIME)), - call_(nullptr), - cq_(nullptr), - sent_initial_metadata_(false), - compression_level_set_(false), - has_pending_ops_(false) {} - -ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata_array* arr) - : completion_op_(nullptr), - has_notify_when_done_tag_(false), - async_notify_when_done_tag_(nullptr), - deadline_(deadline), - call_(nullptr), - cq_(nullptr), - sent_initial_metadata_(false), - compression_level_set_(false), - has_pending_ops_(false) { +ServerContext::ServerContext() { Setup(gpr_inf_future(GPR_CLOCK_REALTIME)); } + +ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata_array* arr) { + Setup(deadline); + std::swap(*client_metadata_.arr(), *arr); +} + +void ServerContext::Setup(gpr_timespec deadline) { + completion_op_ = nullptr; + has_notify_when_done_tag_ = false; + async_notify_when_done_tag_ = nullptr; + deadline_ = deadline; + call_ = nullptr; + cq_ = nullptr; + sent_initial_metadata_ = false; + compression_level_set_ = false; + has_pending_ops_ = false; + rpc_info_ = nullptr; +} + +void ServerContext::BindDeadlineAndMetadata(gpr_timespec deadline, + grpc_metadata_array* arr) { + deadline_ = deadline; std::swap(*client_metadata_.arr(), *arr); } -ServerContext::~ServerContext() { +ServerContext::~ServerContext() { Clear(); } + +void ServerContext::Clear() { if (call_) { grpc_call_unref(call_); } if (completion_op_) { completion_op_->Unref(); + completion_tag_.Clear(); } + if (rpc_info_) { + rpc_info_->Unref(); + } + // Don't need to clear out call_, completion_op_, or rpc_info_ because this is + // either called from destructor or just before Setup } -void ServerContext::BeginCompletionOp(internal::Call* call) { +void ServerContext::BeginCompletionOp(internal::Call* call, bool callback) { GPR_ASSERT(!completion_op_); - completion_op_ = new CompletionOp(); - if (has_notify_when_done_tag_) { + if (rpc_info_) { + rpc_info_->Ref(); + } + grpc_call_ref(call->call()); + completion_op_ = + new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) + CompletionOp(call); + if (callback) { + completion_tag_.Set(call->call(), nullptr, completion_op_); + completion_op_->set_core_cq_tag(&completion_tag_); + } else if (has_notify_when_done_tag_) { completion_op_->set_tag(async_notify_when_done_tag_); } call->PerformOps(completion_op_); @@ -172,6 +294,12 @@ void ServerContext::AddTrailingMetadata(const grpc::string& key, } void ServerContext::TryCancel() const { + internal::CancelInterceptorBatchMethods cancel_methods; + if (rpc_info_) { + for (size_t i = 0; i < rpc_info_->interceptors_.size(); i++) { + rpc_info_->RunInterceptor(&cancel_methods, i); + } + } grpc_call_error err = grpc_call_cancel_with_status( call_, GRPC_STATUS_CANCELLED, "Cancelled on the server side", nullptr); if (err != GRPC_CALL_OK) { @@ -180,12 +308,15 @@ void ServerContext::TryCancel() const { } bool ServerContext::IsCancelled() const { - if (has_notify_when_done_tag_) { - // when using async API, but the result is only valid + if (completion_tag_) { + // When using callback API, this result is always valid. + return completion_op_->CheckCancelledAsync(); + } else if (has_notify_when_done_tag_) { + // When using async API, the result is only valid // if the tag has already been delivered at the completion queue return completion_op_ && completion_op_->CheckCancelledAsync(); } else { - // when using sync API + // when using sync API, the result is always valid return completion_op_ && completion_op_->CheckCancelled(cq_); } } |