aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/rpc/grpc_call.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-22 11:56:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-22 13:04:10 -0700
commitcd95f3a7d66833cdb55fa180d5002f2cf686bcef (patch)
tree19ceedf9585c1a9d5070b5b134ea1871e032a99f /tensorflow/core/distributed_runtime/rpc/grpc_call.h
parent317b6d0e3a66ea446e793eb65a397786e62f8d85 (diff)
Reduced some unnecessary abstraction layers in worker call sending path.
Also cleaned up a bunch of boilerplate. Reduced allocations in Call tag management by inlining Tags into Call. Reduces CPU usage for an RPC intensive benchmark by ~2% Change: 130971469
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc/grpc_call.h')
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_call.h93
1 files changed, 54 insertions, 39 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
index 3b6cbf44c2..70627973c7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "grpc++/grpc++.h"
+#include "grpc++/impl/codegen/service_type.h"
#include "grpc++/server_builder.h"
namespace tensorflow {
@@ -86,13 +87,6 @@ class UntypedCall : public core::RefCounted {
// otherwise false.
virtual void RequestReceived(Service* service, bool ok) = 0;
- // This method will be called when the response has been sent by
- // `service` and the call is no longer used.
- //
- // `ok` is true if the response sending completed as a "regular
- // event", otherwise it is false.
- void ResponseSent(Service* service, bool ok) {}
-
// This method will be called either (i) when the server is notified
// that the request has been cancelled, or (ii) when the request completes
// normally. The implementation should distinguish these cases by querying
@@ -100,27 +94,31 @@ class UntypedCall : public core::RefCounted {
virtual void RequestCancelled(Service* service, bool ok) = 0;
// Associates a tag in a `::grpc::CompletionQueue` with a callback
- // for an incoming RPC. A Tag owns a reference on the corresponding
+ // for an incoming RPC. An active Tag owns a reference on the corresponding
// Call object.
class Tag {
public:
- using Callback = void (UntypedCall::*)(Service*, bool);
-
- // Creates a new `Tag` for the given `UntypedCall`. When the
- // request associated with this tag is complete, `callback` will
- // be called.
- Tag(UntypedCall* call, Callback callback)
- : call_(call), callback_(callback) {
- call_->Ref();
- }
+ // One enum value per supported callback.
+ enum Callback { kRequestReceived, kResponseSent, kCancelled };
- ~Tag() { call_->Unref(); }
+ Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {}
// Calls the callback associated with this tag.
//
// The callback takes ownership of `this->call_`.
void OnCompleted(Service* service, bool ok) {
- (call_->*callback_)(service, ok);
+ switch (callback_) {
+ case kRequestReceived:
+ call_->RequestReceived(service, ok);
+ break;
+ case kResponseSent:
+ // No special handling needed apart from the Unref below.
+ break;
+ case kCancelled:
+ call_->RequestCancelled(service, ok);
+ break;
+ }
+ call_->Unref(); // Ref acquired when tag handed to grpc.
}
private:
@@ -161,9 +159,8 @@ class Call : public UntypedCall<Service> {
}
void SendResponse(::grpc::Status status) {
- responder_.Finish(response, status,
- new typename UntypedCall<Service>::Tag(
- this, &UntypedCall<Service>::ResponseSent));
+ this->Ref(); // Ref for grpc; released in Tag callback.
+ responder_.Finish(response, status, &response_sent_tag_);
this->Unref();
}
@@ -174,9 +171,6 @@ class Call : public UntypedCall<Service> {
cancel_callback_();
}
}
- // NOTE(mrry): This can be called before or after RequestReceived, so we
- // release `cancel_tag_` (in order to allow the event loop to free it).
- cancel_tag_.release();
}
// Registers `callback` as the function that should be called if and when this
@@ -208,11 +202,31 @@ class Call : public UntypedCall<Service> {
call->RegisterCancellationHandler();
}
- (grpc_service->*enqueue_function)(
- &call->ctx_, &call->request, &call->responder_, cq, cq,
- new typename UntypedCall<Service>::Tag(
- call, &UntypedCall<Service>::RequestReceived));
- call->Unref();
+ // Initial ref for call handed to grpc; released in Tag callback.
+ (grpc_service->*enqueue_function)(&call->ctx_, &call->request,
+ &call->responder_, cq, cq,
+ &call->request_received_tag_);
+ }
+
+ // Enqueues a new request for the given service on the given
+ // completion queue, using the given `method_id`.
+ //
+ // The request will be handled with the given
+ // `handle_request_function`.
+ static void EnqueueRequestForMethod(
+ GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq,
+ int method_id, HandleRequestFunction handle_request_function,
+ bool supports_cancel) {
+ auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
+ handle_request_function);
+ if (supports_cancel) {
+ call->RegisterCancellationHandler();
+ }
+
+ // Initial ref for call handed to grpc; released in Tag callback.
+ grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request,
+ &call->responder_, cq, cq,
+ &call->request_received_tag_);
}
RequestMessage request;
@@ -223,22 +237,23 @@ class Call : public UntypedCall<Service> {
// NOTE: This method must be called before this call is enqueued on a
// completion queue.
void RegisterCancellationHandler() {
- cancel_tag_.reset(new typename UntypedCall<Service>::Tag(
- this, &UntypedCall<Service>::RequestCancelled));
- ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
+ this->Ref(); // Ref for grpc; released in Tag callback.
+ ctx_.AsyncNotifyWhenDone(&cancelled_tag_);
}
HandleRequestFunction handle_request_function_;
::grpc::ServerContext ctx_;
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
+
+ // Used as void* completion markers from grpc to indicate different
+ // events of interest for a Call.
+ using typename UntypedCall<Service>::Tag;
+ Tag request_received_tag_{this, Tag::kRequestReceived};
+ Tag response_sent_tag_{this, Tag::kResponseSent};
+ Tag cancelled_tag_{this, Tag::kCancelled};
+
mutex mu_;
std::function<void()> cancel_callback_ GUARDED_BY(mu_);
-
- // This tag is initially owned by `*this` and borrowed by
- // `ctx_->AsyncNotifyWhenDone()`. Ownership is transferred to the
- // appropriate service's completion queue after
- // `this->RequestReceived(..., true)` is called.
- std::unique_ptr<typename UntypedCall<Service>::Tag> cancel_tag_;
};
} // namespace tensorflow