diff options
author | Yash Tibrewal <yashkt@google.com> | 2018-10-18 22:43:49 -0700 |
---|---|---|
committer | Yash Tibrewal <yashkt@google.com> | 2018-10-18 22:43:49 -0700 |
commit | 456231b26d56e13b9a56b93baabede4dd8fc2519 (patch) | |
tree | ac29ea953b7e8209c3e77b3b768991c69f463199 /include | |
parent | adca91f6cfe57cbd4af1e5a8cc8bfe3b506445c5 (diff) |
Server side interception for CompletionOp and AsyncRequest
Diffstat (limited to 'include')
-rw-r--r-- | include/grpcpp/impl/codegen/call.h | 12 | ||||
-rw-r--r-- | include/grpcpp/impl/codegen/server_interface.h | 85 | ||||
-rw-r--r-- | include/grpcpp/server.h | 4 |
3 files changed, 77 insertions, 24 deletions
diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index d591144798..167db8d4dd 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -1004,9 +1004,17 @@ class InterceptorBatchMethodsImpl /* Returns true if no interceptors are run. Returns false otherwise if there are interceptors registered. After the interceptors are done running \a f will be invoked. This is to be used only by BaseAsyncRequest and SyncRequest. */ - bool RunInterceptors(std::function<void(internal::CompletionQueueTag*)> f) { + bool RunInterceptors(std::function<void(void)> f) { GPR_CODEGEN_ASSERT(reverse_ == true); - return true; + GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + callback_ = std::move(f); + RunServerInterceptors(); + return false; } private: diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 9310ccb39f..c83d4aa194 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -20,12 +20,14 @@ #define GRPCPP_IMPL_CODEGEN_SERVER_INTERFACE_H #include <grpc/impl/codegen/grpc_types.h> +//#include <grpcpp/alarm.h> #include <grpcpp/impl/codegen/byte_buffer.h> #include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/call_hook.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> #include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/rpc_service_method.h> +#include <grpcpp/impl/codegen/server_context.h> namespace grpc { @@ -149,45 +151,69 @@ class ServerInterface : public internal::CallHook { public: BaseAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize); virtual ~BaseAsyncRequest(); bool FinalizeResult(void** tag, bool* status) override; + private: + void ContinueFinalizeResultAfterInterception(); + protected: ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; const bool delete_on_finalize_; grpc_call* call_; - internal::InterceptorBatchMethodsImpl interceptor_methods; + internal::Call call_wrapper_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; + bool done_intercepting_; + void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on + alarm.h here */ }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag); - - // uses BaseAsyncRequest::FinalizeResult + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, + const char* name); + + virtual bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(experimental::ServerRpcInfo( + context_, name_, *server_->interceptor_creators()))); + return BaseAsyncRequest::FinalizeResult(tag, status); + } protected: void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); + const char* name_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { - IssueRequest(registered_method, nullptr, notification_cq); + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()) { + IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } // uses RegisteredAsyncRequest::FinalizeResult @@ -196,13 +222,15 @@ class ServerInterface : public internal::CallHook { template <class Message> class PayloadAsyncRequest final : public RegisteredAsyncRequest { public: - PayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + PayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag), + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()), registered_method_(registered_method), server_(server), context_(context), @@ -211,7 +239,8 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { - IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq); + IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), + notification_cq); } ~PayloadAsyncRequest() { @@ -219,6 +248,10 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return RegisteredAsyncRequest::FinalizeResult(tag, status); + } if (*status) { if (!payload_.Valid() || !SerializationTraits<Message>::Deserialize( payload_.bbuf_ptr(), request_) @@ -237,15 +270,24 @@ class ServerInterface : public internal::CallHook { return false; } } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(experimental::ServerRpcInfo( + context_, name_, *server_->interceptor_creators()))); + /* Set interception point for recv message */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); return RegisteredAsyncRequest::FinalizeResult(tag, status); } private: - void* const registered_method_; + internal::RpcServiceMethod* const registered_method_; ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; Message* const request_; @@ -274,9 +316,8 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); - new PayloadAsyncRequest<Message>(method->server_tag(), this, context, - stream, call_cq, notification_cq, tag, - message); + new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq, + notification_cq, tag, message); } void RequestAsyncCall(internal::RpcServiceMethod* method, @@ -285,8 +326,8 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); - new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, - call_cq, notification_cq, tag); + new NoPayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag); } void RequestAsyncGenericCall(GenericServerContext* context, @@ -298,8 +339,10 @@ class ServerInterface : public internal::CallHook { tag, true); } -private: - virtual const std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* interceptor_creators() { + private: + virtual const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() { return nullptr; } }; diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index a593b60550..2b89ffd317 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -191,7 +191,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: - const std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* interceptor_creators() override { + const std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* + interceptor_creators() override { return &interceptor_creators_; } |