diff options
author | Yash Tibrewal <yashkt@google.com> | 2018-10-28 23:36:59 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-28 23:36:59 -0700 |
commit | 01313976e1a44b5c9625d3a349fffa55471beff4 (patch) | |
tree | 805596796cce33154e0d875c4c1ade918ba958f2 /src/cpp/server | |
parent | ffac9d90b18cb076b1c952faa55ce4e049cbc9a6 (diff) | |
parent | 395edbfa24968b8406a0c157874d3cb473076df5 (diff) |
Merge pull request #16842 from yashykt/interceptors
Experimental API for Client and Server Interception
Diffstat (limited to 'src/cpp/server')
-rw-r--r-- | src/cpp/server/server_builder.cc | 3 | ||||
-rw-r--r-- | src/cpp/server/server_cc.cc | 217 | ||||
-rw-r--r-- | src/cpp/server/server_context.cc | 131 |
3 files changed, 270 insertions, 81 deletions
diff --git a/src/cpp/server/server_builder.cc b/src/cpp/server/server_builder.cc index 8417c45e64..fc42b6c886 100644 --- a/src/cpp/server/server_builder.cc +++ b/src/cpp/server/server_builder.cc @@ -263,7 +263,8 @@ std::unique_ptr<Server> ServerBuilder::BuildAndStart() { std::unique_ptr<Server> server(new Server( max_receive_message_size_, &args, sync_server_cqs, sync_server_settings_.min_pollers, sync_server_settings_.max_pollers, - sync_server_settings_.cq_timeout_msec, resource_quota_)); + sync_server_settings_.cq_timeout_msec, resource_quota_, + std::move(interceptor_creators_))); if (has_sync_methods) { // This is a Sync server diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7aeddff643..59a531e272 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -27,7 +27,9 @@ #include <grpcpp/completion_queue.h> #include <grpcpp/generic/async_generic_service.h> #include <grpcpp/impl/codegen/async_unary_call.h> +#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/completion_queue_tag.h> +#include <grpcpp/impl/codegen/server_interceptor.h> #include <grpcpp/impl/grpc_library.h> #include <grpcpp/impl/method_handler_impl.h> #include <grpcpp/impl/rpc_service_method.h> @@ -38,8 +40,10 @@ #include <grpcpp/support/time.h> #include "src/core/ext/transport/inproc/inproc_transport.h" +#include "src/core/lib/iomgr/exec_ctx.h" #include "src/core/lib/profiling/timers.h" #include "src/core/lib/surface/call.h" +#include "src/core/lib/surface/completion_queue.h" #include "src/cpp/client/create_channel_internal.h" #include "src/cpp/server/health/default_health_check_service.h" #include "src/cpp/thread_manager/thread_manager.h" @@ -127,10 +131,13 @@ class Server::UnimplementedAsyncResponse final ~UnimplementedAsyncResponse() { delete request_; } bool FinalizeResult(void** tag, bool* status) override { - internal::CallOpSet< - internal::CallOpSendInitialMetadata, - internal::CallOpServerSendStatus>::FinalizeResult(tag, status); - delete this; + if (internal::CallOpSet< + internal::CallOpSendInitialMetadata, + internal::CallOpServerSendStatus>::FinalizeResult(tag, status)) { + delete this; + } else { + // The tag was swallowed due to interception. We will see it again. + } return false; } @@ -208,13 +215,18 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size()), ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), + request_(nullptr), method_(mrd->method_), - server_(server) { + call_(mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), + server->interceptor_creators_)), + server_(server), + global_callbacks_(nullptr), + resources_(false) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -230,33 +242,73 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Run(const std::shared_ptr<GlobalCallbacks>& global_callbacks, bool resources) { - ctx_.BeginCompletionOp(&call_); - global_callbacks->PreSynchronousRequest(&ctx_); - auto* handler = resources ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_payload_)); - global_callbacks->PostSynchronousRequest(&ctx_); - request_payload_ = nullptr; - - cq_.Shutdown(); + global_callbacks_ = global_callbacks; + resources_ = resources; + + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); + + if (has_request_payload_) { + // Set interception point for RECV MESSAGE + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + request_ = handler->Deserialize(request_payload_, &request_status_); + + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); + } - internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); - cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + auto f = std::bind(&CallData::ContinueRunAfterInterception, this); + if (interceptor_methods_.RunInterceptors(f)) { + ContinueRunAfterInterception(); + } else { + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. + } + } - /* Ensure the cq_ is shutdown */ - DummyTag ignored_tag; - GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + void ContinueRunAfterInterception() { + { + ctx_.BeginCompletionOp(&call_); + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( + &call_, &ctx_, request_, request_status_)); + request_ = nullptr; + global_callbacks_->PostSynchronousRequest(&ctx_); + + cq_.Shutdown(); + + internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + /* Ensure the cq_ is shutdown */ + DummyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + } + delete this; } private: CompletionQueue cq_; - internal::Call call_; ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; internal::RpcServiceMethod* const method_; + internal::Call call_; Server* server_; + std::shared_ptr<GlobalCallbacks> global_callbacks_; + bool resources_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; private: @@ -318,8 +370,9 @@ class Server::SyncRequestThreadManager : public ThreadManager { } if (ok) { - // Calldata takes ownership of the completion queue inside sync_req - SyncRequest::CallData cd(server_, sync_req); + // Calldata takes ownership of the completion queue and interceptors + // inside sync_req + auto* cd = new SyncRequest::CallData(server_, sync_req); // Prepare for the next request if (!IsShutdown()) { sync_req->SetupRequest(); // Create new completion queue for sync_req @@ -327,7 +380,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_, resources); + cd->Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -389,7 +442,10 @@ Server::Server( std::shared_ptr<std::vector<std::unique_ptr<ServerCompletionQueue>>> sync_server_cqs, int min_pollers, int max_pollers, int sync_cq_timeout_msec, - grpc_resource_quota* server_rq) + grpc_resource_quota* server_rq, + std::vector< + std::unique_ptr<experimental::ServerInterceptorFactoryInterface>> + interceptor_creators) : max_receive_message_size_(max_receive_message_size), sync_server_cqs_(std::move(sync_server_cqs)), started_(false), @@ -398,7 +454,8 @@ Server::Server( has_generic_service_(false), server_(nullptr), server_initializer_(new ServerInitializer(this)), - health_check_service_disabled_(false) { + health_check_service_disabled_(false), + interceptor_creators_(std::move(interceptor_creators)) { g_gli_initializer.summon(); gpr_once_init(&g_once_init_callbacks, InitGlobalCallbacks); global_callbacks_ = g_callbacks; @@ -681,31 +738,27 @@ void Server::Wait() { void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, internal::Call* call) { - static const size_t MAX_OPS = 8; - size_t nops = 0; - grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); - auto result = - grpc_call_start_batch(call->call(), cops, nops, ops->cq_tag(), nullptr); - if (result != GRPC_CALL_OK) { - gpr_log(GPR_ERROR, "Fatal: grpc_call_start_batch returned %d", result); - grpc_call_log_batch(__FILE__, __LINE__, GPR_LOG_SEVERITY_ERROR, - call->call(), cops, nops, ops); - abort(); - } + ops->FillOps(call); } ServerInterface::BaseAsyncRequest::BaseAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag, bool delete_on_finalize) + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) : server_(server), context_(context), stream_(stream), call_cq_(call_cq), + notification_cq_(notification_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), - call_(nullptr) { + call_(nullptr), + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops } @@ -715,15 +768,45 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (done_intercepting_) { + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } context_->set_call(call_); context_->cq_ = call_cq_; - internal::Call call(call_, server_, call_cq_, - server_->max_receive_message_size()); - if (*status && call_) { - context_->BeginCompletionOp(&call); + if (call_wrapper_.call() == nullptr) { + // Fill it since it is empty. + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); } + // just the pointers inside call are copied here - stream_->BindCall(&call); + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + // Set interception point for RECV INITIAL METADATA + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + auto f = std::bind(&ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception, + this); + if (interceptor_methods_.RunInterceptors(f)) { + // There are no interceptors to run. Continue + } else { + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. + return false; + } + } + if (*status && call_) { + context_->BeginCompletionOp(&call_wrapper_); + } *tag = tag_; if (delete_on_finalize_) { delete this; @@ -731,11 +814,25 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, return true; } +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_); + // Queue a tag which will be returned immediately + grpc_core::ExecCtx exec_ctx; + grpc_cq_begin_op(notification_cq_->cq(), this); + grpc_cq_end_op( + notification_cq_->cq(), this, GRPC_ERROR_NONE, + [](void* arg, grpc_cq_completion* completion) { delete completion; }, + nullptr, new grpc_cq_completion()); +} + ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerCompletionQueue* notification_cq, void* tag, const char* name) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -751,7 +848,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( ServerInterface* server, GenericServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) - : BaseAsyncRequest(server, context, stream, call_cq, tag, + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, delete_on_finalize) { grpc_call_details_init(&call_details_); GPR_ASSERT(notification_cq); @@ -764,6 +861,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + // If we are done intercepting, there is nothing more for us to do + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } // TODO(yangg) remove the copy here. if (*status) { static_cast<GenericServerContext*>(context_)->method_ = @@ -774,16 +875,26 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, } grpc_slice_unref(call_details_.method); grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info( + static_cast<GenericServerContext*>(context_)->method_.c_str(), + *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { - if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) { - new UnimplementedAsyncRequest(server_, cq_); - new UnimplementedAsyncResponse(this); + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + // We either had no interceptors run or we are done intercepting + if (*status) { + new UnimplementedAsyncRequest(server_, cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } } else { - delete this; + // The tag was swallowed due to interception. We will see it again. } return false; } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index bd532a968d..995e787785 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -41,13 +41,22 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { public: // initial refs: one in the server context, one in the cq // must ref the call before calling constructor and after deleting this - CompletionOp(grpc_call* call) - : call_(call), + CompletionOp(internal::Call* call) + : call_(*call), has_tag_(false), tag_(nullptr), refs_(2), finalized_(false), - cancelled_(0) {} + cancelled_(0), + done_intercepting_(false) {} + + ~CompletionOp() { + if (call_.server_rpc_info()) { + call_.server_rpc_info()->Unref(); + } + } + + void FillOps(internal::Call* call) override; // This should always be arena allocated in the call, so override delete. // But this class is not trivially destructible, so must actually call delete @@ -63,7 +72,6 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { // there are no tests catching the compiler warning. static void operator delete(void*, void*) { assert(0); } - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -82,58 +90,121 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { void Unref(); + // This will be called while interceptors are run if the RPC is a hijacked + // RPC. This should set hijacking state for each of the ops. + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_CODEGEN_ASSERT(false); + } + + /* Should be called after interceptors are done running */ + void ContinueFillOpsAfterInterception() override {} + + /* Should be called after interceptors are done running on the finalize result + * path */ + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + /* We don't have a tag to return. */ + std::unique_lock<std::mutex> lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return; + } + /* Start a dummy op so that we can return the tag */ + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, this, nullptr)); + } + private: bool CheckCancelledNoPluck() { std::lock_guard<std::mutex> g(mu_); return finalized_ ? (cancelled_ != 0) : false; } - grpc_call* call_; + internal::Call call_; bool has_tag_; void* tag_; std::mutex mu_; int refs_; bool finalized_; int cancelled_; + bool done_intercepting_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; void ServerContext::CompletionOp::Unref() { std::unique_lock<std::mutex> lock(mu_); if (--refs_ == 0) { lock.unlock(); - // Save aside the call pointer before deleting for later unref - grpc_call* call = call_; + grpc_call* call = call_.call(); delete this; grpc_call_unref(call); } } -void ServerContext::CompletionOp::FillOps(grpc_call* call, grpc_op* ops, - size_t* nops) { - ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; - ops->data.recv_close_on_server.cancelled = &cancelled_; - ops->flags = 0; - ops->reserved = nullptr; - *nops = 1; +void ServerContext::CompletionOp::FillOps(internal::Call* call) { + grpc_op ops; + ops.op = GRPC_OP_RECV_CLOSE_ON_SERVER; + ops.data.recv_close_on_server.cancelled = &cancelled_; + ops.flags = 0; + ops.reserved = nullptr; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); + GPR_ASSERT(GRPC_CALL_OK == + grpc_call_start_batch(call->call(), &ops, 1, this, nullptr)); + /* No interceptors to run here */ } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { - std::unique_lock<std::mutex> lock(mu_); - finalized_ = true; bool ret = false; - if (has_tag_) { - *tag = tag_; - ret = true; + std::unique_lock<std::mutex> lock(mu_); + if (done_intercepting_) { + /* We are done intercepting. */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } + finalized_ = true; + if (!*status) cancelled_ = 1; - if (--refs_ == 0) { - lock.unlock(); - // Save aside the call pointer before deleting for later unref - grpc_call* call = call_; - delete this; - grpc_call_unref(call); + /* Release the lock since we are going to be running through interceptors now + */ + lock.unlock(); + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + /* No interceptors were run */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + lock.lock(); + if (--refs_ == 0) { + lock.unlock(); + grpc_call* call = call_.call(); + delete this; + grpc_call_unref(call); + } + return ret; } - return ret; + /* There are interceptors to be run. Return false for now */ + return false; } // ServerContext body @@ -169,14 +240,20 @@ ServerContext::~ServerContext() { if (completion_op_) { completion_op_->Unref(); } + if (rpc_info_) { + rpc_info_->Unref(); + } } void ServerContext::BeginCompletionOp(internal::Call* call) { GPR_ASSERT(!completion_op_); + if (rpc_info_) { + rpc_info_->Ref(); + } grpc_call_ref(call->call()); completion_op_ = new (grpc_call_arena_alloc(call->call(), sizeof(CompletionOp))) - CompletionOp(call->call()); + CompletionOp(call); if (has_notify_when_done_tag_) { completion_op_->set_tag(async_notify_when_done_tag_); } |