From 84e763f10a1e10d36c7de35970f9d25958ee2e16 Mon Sep 17 00:00:00 2001 From: Vijay Pai Date: Mon, 8 Oct 2018 13:28:10 -0700 Subject: Experimental C++ server callback unary API --- BUILD | 2 + CMakeLists.txt | 8 + Makefile | 8 + build.yaml | 2 + gRPC-C++.podspec | 2 + include/grpcpp/impl/codegen/byte_buffer.h | 4 + include/grpcpp/impl/codegen/callback_common.h | 11 +- include/grpcpp/impl/codegen/channel_interface.h | 2 +- include/grpcpp/impl/codegen/completion_queue.h | 10 +- include/grpcpp/impl/codegen/rpc_service_method.h | 43 ++-- include/grpcpp/impl/codegen/server_callback.h | 204 +++++++++++++++++ include/grpcpp/impl/codegen/server_context.h | 11 +- include/grpcpp/impl/codegen/server_interface.h | 10 + include/grpcpp/impl/codegen/service_type.h | 58 ++++- include/grpcpp/server.h | 13 ++ include/grpcpp/support/server_callback.h | 24 ++ src/compiler/cpp_generator.cc | 186 ++++++++++++++- src/cpp/server/server_builder.cc | 39 +++- src/cpp/server/server_cc.cc | 251 +++++++++++++++++++-- src/cpp/server/server_context.cc | 50 ++-- test/cpp/codegen/compiler_test_golden | 199 ++++++++++++++++ test/cpp/end2end/client_callback_end2end_test.cc | 53 ++++- test/cpp/end2end/test_service_impl.cc | 150 +++++++++++- test/cpp/end2end/test_service_impl.h | 33 +++ tools/doxygen/Doxyfile.c++ | 2 + tools/doxygen/Doxyfile.c++.internal | 2 + tools/run_tests/generated/sources_and_headers.json | 4 + 27 files changed, 1298 insertions(+), 83 deletions(-) create mode 100644 include/grpcpp/impl/codegen/server_callback.h create mode 100644 include/grpcpp/support/server_callback.h diff --git a/BUILD b/BUILD index c8ad45a788..08948e32e1 100644 --- a/BUILD +++ b/BUILD @@ -245,6 +245,7 @@ GRPCXX_PUBLIC_HDRS = [ "include/grpcpp/support/config.h", "include/grpcpp/support/proto_buffer_reader.h", "include/grpcpp/support/proto_buffer_writer.h", + "include/grpcpp/support/server_callback.h", "include/grpcpp/support/slice.h", "include/grpcpp/support/status.h", "include/grpcpp/support/status_code_enum.h", @@ -2088,6 +2089,7 @@ grpc_cc_library( "include/grpcpp/impl/codegen/rpc_service_method.h", "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", + "include/grpcpp/impl/codegen/server_callback.h", "include/grpcpp/impl/codegen/server_context.h", "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 700fa48abc..cfb163d6b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3020,6 +3020,7 @@ foreach(_hdr include/grpcpp/support/config.h include/grpcpp/support/proto_buffer_reader.h include/grpcpp/support/proto_buffer_writer.h + include/grpcpp/support/server_callback.h include/grpcpp/support/slice.h include/grpcpp/support/status.h include/grpcpp/support/status_code_enum.h @@ -3137,6 +3138,7 @@ foreach(_hdr include/grpcpp/impl/codegen/rpc_service_method.h include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h + include/grpcpp/impl/codegen/server_callback.h include/grpcpp/impl/codegen/server_context.h include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h @@ -3600,6 +3602,7 @@ foreach(_hdr include/grpcpp/support/config.h include/grpcpp/support/proto_buffer_reader.h include/grpcpp/support/proto_buffer_writer.h + include/grpcpp/support/server_callback.h include/grpcpp/support/slice.h include/grpcpp/support/status.h include/grpcpp/support/status_code_enum.h @@ -3717,6 +3720,7 @@ foreach(_hdr include/grpcpp/impl/codegen/rpc_service_method.h include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h + include/grpcpp/impl/codegen/server_callback.h include/grpcpp/impl/codegen/server_context.h include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h @@ -4131,6 +4135,7 @@ foreach(_hdr include/grpcpp/impl/codegen/rpc_service_method.h include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h + include/grpcpp/impl/codegen/server_callback.h include/grpcpp/impl/codegen/server_context.h include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h @@ -4317,6 +4322,7 @@ foreach(_hdr include/grpcpp/impl/codegen/rpc_service_method.h include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h + include/grpcpp/impl/codegen/server_callback.h include/grpcpp/impl/codegen/server_context.h include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h @@ -4530,6 +4536,7 @@ foreach(_hdr include/grpcpp/support/config.h include/grpcpp/support/proto_buffer_reader.h include/grpcpp/support/proto_buffer_writer.h + include/grpcpp/support/server_callback.h include/grpcpp/support/slice.h include/grpcpp/support/status.h include/grpcpp/support/status_code_enum.h @@ -4647,6 +4654,7 @@ foreach(_hdr include/grpcpp/impl/codegen/rpc_service_method.h include/grpcpp/impl/codegen/security/auth_context.h include/grpcpp/impl/codegen/serialization_traits.h + include/grpcpp/impl/codegen/server_callback.h include/grpcpp/impl/codegen/server_context.h include/grpcpp/impl/codegen/server_interceptor.h include/grpcpp/impl/codegen/server_interface.h diff --git a/Makefile b/Makefile index 19c518427f..d1f8762018 100644 --- a/Makefile +++ b/Makefile @@ -5371,6 +5371,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/support/config.h \ include/grpcpp/support/proto_buffer_reader.h \ include/grpcpp/support/proto_buffer_writer.h \ + include/grpcpp/support/server_callback.h \ include/grpcpp/support/slice.h \ include/grpcpp/support/status.h \ include/grpcpp/support/status_code_enum.h \ @@ -5488,6 +5489,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ + include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -5960,6 +5962,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/support/config.h \ include/grpcpp/support/proto_buffer_reader.h \ include/grpcpp/support/proto_buffer_writer.h \ + include/grpcpp/support/server_callback.h \ include/grpcpp/support/slice.h \ include/grpcpp/support/status.h \ include/grpcpp/support/status_code_enum.h \ @@ -6077,6 +6080,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ + include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -6476,6 +6480,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ + include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -6639,6 +6644,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ + include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -6857,6 +6863,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/support/config.h \ include/grpcpp/support/proto_buffer_reader.h \ include/grpcpp/support/proto_buffer_writer.h \ + include/grpcpp/support/server_callback.h \ include/grpcpp/support/slice.h \ include/grpcpp/support/status.h \ include/grpcpp/support/status_code_enum.h \ @@ -6974,6 +6981,7 @@ PUBLIC_HEADERS_CXX += \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ + include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ diff --git a/build.yaml b/build.yaml index e7e92d280d..75860076bd 100644 --- a/build.yaml +++ b/build.yaml @@ -1245,6 +1245,7 @@ filegroups: - include/grpcpp/impl/codegen/rpc_service_method.h - include/grpcpp/impl/codegen/security/auth_context.h - include/grpcpp/impl/codegen/serialization_traits.h + - include/grpcpp/impl/codegen/server_callback.h - include/grpcpp/impl/codegen/server_context.h - include/grpcpp/impl/codegen/server_interceptor.h - include/grpcpp/impl/codegen/server_interface.h @@ -1363,6 +1364,7 @@ filegroups: - include/grpcpp/support/config.h - include/grpcpp/support/proto_buffer_reader.h - include/grpcpp/support/proto_buffer_writer.h + - include/grpcpp/support/server_callback.h - include/grpcpp/support/slice.h - include/grpcpp/support/status.h - include/grpcpp/support/status_code_enum.h diff --git a/gRPC-C++.podspec b/gRPC-C++.podspec index 485646171c..df2618e760 100644 --- a/gRPC-C++.podspec +++ b/gRPC-C++.podspec @@ -118,6 +118,7 @@ Pod::Spec.new do |s| 'include/grpcpp/support/config.h', 'include/grpcpp/support/proto_buffer_reader.h', 'include/grpcpp/support/proto_buffer_writer.h', + 'include/grpcpp/support/server_callback.h', 'include/grpcpp/support/slice.h', 'include/grpcpp/support/status.h', 'include/grpcpp/support/status_code_enum.h', @@ -154,6 +155,7 @@ Pod::Spec.new do |s| 'include/grpcpp/impl/codegen/rpc_service_method.h', 'include/grpcpp/impl/codegen/security/auth_context.h', 'include/grpcpp/impl/codegen/serialization_traits.h', + 'include/grpcpp/impl/codegen/server_callback.h', 'include/grpcpp/impl/codegen/server_context.h', 'include/grpcpp/impl/codegen/server_interceptor.h', 'include/grpcpp/impl/codegen/server_interface.h', diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index d54ae31852..abba5549b8 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -45,6 +45,8 @@ template class RpcMethodHandler; template class ServerStreamingHandler; +template +class CallbackUnaryHandler; template class ErrorMethodHandler; template @@ -154,6 +156,8 @@ class ByteBuffer final { friend class internal::RpcMethodHandler; template friend class internal::ServerStreamingHandler; + template + friend class internal::CallbackUnaryHandler; template friend class internal::ErrorMethodHandler; template diff --git a/include/grpcpp/impl/codegen/callback_common.h b/include/grpcpp/impl/codegen/callback_common.h index eba9ec6edc..8273ef2f4a 100644 --- a/include/grpcpp/impl/codegen/callback_common.h +++ b/include/grpcpp/impl/codegen/callback_common.h @@ -101,10 +101,11 @@ class CallbackWithStatusTag GPR_CODEGEN_ASSERT(ignored == ops_); // Last use of func_ or status_, so ok to move them out - CatchingCallback(std::move(func_), std::move(status_)); - + auto func = std::move(func_); + auto status = std::move(status_); func_ = nullptr; // reset to clear this out for sure status_ = Status(); // reset to clear this out for sure + CatchingCallback(std::move(func), std::move(status)); g_core_codegen_interface->grpc_call_unref(call_); } }; @@ -124,6 +125,8 @@ class CallbackWithSuccessTag // there are no tests catching the compiler warning. static void operator delete(void*, void*) { assert(0); } + CallbackWithSuccessTag() : call_(nullptr), ops_(nullptr) {} + CallbackWithSuccessTag(grpc_call* call, std::function f, CompletionQueueTag* ops) : call_(call), func_(std::move(f)), ops_(ops) { @@ -154,9 +157,9 @@ class CallbackWithSuccessTag GPR_CODEGEN_ASSERT(ignored == ops_); // Last use of func_, so ok to move it out for rvalue call above - CatchingCallback(std::move(func_), ok); - + auto func = std::move(func_); func_ = nullptr; // reset to clear this out for sure + CatchingCallback(std::move(func), ok); g_core_codegen_interface->grpc_call_unref(call_); } }; diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index 6fd1dd1d9b..0735c96521 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -142,7 +142,7 @@ class ChannelInterface { // channel. If the return value is nullptr, this channel doesn't support // callback operations. // TODO(vjpai): Consider a better default like using a global CQ - // Returns nullptr (rather than being pure) since this is a new method + // Returns nullptr (rather than being pure) since this is a post-1.0 method // and adding a new pure method to an interface would be a breaking change // (even though this is private and non-API) virtual CompletionQueue* CallbackCQ() { return nullptr; } diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h index 5eef2c281f..d603c7c700 100644 --- a/include/grpcpp/impl/codegen/completion_queue.h +++ b/include/grpcpp/impl/codegen/completion_queue.h @@ -380,12 +380,18 @@ class ServerCompletionQueue : public CompletionQueue { ServerCompletionQueue() : polling_type_(GRPC_CQ_DEFAULT_POLLING) {} private: + /// \param completion_type indicates whether this is a NEXT or CALLBACK + /// completion queue. /// \param polling_type Informs the GRPC library about the type of polling /// allowed on this completion queue. See grpc_cq_polling_type's description /// in grpc_types.h for more details. - ServerCompletionQueue(grpc_cq_polling_type polling_type) + /// \param shutdown_cb is the shutdown callback used for CALLBACK api queues + ServerCompletionQueue(grpc_cq_completion_type completion_type, + grpc_cq_polling_type polling_type, + grpc_experimental_completion_queue_functor* shutdown_cb) : CompletionQueue(grpc_completion_queue_attributes{ - GRPC_CQ_CURRENT_VERSION, GRPC_CQ_NEXT, polling_type, nullptr}), + GRPC_CQ_CURRENT_VERSION, completion_type, polling_type, + shutdown_cb}), polling_type_(polling_type) {} grpc_cq_polling_type polling_type_; diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index e77f4046a3..44d9b8ad63 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -41,13 +41,18 @@ class MethodHandler { virtual ~MethodHandler() {} struct HandlerParameter { HandlerParameter(Call* c, ServerContext* context, void* req, - Status req_status) - : call(c), server_context(context), request(req), status(req_status) {} + Status req_status, std::function renew) + : call(c), + server_context(context), + request(req), + status(req_status), + renewer(std::move(renew)) {} ~HandlerParameter() {} Call* call; ServerContext* server_context; void* request; Status status; + std::function renewer; }; virtual void RunHandler(const HandlerParameter& param) = 0; @@ -71,25 +76,29 @@ class RpcServiceMethod : public RpcMethod { MethodHandler* handler) : RpcMethod(name, type), server_tag_(nullptr), - async_type_(AsyncType::UNSET), + api_type_(ApiType::SYNC), handler_(handler) {} - enum class AsyncType { - UNSET, + enum class ApiType { + SYNC, ASYNC, RAW, + CALL_BACK, // not CALLBACK because that is reserved in Windows + RAW_CALL_BACK, }; void set_server_tag(void* tag) { server_tag_ = tag; } void* server_tag() const { return server_tag_; } /// if MethodHandler is nullptr, then this is an async method MethodHandler* handler() const { return handler_.get(); } + ApiType api_type() const { return api_type_; } void SetHandler(MethodHandler* handler) { handler_.reset(handler); } - void SetServerAsyncType(RpcServiceMethod::AsyncType type) { - if (async_type_ == AsyncType::UNSET) { + void SetServerApiType(RpcServiceMethod::ApiType type) { + if ((api_type_ == ApiType::SYNC) && + (type == ApiType::ASYNC || type == ApiType::RAW)) { // this marks this method as async handler_.reset(); - } else { + } else if (api_type_ != ApiType::SYNC) { // this is not an error condition, as it allows users to declare a server // like WithRawMethod_foo. However since it // overwrites behavior, it should be logged. @@ -98,24 +107,28 @@ class RpcServiceMethod : public RpcMethod { "You are marking method %s as '%s', even though it was " "previously marked '%s'. This behavior will overwrite the original " "behavior. If you expected this then ignore this message.", - name(), TypeToString(async_type_), TypeToString(type)); + name(), TypeToString(api_type_), TypeToString(type)); } - async_type_ = type; + api_type_ = type; } private: void* server_tag_; - AsyncType async_type_; + ApiType api_type_; std::unique_ptr handler_; - const char* TypeToString(RpcServiceMethod::AsyncType type) { + const char* TypeToString(RpcServiceMethod::ApiType type) { switch (type) { - case AsyncType::UNSET: + case ApiType::SYNC: return "unset"; - case AsyncType::ASYNC: + case ApiType::ASYNC: return "async"; - case AsyncType::RAW: + case ApiType::RAW: return "raw"; + case ApiType::CALL_BACK: + return "callback"; + case ApiType::RAW_CALL_BACK: + return "raw_callback"; default: GPR_UNREACHABLE_CODE(return "unknown"); } diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h new file mode 100644 index 0000000000..c8f7510ed5 --- /dev/null +++ b/include/grpcpp/impl/codegen/server_callback.h @@ -0,0 +1,204 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H +#define GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace grpc { + +// forward declarations +namespace internal { +template +class CallbackUnaryHandler; +} // namespace internal + +namespace experimental { + +// For unary RPCs, the exposed controller class is only an interface +// and the actual implementation is an internal class. +class ServerCallbackRpcController { + public: + virtual ~ServerCallbackRpcController() {} + + // The method handler must call this function when it is done so that + // the library knows to free its resources + virtual void Finish(Status s) = 0; + virtual void FinishWithError(Status s) = 0; + + // Allow the method handler to push out the initial metadata before + // the response and status are ready + virtual void SendInitialMetadata(std::function) = 0; +}; + +} // namespace experimental + +namespace internal { + +template +class CallbackUnaryHandler : public MethodHandler { + public: + CallbackUnaryHandler( + std::function + func, + ServiceType* service) + : func_(func) {} + void RunHandler(const HandlerParameter& param) final { + // Arena allocate a controller structure (that includes request/response) + g_core_codegen_interface->grpc_call_ref(param.call->call()); + auto* controller = new (g_core_codegen_interface->grpc_call_arena_alloc( + param.call->call(), sizeof(ServerCallbackRpcControllerImpl))) + ServerCallbackRpcControllerImpl( + param.server_context, param.call, + static_cast(param.request), std::move(param.renewer)); + Status status = param.status; + + if (status.ok()) { + // Call the actual function handler and expect the user to call finish + CatchingCallback(std::move(func_), param.server_context, + controller->request(), controller->response(), + controller); + } else { + // if deserialization failed, we need to fail the call + controller->Finish(status); + } + } + + void* Deserialize(grpc_call* call, grpc_byte_buffer* req, + Status* status) final { + ByteBuffer buf; + buf.set_buffer(req); + auto* request = new (g_core_codegen_interface->grpc_call_arena_alloc( + call, sizeof(RequestType))) RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); + buf.Release(); + if (status->ok()) { + return request; + } + request->~RequestType(); + return nullptr; + } + + private: + std::function + func_; + + // The implementation class of ServerCallbackRpcController is a private member + // of CallbackUnaryHandler since it is never exposed anywhere, and this allows + // it to take advantage of CallbackUnaryHandler's friendships. + + class ServerCallbackRpcControllerImpl + : public experimental::ServerCallbackRpcController { + public: + void Finish(Status s) override { FinishInternal(std::move(s), false); } + + void FinishWithError(Status s) override { + FinishInternal(std::move(s), true); + } + + void SendInitialMetadata(std::function f) override { + GPR_CODEGEN_ASSERT(!ctx_->sent_initial_metadata_); + + meta_tag_ = + CallbackWithSuccessTag(call_.call(), std::move(f), &meta_buf_); + meta_buf_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + meta_buf_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + meta_buf_.set_cq_tag(&meta_tag_); + call_.PerformOps(&meta_buf_); + } + + private: + template + friend class CallbackUnaryHandler; + + ServerCallbackRpcControllerImpl(ServerContext* ctx, Call* call, + RequestType* req, + std::function renewer) + : ctx_(ctx), call_(*call), req_(req), renewer_(std::move(renewer)) {} + + ~ServerCallbackRpcControllerImpl() { req_->~RequestType(); } + + void FinishInternal(Status s, bool allow_error) { + finish_tag_ = CallbackWithSuccessTag( + call_.call(), + [this](bool) { + grpc_call* call = call_.call(); + auto renewer = std::move(renewer_); + this->~ServerCallbackRpcControllerImpl(); // explicitly call + // destructor + g_core_codegen_interface->grpc_call_unref(call); + renewer(); + }, + &finish_buf_); + if (!ctx_->sent_initial_metadata_) { + finish_buf_.SendInitialMetadata(&ctx_->initial_metadata_, + ctx_->initial_metadata_flags()); + if (ctx_->compression_level_set()) { + finish_buf_.set_compression_level(ctx_->compression_level()); + } + ctx_->sent_initial_metadata_ = true; + } + // The response may be dropped if the status is not OK. + if (allow_error || s.ok()) { + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, + finish_buf_.SendMessage(resp_)); + } else { + finish_buf_.ServerSendStatus(&ctx_->trailing_metadata_, s); + } + finish_buf_.set_cq_tag(&finish_tag_); + call_.PerformOps(&finish_buf_); + } + + RequestType* request() { return req_; } + ResponseType* response() { return &resp_; } + + CallOpSet meta_buf_; + CallbackWithSuccessTag meta_tag_; + CallOpSet + finish_buf_; + CallbackWithSuccessTag finish_tag_; + + ServerContext* ctx_; + Call call_; + RequestType* req_; + ResponseType resp_; + std::function renewer_; + }; +}; + +} // namespace internal + +} // namespace grpc + +#endif // GRPCPP_IMPL_CODEGEN_SERVER_CALLBACK_H diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index 7559fb3b34..ebbf64bc6d 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -65,6 +65,8 @@ template class ServerStreamingHandler; template class BidiStreamingHandler; +template +class CallbackUnaryHandler; template class TemplatedBidiStreamingHandler; template @@ -267,6 +269,8 @@ class ServerContext { friend class ::grpc::internal::ServerStreamingHandler; template friend class ::grpc::internal::TemplatedBidiStreamingHandler; + template + friend class ::grpc::internal::CallbackUnaryHandler; template friend class internal::ErrorMethodHandler; friend class ::grpc::ClientContext; @@ -285,6 +289,11 @@ class ServerContext { void set_call(grpc_call* call) { call_ = call; } + void BindDeadlineAndMetadata(gpr_timespec deadline, grpc_metadata_array* arr); + + void Clear(); + void Setup(gpr_timespec deadline); + uint32_t initial_metadata_flags() const { return 0; } experimental::ServerRpcInfo* set_server_rpc_info( @@ -321,7 +330,7 @@ class ServerContext { pending_ops_; bool has_pending_ops_; - experimental::ServerRpcInfo* rpc_info_ = nullptr; + experimental::ServerRpcInfo* rpc_info_; }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 92c87a5f7e..bde7740f44 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -338,6 +338,16 @@ class ServerInterface : public internal::CallHook { interceptor_creators() { return nullptr; } + + // EXPERIMENTAL + // A method to get the callbackable completion queue associated with this + // server. If the return value is nullptr, this server doesn't support + // callback operations. + // TODO(vjpai): Consider a better default like using a global CQ + // Returns nullptr (rather than being pure) since this is a post-1.0 method + // and adding a new pure method to an interface would be a breaking change + // (even though this is private and non-API) + virtual CompletionQueue* CallbackCQ() { return nullptr; } }; } // namespace grpc diff --git a/include/grpcpp/impl/codegen/service_type.h b/include/grpcpp/impl/codegen/service_type.h index 9f1a052168..332a04c294 100644 --- a/include/grpcpp/impl/codegen/service_type.h +++ b/include/grpcpp/impl/codegen/service_type.h @@ -71,7 +71,20 @@ class Service { bool has_synchronous_methods() const { for (auto it = methods_.begin(); it != methods_.end(); ++it) { - if (*it && (*it)->handler() != nullptr) { + if (*it && + (*it)->api_type() == internal::RpcServiceMethod::ApiType::SYNC) { + return true; + } + } + return false; + } + + bool has_callback_methods() const { + for (auto it = methods_.begin(); it != methods_.end(); ++it) { + if (*it && ((*it)->api_type() == + internal::RpcServiceMethod::ApiType::CALL_BACK || + (*it)->api_type() == + internal::RpcServiceMethod::ApiType::RAW_CALL_BACK)) { return true; } } @@ -88,6 +101,43 @@ class Service { } protected: + // TODO(vjpai): Promote experimental contents once callback API is accepted + class experimental_type { + public: + explicit experimental_type(Service* service) : service_(service) {} + + void MarkMethodCallback(int index, internal::MethodHandler* handler) { + // This does not have to be a hard error, however no one has approached us + // with a use case yet. Please file an issue if you believe you have one. + size_t idx = static_cast(index); + GPR_CODEGEN_ASSERT( + service_->methods_[idx].get() != nullptr && + "Cannot mark the method as 'callback' because it has already been " + "marked as 'generic'."); + service_->methods_[idx]->SetHandler(handler); + service_->methods_[idx]->SetServerApiType( + internal::RpcServiceMethod::ApiType::CALL_BACK); + } + + void MarkMethodRawCallback(int index, internal::MethodHandler* handler) { + // This does not have to be a hard error, however no one has approached us + // with a use case yet. Please file an issue if you believe you have one. + size_t idx = static_cast(index); + GPR_CODEGEN_ASSERT( + service_->methods_[idx].get() != nullptr && + "Cannot mark the method as 'raw callback' because it has already " + "been marked as 'generic'."); + service_->methods_[idx]->SetHandler(handler); + service_->methods_[idx]->SetServerApiType( + internal::RpcServiceMethod::ApiType::RAW_CALL_BACK); + } + + private: + Service* service_; + }; + + experimental_type experimental() { return experimental_type(this); } + template void RequestAsyncUnary(int index, ServerContext* context, Message* request, internal::ServerAsyncStreamingInterface* stream, @@ -138,8 +188,7 @@ class Service { methods_[idx].get() != nullptr && "Cannot mark the method as 'async' because it has already been " "marked as 'generic'."); - methods_[idx]->SetServerAsyncType( - internal::RpcServiceMethod::AsyncType::ASYNC); + methods_[idx]->SetServerApiType(internal::RpcServiceMethod::ApiType::ASYNC); } void MarkMethodRaw(int index) { @@ -149,8 +198,7 @@ class Service { GPR_CODEGEN_ASSERT(methods_[idx].get() != nullptr && "Cannot mark the method as 'raw' because it has already " "been marked as 'generic'."); - methods_[idx]->SetServerAsyncType( - internal::RpcServiceMethod::AsyncType::RAW); + methods_[idx]->SetServerApiType(internal::RpcServiceMethod::ApiType::RAW); } void MarkMethodGeneric(int index) { diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index 2b89ffd317..ef54a218a7 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -202,6 +202,7 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { friend class ServerInitializer; class SyncRequest; + class CallbackRequest; class UnimplementedAsyncRequest; class UnimplementedAsyncResponse; @@ -224,6 +225,8 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { return max_receive_message_size_; }; + CompletionQueue* CallbackCQ() override; + ServerInitializer* initializer(); const int max_receive_message_size_; @@ -238,6 +241,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { /// the \a sync_server_cqs) std::vector> sync_req_mgrs_; + /// Outstanding callback requests + std::vector> callback_reqs_; + // Server status std::mutex mu_; bool started_; @@ -264,6 +270,13 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { std::vector> interceptor_creators_; + + // callback_cq_ references the callbackable completion queue associated + // with this server (if any). It is set on the first call to CallbackCQ(). + // It is _not owned_ by the server; ownership belongs with its internal + // shutdown callback tag (invoked when the CQ is fully shutdown). + // It is protected by mu_ + CompletionQueue* callback_cq_ = nullptr; }; } // namespace grpc diff --git a/include/grpcpp/support/server_callback.h b/include/grpcpp/support/server_callback.h new file mode 100644 index 0000000000..b0aeeb53c5 --- /dev/null +++ b/include/grpcpp/support/server_callback.h @@ -0,0 +1,24 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPCPP_SUPPORT_SERVER_CALLBACK_H +#define GRPCPP_SUPPORT_SERVER_CALLBACK_H + +#include + +#endif // GRPCPP_SUPPORT_SERVER_CALLBACK_H diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 56716493dc..1e17760092 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -135,6 +135,7 @@ grpc::string GetHeaderIncludes(grpc_generator::File* file, "grpcpp/impl/codegen/method_handler_impl.h", "grpcpp/impl/codegen/proto_utils.h", "grpcpp/impl/codegen/rpc_method.h", + "grpcpp/impl/codegen/server_callback.h", "grpcpp/impl/codegen/service_type.h", "grpcpp/impl/codegen/status.h", "grpcpp/impl/codegen/stub_options.h", @@ -702,7 +703,7 @@ void PrintHeaderServerMethodSync(grpc_generator::Printer* printer, // Helper generator. Disabled the sync API for Request and Response, then adds // in an async API for RealRequest and RealResponse types. This is to be used -// to generate async and raw APIs. +// to generate async and raw async APIs. void PrintHeaderServerAsyncMethodsHelper( grpc_generator::Printer* printer, const grpc_generator::Method* method, std::map* vars) { @@ -829,6 +830,164 @@ void PrintHeaderServerMethodAsync(grpc_generator::Printer* printer, printer->Print(*vars, "};\n"); } +// Helper generator. Disabled the sync API for Request and Response, then adds +// in a callback API for RealRequest and RealResponse types. This is to be used +// to generate callback and raw callback APIs. +void PrintHeaderServerCallbackMethodsHelper( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + if (method->NoStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "$Response$* response) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print( + *vars, + "virtual void $Method$(" + "::grpc::ServerContext* context, const $RealRequest$* request, " + "$RealResponse$* response, " + "::grpc::experimental::ServerCallbackRpcController* " + "controller) { controller->Finish(::grpc::Status(" + "::grpc::StatusCode::UNIMPLEMENTED, \"\")); }\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReader< $Request$>* reader, " + "$Response$* response) override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "::grpc::ServerWriter< $Response$>* writer) override " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (method->BidiStreaming()) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* stream) " + " override {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } +} + +void PrintHeaderServerMethodCallback( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for the callback API + (*vars)["RealRequest"] = method->input_type_name(); + (*vars)["RealResponse"] = method->output_type_name(); + printer->Print(*vars, "template \n"); + printer->Print( + *vars, + "class ExperimentalWithCallbackMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service *service) {}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, "ExperimentalWithCallbackMethod_$Method$() {\n"); + if (method->NoStreaming()) { + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodCallback($Idx$,\n" + " new ::grpc::internal::CallbackUnaryHandler< " + "ExperimentalWithCallbackMethod_$Method$, $RealRequest$, " + "$RealResponse$>(\n" + " [this](::grpc::ServerContext* context,\n" + " const $RealRequest$* request,\n" + " $RealResponse$* response,\n" + " ::grpc::experimental::ServerCallbackRpcController* " + "controller) {\n" + " this->$" + "Method$(context, request, response, controller);\n" + " }, this));\n"); + } else if (ClientOnlyStreaming(method)) { + } else if (ServerOnlyStreaming(method)) { + } else if (method->BidiStreaming()) { + } + printer->Print(*vars, "}\n"); + printer->Print(*vars, + "~ExperimentalWithCallbackMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerCallbackMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + +void PrintHeaderServerMethodRawCallback( + grpc_generator::Printer* printer, const grpc_generator::Method* method, + std::map* vars) { + (*vars)["Method"] = method->name(); + // These will be disabled + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + // These will be used for raw API + (*vars)["RealRequest"] = "::grpc::ByteBuffer"; + (*vars)["RealResponse"] = "::grpc::ByteBuffer"; + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class ExperimentalWithRawCallbackMethod_$Method$ : public " + "BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service *service) {}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, "ExperimentalWithRawCallbackMethod_$Method$() {\n"); + if (method->NoStreaming()) { + printer->Print( + *vars, + " ::grpc::Service::experimental().MarkMethodRawCallback($Idx$,\n" + " new ::grpc::internal::CallbackUnaryHandler< " + "ExperimentalWithRawCallbackMethod_$Method$, $RealRequest$, " + "$RealResponse$>(\n" + " [this](::grpc::ServerContext* context,\n" + " const $RealRequest$* request,\n" + " $RealResponse$* response,\n" + " ::grpc::experimental::ServerCallbackRpcController* " + "controller) {\n" + " this->$" + "Method$(context, request, response, controller);\n" + " }, this));\n"); + } else if (ClientOnlyStreaming(method)) { + } else if (ServerOnlyStreaming(method)) { + } else if (method->BidiStreaming()) { + } + printer->Print(*vars, "}\n"); + printer->Print(*vars, + "~ExperimentalWithRawCallbackMethod_$Method$() override {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + PrintHeaderServerCallbackMethodsHelper(printer, method, vars); + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + void PrintHeaderServerMethodStreamedUnary( grpc_generator::Printer* printer, const grpc_generator::Method* method, std::map* vars) { @@ -1146,6 +1305,24 @@ void PrintHeaderService(grpc_generator::Printer* printer, } printer->Print(" AsyncService;\n"); + // Server side - Callback + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodCallback(printer, service->method(i).get(), vars); + } + + printer->Print("typedef "); + + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i).get()->name(); + printer->Print(*vars, "ExperimentalWithCallbackMethod_$method_name$<"); + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + printer->Print(" >"); + } + printer->Print(" ExperimentalCallbackService;\n"); + // Server side - Generic for (int i = 0; i < service->method_count(); ++i) { (*vars)["Idx"] = as_string(i); @@ -1158,6 +1335,12 @@ void PrintHeaderService(grpc_generator::Printer* printer, PrintHeaderServerMethodRaw(printer, service->method(i).get(), vars); } + // Server side - Raw Callback + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodRawCallback(printer, service->method(i).get(), vars); + } + // Server side - Streamed Unary for (int i = 0; i < service->method_count(); ++i) { (*vars)["Idx"] = as_string(i); @@ -1333,6 +1516,7 @@ grpc::string GetSourceIncludes(grpc_generator::File* file, "grpcpp/impl/codegen/client_callback.h", "grpcpp/impl/codegen/method_handler_impl.h", "grpcpp/impl/codegen/rpc_service_method.h", + "grpcpp/impl/codegen/server_callback.h", "grpcpp/impl/codegen/service_type.h", "grpcpp/impl/codegen/sync_stream.h"}; std::vector headers(headers_strs, array_end(headers_strs)); diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index fc42b6c886..0dc03b6876 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -71,7 +71,9 @@ ServerBuilder::~ServerBuilder() { std::unique_ptr ServerBuilder::AddCompletionQueue( bool is_frequently_polled) { ServerCompletionQueue* cq = new ServerCompletionQueue( - is_frequently_polled ? GRPC_CQ_DEFAULT_POLLING : GRPC_CQ_NON_LISTENING); + GRPC_CQ_NEXT, + is_frequently_polled ? GRPC_CQ_DEFAULT_POLLING : GRPC_CQ_NON_LISTENING, + nullptr); cqs_.push_back(cq); return std::unique_ptr(cq); } @@ -256,15 +258,22 @@ std::unique_ptr ServerBuilder::BuildAndStart() { // Create completion queues to listen to incoming rpc requests for (int i = 0; i < sync_server_settings_.num_cqs; i++) { - sync_server_cqs->emplace_back(new ServerCompletionQueue(polling_type)); + sync_server_cqs->emplace_back( + new ServerCompletionQueue(GRPC_CQ_NEXT, polling_type, nullptr)); } } - std::unique_ptr server(new Server( - max_receive_message_size_, &args, sync_server_cqs, - sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, resource_quota_, - std::move(interceptor_creators_))); + // == Determine if the server has any callback methods == + bool has_callback_methods = false; + for (auto it = services_.begin(); it != services_.end(); ++it) { + if ((*it)->service->has_callback_methods()) { + has_callback_methods = true; + break; + } + } + + // TODO(vjpai): Add a section here for plugins once they can support callback + // methods if (has_sync_methods) { // This is a Sync server @@ -276,6 +285,16 @@ std::unique_ptr ServerBuilder::BuildAndStart() { sync_server_settings_.cq_timeout_msec); } + if (has_callback_methods) { + gpr_log(GPR_INFO, "Callback server."); + } + + std::unique_ptr server(new Server( + max_receive_message_size_, &args, sync_server_cqs, + sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, + sync_server_settings_.cq_timeout_msec, resource_quota_, + std::move(interceptor_creators_))); + ServerInitializer* initializer = server->initializer(); // Register all the completion queues with the server. i.e @@ -289,6 +308,12 @@ std::unique_ptr ServerBuilder::BuildAndStart() { num_frequently_polled_cqs++; } + if (has_callback_methods) { + auto* cq = server->CallbackCQ(); + grpc_server_register_completion_queue(server->server_, cq->cq(), nullptr); + num_frequently_polled_cqs++; + } + // cqs_ contains the completion queue added by calling the ServerBuilder's // AddCompletionQueue() API. Some of them may not be frequently polled (i.e by // calling Next() or AsyncNext()) and hence are not safe to be used for diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 82a9d719fa..5d24e01a47 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -147,9 +147,9 @@ class Server::UnimplementedAsyncResponse final class Server::SyncRequest final : public internal::CompletionQueueTag { public: - SyncRequest(internal::RpcServiceMethod* method, void* tag) + SyncRequest(internal::RpcServiceMethod* method, void* method_tag) : method_(method), - tag_(tag), + method_tag_(method_tag), in_flight_(false), has_request_payload_( method->method_type() == internal::RpcMethod::NORMAL_RPC || @@ -176,10 +176,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Request(grpc_server* server, grpc_completion_queue* notify_cq) { GPR_ASSERT(cq_ && !in_flight_); in_flight_ = true; - if (tag_) { + if (method_tag_) { if (GRPC_CALL_OK != grpc_server_request_registered_call( - server, tag_, &call_, &deadline_, &request_metadata_, + server, method_tag_, &call_, &deadline_, &request_metadata_, has_request_payload_ ? &request_payload_ : nullptr, cq_, notify_cq, this)) { TeardownRequest(); @@ -211,6 +211,9 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { return true; } + // The CallData class represents a call that is "active" as opposed + // to just being requested. It wraps and takes ownership of the cq from + // the call request class CallData final { public: explicit CallData(Server* server, SyncRequest* mrd) @@ -281,7 +284,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get(); handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_, request_status_)); + &call_, &ctx_, request_, request_status_, nullptr)); request_ = nullptr; global_callbacks_->PostSynchronousRequest(&ctx_); @@ -314,7 +317,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { private: internal::RpcServiceMethod* const method_; - void* const tag_; + void* const method_tag_; bool in_flight_; const bool has_request_payload_; grpc_call* call_; @@ -325,6 +328,176 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { grpc_completion_queue* cq_; }; +class Server::CallbackRequest final : public internal::CompletionQueueTag { + public: + CallbackRequest(Server* server, internal::RpcServiceMethod* method, + void* method_tag) + : server_(server), + method_(method), + method_tag_(method_tag), + has_request_payload_( + method->method_type() == internal::RpcMethod::NORMAL_RPC || + method->method_type() == internal::RpcMethod::SERVER_STREAMING), + cq_(server->CallbackCQ()), + tag_(this) { + Setup(); + } + + ~CallbackRequest() { Clear(); } + + void Request() { + if (method_tag_) { + if (GRPC_CALL_OK != + grpc_server_request_registered_call( + server_->c_server(), method_tag_, &call_, &deadline_, + &request_metadata_, + has_request_payload_ ? &request_payload_ : nullptr, cq_->cq(), + cq_->cq(), static_cast(&tag_))) { + return; + } + } else { + if (!call_details_) { + call_details_ = new grpc_call_details; + grpc_call_details_init(call_details_); + } + if (grpc_server_request_call(server_->c_server(), &call_, call_details_, + &request_metadata_, cq_->cq(), cq_->cq(), + static_cast(&tag_)) != GRPC_CALL_OK) { + return; + } + } + } + + bool FinalizeResult(void** tag, bool* status) override { return false; } + + private: + class CallbackCallTag : public grpc_experimental_completion_queue_functor { + public: + CallbackCallTag(Server::CallbackRequest* req) : req_(req) { + functor_run = &CallbackCallTag::StaticRun; + } + + // force_run can not be performed on a tag if operations using this tag + // have been sent to PerformOpsOnCall. It is intended for error conditions + // that are detected before the operations are internally processed. + void force_run(bool ok) { Run(ok); } + + private: + Server::CallbackRequest* req_; + internal::Call* call_; + + static void StaticRun(grpc_experimental_completion_queue_functor* cb, + int ok) { + static_cast(cb)->Run(static_cast(ok)); + } + void Run(bool ok) { + void* ignored = req_; + bool new_ok = ok; + GPR_ASSERT(!req_->FinalizeResult(&ignored, &new_ok)); + GPR_ASSERT(ignored == req_); + + if (!ok) { + // The call has been shutdown + req_->Clear(); + return; + } + + // Bind the call, deadline, and metadata from what we got + req_->ctx_.set_call(req_->call_); + req_->ctx_.cq_ = req_->cq_; + req_->ctx_.BindDeadlineAndMetadata(req_->deadline_, + &req_->request_metadata_); + req_->request_metadata_.count = 0; + + // Create a C++ Call to control the underlying core call + call_ = new (grpc_call_arena_alloc(req_->call_, sizeof(internal::Call))) + internal::Call( + req_->call_, req_->server_, req_->cq_, + req_->server_->max_receive_message_size(), + req_->ctx_.set_server_rpc_info( + req_->method_->name(), req_->server_->interceptor_creators_)); + + req_->interceptor_methods_.SetCall(call_); + req_->interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + req_->interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + req_->interceptor_methods_.SetRecvInitialMetadata( + &req_->ctx_.client_metadata_); + + if (req_->has_request_payload_) { + // Set interception point for RECV MESSAGE + req_->request_ = req_->method_->handler()->Deserialize( + req_->call_, req_->request_payload_, &req_->request_status_); + req_->request_payload_ = nullptr; + req_->interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + req_->interceptor_methods_.SetRecvMessage(req_->request_); + } + + if (req_->interceptor_methods_.RunInterceptors( + [this] { ContinueRunAfterInterception(); })) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } + void ContinueRunAfterInterception() { + // req_->ctx_.BeginCompletionOp(call_); + req_->method_->handler()->RunHandler( + internal::MethodHandler::HandlerParameter( + call_, &req_->ctx_, req_->request_, req_->request_status_, + [this] { + req_->Reset(); + req_->Request(); + })); + } + }; + + void Reset() { + Clear(); + Setup(); + } + + void Clear() { + if (call_details_) { + delete call_details_; + call_details_ = nullptr; + } + grpc_metadata_array_destroy(&request_metadata_); + if (has_request_payload_ && request_payload_) { + grpc_byte_buffer_destroy(request_payload_); + } + ctx_.Clear(); + interceptor_methods_.ClearState(); + } + + void Setup() { + grpc_metadata_array_init(&request_metadata_); + ctx_.Setup(gpr_inf_future(GPR_CLOCK_REALTIME)); + request_payload_ = nullptr; + request_ = nullptr; + request_status_ = Status(); + } + + Server* const server_; + internal::RpcServiceMethod* const method_; + void* const method_tag_; + const bool has_request_payload_; + grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; + grpc_call_details* call_details_ = nullptr; + grpc_call* call_; + gpr_timespec deadline_; + grpc_metadata_array request_metadata_; + CompletionQueue* cq_; + CallbackCallTag tag_; + ServerContext ctx_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; +}; + // Implementation of ThreadManager. Each instance of SyncRequestThreadManager // manages a pool of threads that poll for incoming Sync RPCs and call the // appropriate RPC handlers @@ -504,6 +677,9 @@ Server::Server( Server::~Server() { { std::unique_lock lock(mu_); + if (callback_cq_ != nullptr) { + callback_cq_->Shutdown(); + } if (started_ && !shutdown_) { lock.unlock(); Shutdown(); @@ -576,21 +752,28 @@ bool Server::RegisterService(const grpc::string* host, Service* service) { } internal::RpcServiceMethod* method = it->get(); - void* tag = grpc_server_register_method( + void* method_registration_tag = grpc_server_register_method( server_, method->name(), host ? host->c_str() : nullptr, PayloadHandlingForMethod(method), 0); - if (tag == nullptr) { + if (method_registration_tag == nullptr) { gpr_log(GPR_DEBUG, "Attempt to register %s multiple times", method->name()); return false; } - if (method->handler() == nullptr) { // Async method - method->set_server_tag(tag); - } else { + if (method->handler() == nullptr) { // Async method without handler + method->set_server_tag(method_registration_tag); + } else if (method->api_type() == + internal::RpcServiceMethod::ApiType::SYNC) { for (auto it = sync_req_mgrs_.begin(); it != sync_req_mgrs_.end(); it++) { - (*it)->AddSyncMethod(method, tag); + (*it)->AddSyncMethod(method, method_registration_tag); } + } else { + // a callback method + auto* req = new CallbackRequest(this, method, method_registration_tag); + callback_reqs_.emplace_back(req); + // Enqueue it so that it will be Request'ed later once + // all request matchers are created at core server startup } method_name = method->name(); @@ -641,7 +824,8 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { // performance. This ensures that we don't introduce thread hops // for application requests that wind up on this CQ, which is polled // in its own thread. - health_check_cq = new ServerCompletionQueue(GRPC_CQ_NON_POLLING); + health_check_cq = + new ServerCompletionQueue(GRPC_CQ_NEXT, GRPC_CQ_NON_POLLING, nullptr); grpc_server_register_completion_queue(server_, health_check_cq->cq(), nullptr); default_health_check_service_impl = @@ -678,6 +862,10 @@ void Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { (*it)->Start(); } + for (auto& cbreq : callback_reqs_) { + cbreq->Request(); + } + if (default_health_check_service_impl != nullptr) { default_health_check_service_impl->StartServingThread(); } @@ -910,4 +1098,41 @@ Server::UnimplementedAsyncResponse::UnimplementedAsyncResponse( ServerInitializer* Server::initializer() { return server_initializer_.get(); } +namespace { +class ShutdownCallback : public grpc_experimental_completion_queue_functor { + public: + ShutdownCallback() { functor_run = &ShutdownCallback::Run; } + // TakeCQ takes ownership of the cq into the shutdown callback + // so that the shutdown callback will be responsible for destroying it + void TakeCQ(CompletionQueue* cq) { cq_ = cq; } + + // The Run function will get invoked by the completion queue library + // when the shutdown is actually complete + static void Run(grpc_experimental_completion_queue_functor* cb, int) { + auto* callback = static_cast(cb); + delete callback->cq_; + grpc_core::Delete(callback); + } + + private: + CompletionQueue* cq_ = nullptr; +}; +} // namespace + +CompletionQueue* Server::CallbackCQ() { + // TODO(vjpai): Consider using a single global CQ for the default CQ + // if there is no explicit per-server CQ registered + std::lock_guard l(mu_); + if (callback_cq_ == nullptr) { + auto* shutdown_callback = grpc_core::New(); + callback_cq_ = new CompletionQueue(grpc_completion_queue_attributes{ + GRPC_CQ_CURRENT_VERSION, GRPC_CQ_CALLBACK, GRPC_CQ_DEFAULT_POLLING, + shutdown_callback}); + + // Transfer ownership of the new cq to its own shutdown callback + shutdown_callback->TakeCQ(callback_cq_); + } + return callback_cq_; +}; + } // namespace grpc diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 995e787785..51a2689c6a 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -209,31 +209,35 @@ bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { // ServerContext body -ServerContext::ServerContext() - : completion_op_(nullptr), - has_notify_when_done_tag_(false), - async_notify_when_done_tag_(nullptr), - deadline_(gpr_inf_future(GPR_CLOCK_REALTIME)), - call_(nullptr), - cq_(nullptr), - sent_initial_metadata_(false), - compression_level_set_(false), - has_pending_ops_(false) {} - -ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata_array* arr) - : completion_op_(nullptr), - has_notify_when_done_tag_(false), - async_notify_when_done_tag_(nullptr), - deadline_(deadline), - call_(nullptr), - cq_(nullptr), - sent_initial_metadata_(false), - compression_level_set_(false), - has_pending_ops_(false) { +ServerContext::ServerContext() { Setup(gpr_inf_future(GPR_CLOCK_REALTIME)); } + +ServerContext::ServerContext(gpr_timespec deadline, grpc_metadata_array* arr) { + Setup(deadline); std::swap(*client_metadata_.arr(), *arr); } -ServerContext::~ServerContext() { +void ServerContext::Setup(gpr_timespec deadline) { + completion_op_ = nullptr; + has_notify_when_done_tag_ = false; + async_notify_when_done_tag_ = nullptr; + deadline_ = deadline; + call_ = nullptr; + cq_ = nullptr; + sent_initial_metadata_ = false; + compression_level_set_ = false; + has_pending_ops_ = false; + rpc_info_ = nullptr; +} + +void ServerContext::BindDeadlineAndMetadata(gpr_timespec deadline, + grpc_metadata_array* arr) { + deadline_ = deadline; + std::swap(*client_metadata_.arr(), *arr); +} + +ServerContext::~ServerContext() { Clear(); } + +void ServerContext::Clear() { if (call_) { grpc_call_unref(call_); } @@ -243,6 +247,8 @@ ServerContext::~ServerContext() { if (rpc_info_) { rpc_info_->Unref(); } + // Don't need to clear out call_, completion_op_, or rpc_info_ because this is + // either called from destructor or just before Setup } void ServerContext::BeginCompletionOp(internal::Call* call) { diff --git a/test/cpp/codegen/compiler_test_golden b/test/cpp/codegen/compiler_test_golden index 93e1e68654..fdc67969d9 100644 --- a/test/cpp/codegen/compiler_test_golden +++ b/test/cpp/codegen/compiler_test_golden @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -308,6 +309,80 @@ class ServiceA final { }; typedef WithAsyncMethod_MethodA1 > > > AsyncService; template + class ExperimentalWithCallbackMethod_MethodA1 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithCallbackMethod_MethodA1() { + ::grpc::Service::experimental().MarkMethodCallback(0, + new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodA1, ::grpc::testing::Request, ::grpc::testing::Response>( + [this](::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + this->MethodA1(context, request, response, controller); + }, this)); + } + ~ExperimentalWithCallbackMethod_MethodA1() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + virtual void MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + }; + template + class ExperimentalWithCallbackMethod_MethodA2 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithCallbackMethod_MethodA2() { + } + ~ExperimentalWithCallbackMethod_MethodA2() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template + class ExperimentalWithCallbackMethod_MethodA3 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithCallbackMethod_MethodA3() { + } + ~ExperimentalWithCallbackMethod_MethodA3() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template + class ExperimentalWithCallbackMethod_MethodA4 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithCallbackMethod_MethodA4() { + } + ~ExperimentalWithCallbackMethod_MethodA4() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + typedef ExperimentalWithCallbackMethod_MethodA1 > > > ExperimentalCallbackService; + template class WithGenericMethod_MethodA1 : public BaseClass { private: void BaseClassMustBeDerivedFromService(const Service *service) {} @@ -456,6 +531,79 @@ class ServiceA final { } }; template + class ExperimentalWithRawCallbackMethod_MethodA1 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithRawCallbackMethod_MethodA1() { + ::grpc::Service::experimental().MarkMethodRawCallback(0, + new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodA1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this](::grpc::ServerContext* context, + const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + this->MethodA1(context, request, response, controller); + }, this)); + } + ~ExperimentalWithRawCallbackMethod_MethodA1() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + virtual void MethodA1(::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + }; + template + class ExperimentalWithRawCallbackMethod_MethodA2 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithRawCallbackMethod_MethodA2() { + } + ~ExperimentalWithRawCallbackMethod_MethodA2() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA2(::grpc::ServerContext* context, ::grpc::ServerReader< ::grpc::testing::Request>* reader, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template + class ExperimentalWithRawCallbackMethod_MethodA3 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithRawCallbackMethod_MethodA3() { + } + ~ExperimentalWithRawCallbackMethod_MethodA3() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA3(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::ServerWriter< ::grpc::testing::Response>* writer) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template + class ExperimentalWithRawCallbackMethod_MethodA4 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithRawCallbackMethod_MethodA4() { + } + ~ExperimentalWithRawCallbackMethod_MethodA4() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodA4(::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::grpc::testing::Response, ::grpc::testing::Request>* stream) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template class WithStreamedUnaryMethod_MethodA1 : public BaseClass { private: void BaseClassMustBeDerivedFromService(const Service *service) {} @@ -591,6 +739,32 @@ class ServiceB final { }; typedef WithAsyncMethod_MethodB1 AsyncService; template + class ExperimentalWithCallbackMethod_MethodB1 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithCallbackMethod_MethodB1() { + ::grpc::Service::experimental().MarkMethodCallback(0, + new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithCallbackMethod_MethodB1, ::grpc::testing::Request, ::grpc::testing::Response>( + [this](::grpc::ServerContext* context, + const ::grpc::testing::Request* request, + ::grpc::testing::Response* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + this->MethodB1(context, request, response, controller); + }, this)); + } + ~ExperimentalWithCallbackMethod_MethodB1() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + virtual void MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + }; + typedef ExperimentalWithCallbackMethod_MethodB1 ExperimentalCallbackService; + template class WithGenericMethod_MethodB1 : public BaseClass { private: void BaseClassMustBeDerivedFromService(const Service *service) {} @@ -628,6 +802,31 @@ class ServiceB final { } }; template + class ExperimentalWithRawCallbackMethod_MethodB1 : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + ExperimentalWithRawCallbackMethod_MethodB1() { + ::grpc::Service::experimental().MarkMethodRawCallback(0, + new ::grpc::internal::CallbackUnaryHandler< ExperimentalWithRawCallbackMethod_MethodB1, ::grpc::ByteBuffer, ::grpc::ByteBuffer>( + [this](::grpc::ServerContext* context, + const ::grpc::ByteBuffer* request, + ::grpc::ByteBuffer* response, + ::grpc::experimental::ServerCallbackRpcController* controller) { + this->MethodB1(context, request, response, controller); + }, this)); + } + ~ExperimentalWithRawCallbackMethod_MethodB1() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status MethodB1(::grpc::ServerContext* context, const ::grpc::testing::Request* request, ::grpc::testing::Response* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + virtual void MethodB1(::grpc::ServerContext* context, const ::grpc::ByteBuffer* request, ::grpc::ByteBuffer* response, ::grpc::experimental::ServerCallbackRpcController* controller) { controller->Finish(::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "")); } + }; + template class WithStreamedUnaryMethod_MethodB1 : public BaseClass { private: void BaseClassMustBeDerivedFromService(const Service *service) {} diff --git a/test/cpp/end2end/client_callback_end2end_test.cc b/test/cpp/end2end/client_callback_end2end_test.cc index 62a85641c7..7ffc610ce2 100644 --- a/test/cpp/end2end/client_callback_end2end_test.cc +++ b/test/cpp/end2end/client_callback_end2end_test.cc @@ -41,13 +41,38 @@ namespace grpc { namespace testing { namespace { -class ClientCallbackEnd2endTest : public ::testing::Test { +class TestScenario { + public: + TestScenario(bool serve_callback) : callback_server(serve_callback) {} + void Log() const; + bool callback_server; +}; + +static std::ostream& operator<<(std::ostream& out, + const TestScenario& scenario) { + return out << "TestScenario{callback_server=" + << (scenario.callback_server ? "true" : "false") << "}"; +} + +void TestScenario::Log() const { + std::ostringstream out; + out << *this; + gpr_log(GPR_DEBUG, "%s", out.str().c_str()); +} + +class ClientCallbackEnd2endTest + : public ::testing::TestWithParam { protected: - ClientCallbackEnd2endTest() {} + ClientCallbackEnd2endTest() { GetParam().Log(); } void SetUp() override { ServerBuilder builder; - builder.RegisterService(&service_); + + if (!GetParam().callback_server) { + builder.RegisterService(&service_); + } else { + builder.RegisterService(&callback_service_); + } server_ = builder.BuildAndStart(); is_server_started_ = true; @@ -151,37 +176,38 @@ class ClientCallbackEnd2endTest : public ::testing::Test { std::unique_ptr stub_; std::unique_ptr generic_stub_; TestServiceImpl service_; + CallbackTestServiceImpl callback_service_; std::unique_ptr server_; }; -TEST_F(ClientCallbackEnd2endTest, SimpleRpc) { +TEST_P(ClientCallbackEnd2endTest, SimpleRpc) { ResetStub(); SendRpcs(1, false); } -TEST_F(ClientCallbackEnd2endTest, SequentialRpcs) { +TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) { ResetStub(); SendRpcs(10, false); } -TEST_F(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) { +TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) { ResetStub(); SendRpcs(10, true); } -TEST_F(ClientCallbackEnd2endTest, SequentialGenericRpcs) { +TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) { ResetStub(); SendRpcsGeneric(10, false); } #if GRPC_ALLOW_EXCEPTIONS -TEST_F(ClientCallbackEnd2endTest, ExceptingRpc) { +TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) { ResetStub(); SendRpcsGeneric(10, true); } #endif -TEST_F(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { +TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { ResetStub(); std::vector threads; threads.reserve(10); @@ -193,7 +219,7 @@ TEST_F(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) { } } -TEST_F(ClientCallbackEnd2endTest, MultipleRpcs) { +TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) { ResetStub(); std::vector threads; threads.reserve(10); @@ -205,7 +231,7 @@ TEST_F(ClientCallbackEnd2endTest, MultipleRpcs) { } } -TEST_F(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { +TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { ResetStub(); EchoRequest request; EchoResponse response; @@ -230,6 +256,11 @@ TEST_F(ClientCallbackEnd2endTest, CancelRpcBeforeStart) { } } +TestScenario scenarios[] = {TestScenario{false}, TestScenario{true}}; + +INSTANTIATE_TEST_CASE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest, + ::testing::ValuesIn(scenarios)); + } // namespace } // namespace testing } // namespace grpc diff --git a/test/cpp/end2end/test_service_impl.cc b/test/cpp/end2end/test_service_impl.cc index 3c3a5d9cd4..605356724f 100644 --- a/test/cpp/end2end/test_service_impl.cc +++ b/test/cpp/end2end/test_service_impl.cc @@ -165,6 +165,138 @@ Status TestServiceImpl::Echo(ServerContext* context, const EchoRequest* request, return Status::OK; } +void CallbackTestServiceImpl::Echo( + ServerContext* context, const EchoRequest* request, EchoResponse* response, + experimental::ServerCallbackRpcController* controller) { + // A bit of sleep to make sure that short deadline tests fail + if (request->has_param() && request->param().server_sleep_us() > 0) { + // Set an alarm for that much time + alarm_.experimental().Set( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(request->param().server_sleep_us(), + GPR_TIMESPAN)), + [this, context, request, response, controller](bool) { + EchoNonDelayed(context, request, response, controller); + }); + } else { + EchoNonDelayed(context, request, response, controller); + } +} + +void CallbackTestServiceImpl::EchoNonDelayed( + ServerContext* context, const EchoRequest* request, EchoResponse* response, + experimental::ServerCallbackRpcController* controller) { + if (request->has_param() && request->param().server_die()) { + gpr_log(GPR_ERROR, "The request should not reach application handler."); + GPR_ASSERT(0); + } + if (request->has_param() && request->param().has_expected_error()) { + const auto& error = request->param().expected_error(); + controller->Finish(Status(static_cast(error.code()), + error.error_message(), + error.binary_error_details())); + } + int server_try_cancel = GetIntValueFromMetadata( + kServerTryCancelRequest, context->client_metadata(), DO_NOT_CANCEL); + if (server_try_cancel > DO_NOT_CANCEL) { + // Since this is a unary RPC, by the time this server handler is called, + // the 'request' message is already read from the client. So the scenarios + // in server_try_cancel don't make much sense. Just cancel the RPC as long + // as server_try_cancel is not DO_NOT_CANCEL + EXPECT_FALSE(context->IsCancelled()); + context->TryCancel(); + gpr_log(GPR_INFO, "Server called TryCancel() to cancel the request"); + // Now wait until it's really canceled + + std::function recurrence = [this, context, controller, + &recurrence](bool) { + if (!context->IsCancelled()) { + alarm_.experimental().Set( + gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(1000, GPR_TIMESPAN)), + recurrence); + } else { + controller->Finish(Status::CANCELLED); + } + }; + recurrence(true); + return; + } + + response->set_message(request->message()); + MaybeEchoDeadline(context, request, response); + if (host_) { + response->mutable_param()->set_host(*host_); + } + if (request->has_param() && request->param().client_cancel_after_us()) { + { + std::unique_lock lock(mu_); + signal_client_ = true; + } + std::function recurrence = [this, context, request, controller, + &recurrence](bool) { + if (!context->IsCancelled()) { + alarm_.experimental().Set( + gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(request->param().client_cancel_after_us(), + GPR_TIMESPAN)), + recurrence); + } else { + controller->Finish(Status::CANCELLED); + } + }; + recurrence(true); + return; + } else if (request->has_param() && + request->param().server_cancel_after_us()) { + alarm_.experimental().Set( + gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), + gpr_time_from_micros(request->param().client_cancel_after_us(), + GPR_TIMESPAN)), + [controller](bool) { controller->Finish(Status::CANCELLED); }); + return; + } else if (!request->has_param() || + !request->param().skip_cancelled_check()) { + EXPECT_FALSE(context->IsCancelled()); + } + + if (request->has_param() && request->param().echo_metadata()) { + const std::multimap& client_metadata = + context->client_metadata(); + for (std::multimap::const_iterator + iter = client_metadata.begin(); + iter != client_metadata.end(); ++iter) { + context->AddTrailingMetadata(ToString(iter->first), + ToString(iter->second)); + } + // Terminate rpc with error and debug info in trailer. + if (request->param().debug_info().stack_entries_size() || + !request->param().debug_info().detail().empty()) { + grpc::string serialized_debug_info = + request->param().debug_info().SerializeAsString(); + context->AddTrailingMetadata(kDebugInfoTrailerKey, serialized_debug_info); + controller->Finish(Status::CANCELLED); + } + } + if (request->has_param() && + (request->param().expected_client_identity().length() > 0 || + request->param().check_auth_context())) { + CheckServerAuthContext(context, + request->param().expected_transport_security_type(), + request->param().expected_client_identity()); + } + if (request->has_param() && request->param().response_message_length() > 0) { + response->set_message( + grpc::string(request->param().response_message_length(), '\0')); + } + if (request->has_param() && request->param().echo_peer()) { + response->mutable_param()->set_peer(context->peer()); + } + controller->Finish(Status::OK); +} + // Unimplemented is left unimplemented to test the returned error. Status TestServiceImpl::RequestStream(ServerContext* context, @@ -332,7 +464,8 @@ Status TestServiceImpl::BidiStream( return Status::OK; } -int TestServiceImpl::GetIntValueFromMetadata( +namespace { +int GetIntValueFromMetadataHelper( const char* key, const std::multimap& metadata, int default_value) { @@ -344,6 +477,21 @@ int TestServiceImpl::GetIntValueFromMetadata( return default_value; } +}; // namespace + +int TestServiceImpl::GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} + +int CallbackTestServiceImpl::GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value) { + return GetIntValueFromMetadataHelper(key, metadata, default_value); +} void TestServiceImpl::ServerTryCancel(ServerContext* context) { EXPECT_FALSE(context->IsCancelled()); diff --git a/test/cpp/end2end/test_service_impl.h b/test/cpp/end2end/test_service_impl.h index 052543a03e..ddfe94487e 100644 --- a/test/cpp/end2end/test_service_impl.h +++ b/test/cpp/end2end/test_service_impl.h @@ -22,6 +22,7 @@ #include #include +#include #include #include "src/proto/grpc/testing/echo.grpc.pb.h" @@ -78,7 +79,39 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { void ServerTryCancel(ServerContext* context); + bool signal_client_; + std::mutex mu_; + std::unique_ptr host_; +}; + +class CallbackTestServiceImpl + : public ::grpc::testing::EchoTestService::ExperimentalCallbackService { + public: + CallbackTestServiceImpl() : signal_client_(false), host_() {} + explicit CallbackTestServiceImpl(const grpc::string& host) + : signal_client_(false), host_(new grpc::string(host)) {} + + void Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response, + experimental::ServerCallbackRpcController* controller) override; + + // Unimplemented is left unimplemented to test the returned error. + bool signal_client() { + std::unique_lock lock(mu_); + return signal_client_; + } + private: + void EchoNonDelayed(ServerContext* context, const EchoRequest* request, + EchoResponse* response, + experimental::ServerCallbackRpcController* controller); + + int GetIntValueFromMetadata( + const char* key, + const std::multimap& metadata, + int default_value); + + Alarm alarm_; bool signal_client_; std::mutex mu_; std::unique_ptr host_; diff --git a/tools/doxygen/Doxyfile.c++ b/tools/doxygen/Doxyfile.c++ index f2bb5df7d3..392113c284 100644 --- a/tools/doxygen/Doxyfile.c++ +++ b/tools/doxygen/Doxyfile.c++ @@ -971,6 +971,7 @@ include/grpcpp/impl/codegen/rpc_method.h \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ +include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -1008,6 +1009,7 @@ include/grpcpp/support/client_callback.h \ include/grpcpp/support/config.h \ include/grpcpp/support/proto_buffer_reader.h \ include/grpcpp/support/proto_buffer_writer.h \ +include/grpcpp/support/server_callback.h \ include/grpcpp/support/slice.h \ include/grpcpp/support/status.h \ include/grpcpp/support/status_code_enum.h \ diff --git a/tools/doxygen/Doxyfile.c++.internal b/tools/doxygen/Doxyfile.c++.internal index 81d74cd972..33dd985914 100644 --- a/tools/doxygen/Doxyfile.c++.internal +++ b/tools/doxygen/Doxyfile.c++.internal @@ -973,6 +973,7 @@ include/grpcpp/impl/codegen/rpc_method.h \ include/grpcpp/impl/codegen/rpc_service_method.h \ include/grpcpp/impl/codegen/security/auth_context.h \ include/grpcpp/impl/codegen/serialization_traits.h \ +include/grpcpp/impl/codegen/server_callback.h \ include/grpcpp/impl/codegen/server_context.h \ include/grpcpp/impl/codegen/server_interceptor.h \ include/grpcpp/impl/codegen/server_interface.h \ @@ -1010,6 +1011,7 @@ include/grpcpp/support/client_callback.h \ include/grpcpp/support/config.h \ include/grpcpp/support/proto_buffer_reader.h \ include/grpcpp/support/proto_buffer_writer.h \ +include/grpcpp/support/server_callback.h \ include/grpcpp/support/slice.h \ include/grpcpp/support/status.h \ include/grpcpp/support/status_code_enum.h \ diff --git a/tools/run_tests/generated/sources_and_headers.json b/tools/run_tests/generated/sources_and_headers.json index 042856146d..33ac22e33c 100644 --- a/tools/run_tests/generated/sources_and_headers.json +++ b/tools/run_tests/generated/sources_and_headers.json @@ -11221,6 +11221,7 @@ "include/grpcpp/impl/codegen/rpc_service_method.h", "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", + "include/grpcpp/impl/codegen/server_callback.h", "include/grpcpp/impl/codegen/server_context.h", "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", @@ -11296,6 +11297,7 @@ "include/grpcpp/impl/codegen/rpc_service_method.h", "include/grpcpp/impl/codegen/security/auth_context.h", "include/grpcpp/impl/codegen/serialization_traits.h", + "include/grpcpp/impl/codegen/server_callback.h", "include/grpcpp/impl/codegen/server_context.h", "include/grpcpp/impl/codegen/server_interceptor.h", "include/grpcpp/impl/codegen/server_interface.h", @@ -11445,6 +11447,7 @@ "include/grpcpp/support/config.h", "include/grpcpp/support/proto_buffer_reader.h", "include/grpcpp/support/proto_buffer_writer.h", + "include/grpcpp/support/server_callback.h", "include/grpcpp/support/slice.h", "include/grpcpp/support/status.h", "include/grpcpp/support/status_code_enum.h", @@ -11549,6 +11552,7 @@ "include/grpcpp/support/config.h", "include/grpcpp/support/proto_buffer_reader.h", "include/grpcpp/support/proto_buffer_writer.h", + "include/grpcpp/support/server_callback.h", "include/grpcpp/support/slice.h", "include/grpcpp/support/status.h", "include/grpcpp/support/status_code_enum.h", -- cgit v1.2.3