diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-22 11:56:05 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-22 13:04:10 -0700 |
commit | cd95f3a7d66833cdb55fa180d5002f2cf686bcef (patch) | |
tree | 19ceedf9585c1a9d5070b5b134ea1871e032a99f /tensorflow/core/distributed_runtime/rpc/grpc_call.h | |
parent | 317b6d0e3a66ea446e793eb65a397786e62f8d85 (diff) |
Reduced some unnecessary abstraction layers in worker call sending path.
Also cleaned up a bunch of boilerplate.
Reduced allocations in Call tag management by inlining Tags into Call.
Reduces CPU usage for an RPC intensive benchmark by ~2%
Change: 130971469
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_call.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/grpc_call.h | 93 |
1 files changed, 54 insertions, 39 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h index 3b6cbf44c2..70627973c7 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "grpc++/grpc++.h" +#include "grpc++/impl/codegen/service_type.h" #include "grpc++/server_builder.h" namespace tensorflow { @@ -86,13 +87,6 @@ class UntypedCall : public core::RefCounted { // otherwise false. virtual void RequestReceived(Service* service, bool ok) = 0; - // This method will be called when the response has been sent by - // `service` and the call is no longer used. - // - // `ok` is true if the response sending completed as a "regular - // event", otherwise it is false. - void ResponseSent(Service* service, bool ok) {} - // This method will be called either (i) when the server is notified // that the request has been cancelled, or (ii) when the request completes // normally. The implementation should distinguish these cases by querying @@ -100,27 +94,31 @@ class UntypedCall : public core::RefCounted { virtual void RequestCancelled(Service* service, bool ok) = 0; // Associates a tag in a `::grpc::CompletionQueue` with a callback - // for an incoming RPC. A Tag owns a reference on the corresponding + // for an incoming RPC. An active Tag owns a reference on the corresponding // Call object. class Tag { public: - using Callback = void (UntypedCall::*)(Service*, bool); - - // Creates a new `Tag` for the given `UntypedCall`. When the - // request associated with this tag is complete, `callback` will - // be called. - Tag(UntypedCall* call, Callback callback) - : call_(call), callback_(callback) { - call_->Ref(); - } + // One enum value per supported callback. + enum Callback { kRequestReceived, kResponseSent, kCancelled }; - ~Tag() { call_->Unref(); } + Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {} // Calls the callback associated with this tag. // // The callback takes ownership of `this->call_`. void OnCompleted(Service* service, bool ok) { - (call_->*callback_)(service, ok); + switch (callback_) { + case kRequestReceived: + call_->RequestReceived(service, ok); + break; + case kResponseSent: + // No special handling needed apart from the Unref below. + break; + case kCancelled: + call_->RequestCancelled(service, ok); + break; + } + call_->Unref(); // Ref acquired when tag handed to grpc. } private: @@ -161,9 +159,8 @@ class Call : public UntypedCall<Service> { } void SendResponse(::grpc::Status status) { - responder_.Finish(response, status, - new typename UntypedCall<Service>::Tag( - this, &UntypedCall<Service>::ResponseSent)); + this->Ref(); // Ref for grpc; released in Tag callback. + responder_.Finish(response, status, &response_sent_tag_); this->Unref(); } @@ -174,9 +171,6 @@ class Call : public UntypedCall<Service> { cancel_callback_(); } } - // NOTE(mrry): This can be called before or after RequestReceived, so we - // release `cancel_tag_` (in order to allow the event loop to free it). - cancel_tag_.release(); } // Registers `callback` as the function that should be called if and when this @@ -208,11 +202,31 @@ class Call : public UntypedCall<Service> { call->RegisterCancellationHandler(); } - (grpc_service->*enqueue_function)( - &call->ctx_, &call->request, &call->responder_, cq, cq, - new typename UntypedCall<Service>::Tag( - call, &UntypedCall<Service>::RequestReceived)); - call->Unref(); + // Initial ref for call handed to grpc; released in Tag callback. + (grpc_service->*enqueue_function)(&call->ctx_, &call->request, + &call->responder_, cq, cq, + &call->request_received_tag_); + } + + // Enqueues a new request for the given service on the given + // completion queue, using the given `method_id`. + // + // The request will be handled with the given + // `handle_request_function`. + static void EnqueueRequestForMethod( + GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq, + int method_id, HandleRequestFunction handle_request_function, + bool supports_cancel) { + auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( + handle_request_function); + if (supports_cancel) { + call->RegisterCancellationHandler(); + } + + // Initial ref for call handed to grpc; released in Tag callback. + grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request, + &call->responder_, cq, cq, + &call->request_received_tag_); } RequestMessage request; @@ -223,22 +237,23 @@ class Call : public UntypedCall<Service> { // NOTE: This method must be called before this call is enqueued on a // completion queue. void RegisterCancellationHandler() { - cancel_tag_.reset(new typename UntypedCall<Service>::Tag( - this, &UntypedCall<Service>::RequestCancelled)); - ctx_.AsyncNotifyWhenDone(cancel_tag_.get()); + this->Ref(); // Ref for grpc; released in Tag callback. + ctx_.AsyncNotifyWhenDone(&cancelled_tag_); } HandleRequestFunction handle_request_function_; ::grpc::ServerContext ctx_; ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; + + // Used as void* completion markers from grpc to indicate different + // events of interest for a Call. + using typename UntypedCall<Service>::Tag; + Tag request_received_tag_{this, Tag::kRequestReceived}; + Tag response_sent_tag_{this, Tag::kResponseSent}; + Tag cancelled_tag_{this, Tag::kCancelled}; + mutex mu_; std::function<void()> cancel_callback_ GUARDED_BY(mu_); - - // This tag is initially owned by `*this` and borrowed by - // `ctx_->AsyncNotifyWhenDone()`. Ownership is transferred to the - // appropriate service's completion queue after - // `this->RequestReceived(..., true)` is called. - std::unique_ptr<typename UntypedCall<Service>::Tag> cancel_tag_; }; } // namespace tensorflow |