diff options
author | 2016-08-22 11:56:05 -0800 | |
---|---|---|
committer | 2016-08-22 13:04:10 -0700 | |
commit | cd95f3a7d66833cdb55fa180d5002f2cf686bcef (patch) | |
tree | 19ceedf9585c1a9d5070b5b134ea1871e032a99f | |
parent | 317b6d0e3a66ea446e793eb65a397786e62f8d85 (diff) |
Reduced some unnecessary abstraction layers in worker call sending path.
Also cleaned up a bunch of boilerplate.
Reduced allocations in Call tag management by inlining Tags into Call.
Reduces CPU usage for an RPC intensive benchmark by ~2%
Change: 130971469
8 files changed, 205 insertions, 591 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h index 3b6cbf44c2..70627973c7 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "grpc++/grpc++.h" +#include "grpc++/impl/codegen/service_type.h" #include "grpc++/server_builder.h" namespace tensorflow { @@ -86,13 +87,6 @@ class UntypedCall : public core::RefCounted { // otherwise false. virtual void RequestReceived(Service* service, bool ok) = 0; - // This method will be called when the response has been sent by - // `service` and the call is no longer used. - // - // `ok` is true if the response sending completed as a "regular - // event", otherwise it is false. - void ResponseSent(Service* service, bool ok) {} - // This method will be called either (i) when the server is notified // that the request has been cancelled, or (ii) when the request completes // normally. The implementation should distinguish these cases by querying @@ -100,27 +94,31 @@ class UntypedCall : public core::RefCounted { virtual void RequestCancelled(Service* service, bool ok) = 0; // Associates a tag in a `::grpc::CompletionQueue` with a callback - // for an incoming RPC. A Tag owns a reference on the corresponding + // for an incoming RPC. An active Tag owns a reference on the corresponding // Call object. class Tag { public: - using Callback = void (UntypedCall::*)(Service*, bool); - - // Creates a new `Tag` for the given `UntypedCall`. When the - // request associated with this tag is complete, `callback` will - // be called. - Tag(UntypedCall* call, Callback callback) - : call_(call), callback_(callback) { - call_->Ref(); - } + // One enum value per supported callback. + enum Callback { kRequestReceived, kResponseSent, kCancelled }; - ~Tag() { call_->Unref(); } + Tag(UntypedCall* call, Callback cb) : call_(call), callback_(cb) {} // Calls the callback associated with this tag. // // The callback takes ownership of `this->call_`. void OnCompleted(Service* service, bool ok) { - (call_->*callback_)(service, ok); + switch (callback_) { + case kRequestReceived: + call_->RequestReceived(service, ok); + break; + case kResponseSent: + // No special handling needed apart from the Unref below. + break; + case kCancelled: + call_->RequestCancelled(service, ok); + break; + } + call_->Unref(); // Ref acquired when tag handed to grpc. } private: @@ -161,9 +159,8 @@ class Call : public UntypedCall<Service> { } void SendResponse(::grpc::Status status) { - responder_.Finish(response, status, - new typename UntypedCall<Service>::Tag( - this, &UntypedCall<Service>::ResponseSent)); + this->Ref(); // Ref for grpc; released in Tag callback. + responder_.Finish(response, status, &response_sent_tag_); this->Unref(); } @@ -174,9 +171,6 @@ class Call : public UntypedCall<Service> { cancel_callback_(); } } - // NOTE(mrry): This can be called before or after RequestReceived, so we - // release `cancel_tag_` (in order to allow the event loop to free it). - cancel_tag_.release(); } // Registers `callback` as the function that should be called if and when this @@ -208,11 +202,31 @@ class Call : public UntypedCall<Service> { call->RegisterCancellationHandler(); } - (grpc_service->*enqueue_function)( - &call->ctx_, &call->request, &call->responder_, cq, cq, - new typename UntypedCall<Service>::Tag( - call, &UntypedCall<Service>::RequestReceived)); - call->Unref(); + // Initial ref for call handed to grpc; released in Tag callback. + (grpc_service->*enqueue_function)(&call->ctx_, &call->request, + &call->responder_, cq, cq, + &call->request_received_tag_); + } + + // Enqueues a new request for the given service on the given + // completion queue, using the given `method_id`. + // + // The request will be handled with the given + // `handle_request_function`. + static void EnqueueRequestForMethod( + GrpcService* grpc_service, ::grpc::ServerCompletionQueue* cq, + int method_id, HandleRequestFunction handle_request_function, + bool supports_cancel) { + auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( + handle_request_function); + if (supports_cancel) { + call->RegisterCancellationHandler(); + } + + // Initial ref for call handed to grpc; released in Tag callback. + grpc_service->RequestAsyncUnary(method_id, &call->ctx_, &call->request, + &call->responder_, cq, cq, + &call->request_received_tag_); } RequestMessage request; @@ -223,22 +237,23 @@ class Call : public UntypedCall<Service> { // NOTE: This method must be called before this call is enqueued on a // completion queue. void RegisterCancellationHandler() { - cancel_tag_.reset(new typename UntypedCall<Service>::Tag( - this, &UntypedCall<Service>::RequestCancelled)); - ctx_.AsyncNotifyWhenDone(cancel_tag_.get()); + this->Ref(); // Ref for grpc; released in Tag callback. + ctx_.AsyncNotifyWhenDone(&cancelled_tag_); } HandleRequestFunction handle_request_function_; ::grpc::ServerContext ctx_; ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; + + // Used as void* completion markers from grpc to indicate different + // events of interest for a Call. + using typename UntypedCall<Service>::Tag; + Tag request_received_tag_{this, Tag::kRequestReceived}; + Tag response_sent_tag_{this, Tag::kResponseSent}; + Tag cancelled_tag_{this, Tag::kCancelled}; + mutex mu_; std::function<void()> cancel_callback_ GUARDED_BY(mu_); - - // This tag is initially owned by `*this` and borrowed by - // `ctx_->AsyncNotifyWhenDone()`. Ownership is transferred to the - // appropriate service's completion queue after - // `this->RequestReceived(..., true)` is called. - std::unique_ptr<typename UntypedCall<Service>::Tag> cancel_tag_; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h index b305ab44fe..95c2c935f0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h @@ -28,26 +28,14 @@ namespace tensorflow { // stored in a `grpc::CompletionQueue`. class GrpcClientCQTag { public: - GrpcClientCQTag(::grpc::ClientContext* context, StatusCallback cb) - : context_(context), cb_(std::move(cb)) {} - ~GrpcClientCQTag() { delete context_; } + GrpcClientCQTag() {} + virtual ~GrpcClientCQTag() {} - void OnCompleted(bool ok) { - if (!ok) { - VLOG(2) << "Call returned with non-ok status: " - << status_.error_message(); - } - cb_(FromGrpcStatus(status_)); - } - - ::grpc::ClientContext* context() { return context_; } - ::grpc::Status* status() { return &status_; } + // OnCompleted is invoked when the RPC has finished. + // Implementations of OnCompleted must delete *this. + virtual void OnCompleted(bool ok) = 0; private: - ::grpc::ClientContext* context_; - ::grpc::Status status_; - StatusCallback cb_; - TF_DISALLOW_COPY_AND_ASSIGN(GrpcClientCQTag); }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 9823980e83..c8a0892842 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -118,7 +118,6 @@ class GrpcMasterService : public AsyncServiceInterface { static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag); if (callback_tag) { callback_tag->OnCompleted(this, ok); - delete callback_tag; } else { // NOTE(mrry): A null `callback_tag` indicates that this is // the shutdown alarm. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 14ad1e0355..79d3b3e2f6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -37,8 +37,17 @@ class GrpcRemoteWorker : public WorkerInterface { explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) - : stub_(grpc::WorkerService::NewStub(channel)), + : channel_(channel), cq_(completion_queue), + getstatus_(Method(GrpcWorkerMethod::kGetStatus)), + registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)), + deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)), + rungraph_(Method(GrpcWorkerMethod::kRunGraph)), + cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)), + cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)), + recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)), + logging_(Method(GrpcWorkerMethod::kLogging)), + tracing_(Method(GrpcWorkerMethod::kTracing)), logger_(logger) {} ~GrpcRemoteWorker() override {} @@ -46,45 +55,36 @@ class GrpcRemoteWorker : public WorkerInterface { void GetStatusAsync(const GetStatusRequest* request, GetStatusResponse* response, StatusCallback done) override { - IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncGetStatus, - std::move(done)); + IssueRequest(request, response, getstatus_, std::move(done)); } void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override { - IssueRequest(request, response, - &grpc::WorkerService::Stub::AsyncRegisterGraph, - std::move(done)); + IssueRequest(request, response, registergraph_, std::move(done)); } void DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) override { - IssueRequest(request, response, - &grpc::WorkerService::Stub::AsyncDeregisterGraph, - std::move(done)); + IssueRequest(request, response, deregistergraph_, std::move(done)); } void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request, RunGraphResponse* response, StatusCallback done) override { - IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph, - std::move(done), call_opts); + IssueRequest(request, response, rungraph_, std::move(done), call_opts); } void CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) override { - IssueRequest(request, response, - &grpc::WorkerService::Stub::AsyncCleanupGraph, - std::move(done)); + IssueRequest(request, response, cleanupgraph_, std::move(done)); } void CleanupAllAsync(const CleanupAllRequest* request, CleanupAllResponse* response, StatusCallback done) override { - IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncCleanupAll, - std::move(done)); + IssueRequest(request, response, cleanupall_, std::move(done)); } void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, @@ -156,59 +156,98 @@ class GrpcRemoteWorker : public WorkerInterface { cb_to_use = &wrapper_done; } - IssueRequest(req_copy ? req_copy : request, response, - &grpc::WorkerService::Stub::AsyncRecvTensor, + IssueRequest(req_copy ? req_copy : request, response, recvtensor_, std::move(*cb_to_use), call_opts); } void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done) override { - IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncLogging, - done); + IssueRequest(request, response, logging_, done); } void TracingAsync(const TracingRequest* request, TracingResponse* response, StatusCallback done) override { - IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncTracing, - done); + IssueRequest(request, response, tracing_, done); } private: + // Object allocated per active RPC. template <class RequestMessage, class ResponseMessage> - using AsyncMethod = - std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseMessage>> ( - grpc::WorkerService::Stub::*)(::grpc::ClientContext*, - const RequestMessage&, - ::grpc::CompletionQueue*); + class RPCState final : public GrpcClientCQTag { + public: + RPCState(::grpc::ChannelInterface* channel, ::grpc::CompletionQueue* cq, + const ::grpc::RpcMethod& method, const RequestMessage& request, + StatusCallback done, CallOptions* call_opts) + : call_opts_(call_opts), + reader_(channel, cq, method, InitContext(call_opts), request), + done_(std::move(done)) {} + + ~RPCState() override {} + + void StartRPC(ResponseMessage* response) { + reader_.Finish(response, &status_, this); + } + + void OnCompleted(bool ok) override { + if (!ok) { + VLOG(2) << "Call returned with non-ok status: " + << status_.error_message(); + } + if (call_opts_) { + call_opts_->ClearCancelCallback(); + } + done_(FromGrpcStatus(status_)); + delete this; + } + + private: + CallOptions* call_opts_; + ::grpc::ClientContext context_; + ::grpc::ClientAsyncResponseReader<ResponseMessage> reader_; + ::grpc::Status status_; + StatusCallback done_; + + ::grpc::ClientContext* InitContext(CallOptions* call_opts) { + // The initialization and recovery protocols rely on blocking + // until we get a response. + context_.set_fail_fast(false); + if (call_opts) { + call_opts->SetCancelCallback([this]() { context_.TryCancel(); }); + } + return &context_; + } + }; // Utility method for issuing a generic asynchronous request. The // given callback, `done`, will be called when the RPC completes. template <class RequestMessage, class ResponseMessage> void IssueRequest(const RequestMessage* request, ResponseMessage* response, - AsyncMethod<RequestMessage, ResponseMessage> async_method, - StatusCallback done, CallOptions* call_opts = nullptr) { - ::grpc::ClientContext* context = new ::grpc::ClientContext; - // The initialization and recovery protocols rely on blocking - // until we get a response. - context->set_fail_fast(false); - if (call_opts) { - call_opts->SetCancelCallback([context]() { context->TryCancel(); }); - } - auto rpc = (stub_.get()->*async_method)(context, *request, cq_).release(); - GrpcClientCQTag* tag = - new GrpcClientCQTag(context, [rpc, done, call_opts](Status s) { - if (call_opts) { - call_opts->ClearCancelCallback(); - } - delete rpc; - done(s); - }); - rpc->Finish(response, tag->status(), tag); + const ::grpc::RpcMethod& method, StatusCallback done, + CallOptions* call_opts = nullptr) { + auto state = new RPCState<RequestMessage, ResponseMessage>( + channel_.get(), cq_, method, *request, std::move(done), call_opts); + state->StartRPC(response); } - std::unique_ptr<grpc::WorkerService::Stub> stub_; + // Helper function for initializing the RpcMethod objects below. + ::grpc::RpcMethod Method(GrpcWorkerMethod id) { + return ::grpc::RpcMethod(GrpcWorkerMethodName(id), + ::grpc::RpcMethod::NORMAL_RPC, channel_); + } + + SharedGrpcChannelPtr channel_; ::grpc::CompletionQueue* cq_; + const ::grpc::RpcMethod getstatus_; + const ::grpc::RpcMethod registergraph_; + const ::grpc::RpcMethod deregistergraph_; + const ::grpc::RpcMethod rungraph_; + const ::grpc::RpcMethod cleanupgraph_; + const ::grpc::RpcMethod cleanupall_; + const ::grpc::RpcMethod recvtensor_; + const ::grpc::RpcMethod logging_; + const ::grpc::RpcMethod tracing_; + // Support for logging. WorkerCacheLogger* logger_; bool retry_unavailable_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 485ed14c0f..0c0c80117a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -38,7 +38,6 @@ class GrpcWorkerCache : public WorkerCachePartial { while (completion_queue_.Next(&tag, &ok)) { GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag); callback_tag->OnCompleted(ok); - delete callback_tag; } }); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 9f7d009a36..9a87bbda1b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -85,7 +85,7 @@ class GrpcWorkerService : public AsyncServiceInterface { } // This macro creates a new request for the given RPC method name -// (e.g., `ENQUEUE_REQUEST(GetStatus);`), and enqueues it on +// (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on // `this->cq_`. // // This macro is invoked one or more times for each RPC method to @@ -95,17 +95,17 @@ class GrpcWorkerService : public AsyncServiceInterface { // The implementation of the request handler for each RPC method // must ensure that it calls ENQUEUE_REQUEST() for that RPC method, // to keep accepting new requests. -#define ENQUEUE_REQUEST(method, supports_cancel) \ - do { \ - mutex_lock l(shutdown_mu_); \ - if (!is_shutdown_) { \ - Call<GrpcWorkerService, grpc::WorkerService::AsyncService, \ - method##Request, method##Response>:: \ - EnqueueRequest(&worker_service_, cq_, \ - &grpc::WorkerService::AsyncService::Request##method, \ - &GrpcWorkerService::method##Handler, \ - (supports_cancel)); \ - } \ +#define ENQUEUE_REQUEST(method, supports_cancel) \ + do { \ + mutex_lock l(shutdown_mu_); \ + if (!is_shutdown_) { \ + Call<GrpcWorkerService, grpc::WorkerService::AsyncService, \ + method##Request, method##Response>:: \ + EnqueueRequestForMethod( \ + &worker_service_, cq_, \ + static_cast<int>(GrpcWorkerMethod::k##method), \ + &GrpcWorkerService::method##Handler, (supports_cancel)); \ + } \ } while (0) // This method blocks forever handling requests from the completion queue. @@ -145,7 +145,6 @@ class GrpcWorkerService : public AsyncServiceInterface { static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag); if (callback_tag) { callback_tag->OnCompleted(this, ok); - delete callback_tag; } else { // NOTE(mrry): A null `callback_tag` indicates that this is // the shutdown alarm. @@ -267,9 +266,9 @@ class GrpcWorkerService : public AsyncServiceInterface { if (!is_shutdown_) { Call<GrpcWorkerService, grpc::WorkerService::AsyncService, RecvTensorRequest, ::grpc::ByteBuffer>:: - EnqueueRequest( + EnqueueRequestForMethod( &worker_service_, cq_, - &grpc::WorkerService::AsyncService::RequestRecvTensorRaw, + static_cast<int>(GrpcWorkerMethod::kRecvTensor), &GrpcWorkerService::RecvTensorHandlerRaw, true /* supports cancel*/); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 9f5e50f90e..a3ba22a95d 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -26,127 +26,36 @@ limitations under the License. namespace tensorflow { -namespace grpc { - -static const char* grpcWorkerService_method_names[] = { - "/tensorflow.WorkerService/GetStatus", - "/tensorflow.WorkerService/RegisterGraph", - "/tensorflow.WorkerService/DeregisterGraph", - "/tensorflow.WorkerService/RunGraph", - "/tensorflow.WorkerService/CleanupGraph", - "/tensorflow.WorkerService/CleanupAll", - "/tensorflow.WorkerService/RecvTensor", - "/tensorflow.WorkerService/Logging", - "/tensorflow.WorkerService/Tracing", -}; - -std::unique_ptr<WorkerService::Stub> WorkerService::NewStub( - const std::shared_ptr< ::grpc::ChannelInterface>& channel, - const ::grpc::StubOptions& options) { - std::unique_ptr<WorkerService::Stub> stub(new WorkerService::Stub(channel)); - return stub; -} - -WorkerService::Stub::Stub( - const std::shared_ptr< ::grpc::ChannelInterface>& channel) - : channel_(channel), - rpcmethod_GetStatus_(grpcWorkerService_method_names[0], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_RegisterGraph_(grpcWorkerService_method_names[1], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_DeregisterGraph_(grpcWorkerService_method_names[2], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_RunGraph_(grpcWorkerService_method_names[3], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_CleanupGraph_(grpcWorkerService_method_names[4], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_CleanupAll_(grpcWorkerService_method_names[5], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_RecvTensor_(grpcWorkerService_method_names[6], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_Logging_(grpcWorkerService_method_names[7], - ::grpc::RpcMethod::NORMAL_RPC, channel), - rpcmethod_Tracing_(grpcWorkerService_method_names[8], - ::grpc::RpcMethod::NORMAL_RPC, channel) {} - -::grpc::ClientAsyncResponseReader<GetStatusResponse>* -WorkerService::Stub::AsyncGetStatusRaw(::grpc::ClientContext* context, - const GetStatusRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<GetStatusResponse>( - channel_.get(), cq, rpcmethod_GetStatus_, context, request); -} - -::grpc::ClientAsyncResponseReader<RegisterGraphResponse>* -WorkerService::Stub::AsyncRegisterGraphRaw(::grpc::ClientContext* context, - const RegisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<RegisterGraphResponse>( - channel_.get(), cq, rpcmethod_RegisterGraph_, context, request); -} - -::grpc::ClientAsyncResponseReader<DeregisterGraphResponse>* -WorkerService::Stub::AsyncDeregisterGraphRaw( - ::grpc::ClientContext* context, const DeregisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<DeregisterGraphResponse>( - channel_.get(), cq, rpcmethod_DeregisterGraph_, context, request); -} - -::grpc::ClientAsyncResponseReader<RunGraphResponse>* -WorkerService::Stub::AsyncRunGraphRaw(::grpc::ClientContext* context, - const RunGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<RunGraphResponse>( - channel_.get(), cq, rpcmethod_RunGraph_, context, request); -} - -::grpc::ClientAsyncResponseReader<CleanupGraphResponse>* -WorkerService::Stub::AsyncCleanupGraphRaw(::grpc::ClientContext* context, - const CleanupGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<CleanupGraphResponse>( - channel_.get(), cq, rpcmethod_CleanupGraph_, context, request); -} - -::grpc::ClientAsyncResponseReader<CleanupAllResponse>* -WorkerService::Stub::AsyncCleanupAllRaw(::grpc::ClientContext* context, - const CleanupAllRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<CleanupAllResponse>( - channel_.get(), cq, rpcmethod_CleanupAll_, context, request); -} - -::grpc::ClientAsyncResponseReader<TensorResponse>* -WorkerService::Stub::AsyncRecvTensorRaw(::grpc::ClientContext* context, - const RecvTensorRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<TensorResponse>( - channel_.get(), cq, rpcmethod_RecvTensor_, context, request); -} - -::grpc::ClientAsyncResponseReader<LoggingResponse>* -WorkerService::Stub::AsyncLoggingRaw(::grpc::ClientContext* context, - const LoggingRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<LoggingResponse>( - channel_.get(), cq, rpcmethod_Logging_, context, request); +const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { + switch (id) { + case GrpcWorkerMethod::kGetStatus: + return "/tensorflow.WorkerService/GetStatus"; + case GrpcWorkerMethod::kRegisterGraph: + return "/tensorflow.WorkerService/RegisterGraph"; + case GrpcWorkerMethod::kDeregisterGraph: + return "/tensorflow.WorkerService/DeregisterGraph"; + case GrpcWorkerMethod::kRunGraph: + return "/tensorflow.WorkerService/RunGraph"; + case GrpcWorkerMethod::kCleanupGraph: + return "/tensorflow.WorkerService/CleanupGraph"; + case GrpcWorkerMethod::kCleanupAll: + return "/tensorflow.WorkerService/CleanupAll"; + case GrpcWorkerMethod::kRecvTensor: + return "/tensorflow.WorkerService/RecvTensor"; + case GrpcWorkerMethod::kLogging: + return "/tensorflow.WorkerService/Logging"; + case GrpcWorkerMethod::kTracing: + return "/tensorflow.WorkerService/Tracing"; + } } -::grpc::ClientAsyncResponseReader<TracingResponse>* -WorkerService::Stub::AsyncTracingRaw(::grpc::ClientContext* context, - const TracingRequest& request, - ::grpc::CompletionQueue* cq) { - return new ::grpc::ClientAsyncResponseReader<TracingResponse>( - channel_.get(), cq, rpcmethod_Tracing_, context, request); -} +namespace grpc { WorkerService::AsyncService::AsyncService() { - (void)grpcWorkerService_method_names; - for (int i = 0; i < TF_ARRAYSIZE(grpcWorkerService_method_names); ++i) { - AddMethod(new ::grpc::RpcServiceMethod(grpcWorkerService_method_names[i], - ::grpc::RpcMethod::NORMAL_RPC, - nullptr)); + for (int i = 0; i < kGrpcNumWorkerMethods; ++i) { + AddMethod(new ::grpc::RpcServiceMethod( + GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)), + ::grpc::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index 0513bc6894..f3aac795ca 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -106,6 +106,23 @@ class SerializationTraits<tensorflow::TensorResponse> namespace tensorflow { +// Names of worker methods. +enum class GrpcWorkerMethod { + kGetStatus, + kRegisterGraph, + kDeregisterGraph, + kRunGraph, + kCleanupGraph, + kCleanupAll, + kRecvTensor, + kLogging, + kTracing, +}; +static const int kGrpcNumWorkerMethods = + static_cast<int>(GrpcWorkerMethod::kTracing) + 1; + +const char* GrpcWorkerMethodName(GrpcWorkerMethod id); + namespace grpc { // Implementation of `tensorflow.WorkerService`, based on the @@ -114,364 +131,13 @@ namespace grpc { // See the proto file for the definition of methods and messages. class WorkerService GRPC_FINAL { public: - class StubInterface { - public: - virtual ~StubInterface() {} - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::GetStatusResponse>> - AsyncGetStatus(::grpc::ClientContext* context, - const ::tensorflow::GetStatusRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::GetStatusResponse>>( - AsyncGetStatusRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RegisterGraphResponse>> - AsyncRegisterGraph(::grpc::ClientContext* context, - const ::tensorflow::RegisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RegisterGraphResponse>>( - AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::DeregisterGraphResponse>> - AsyncDeregisterGraph(::grpc::ClientContext* context, - const ::tensorflow::DeregisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::DeregisterGraphResponse>>( - AsyncDeregisterGraphRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RunGraphResponse>> - AsyncRunGraph(::grpc::ClientContext* context, - const ::tensorflow::RunGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RunGraphResponse>>( - AsyncRunGraphRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupGraphResponse>> - AsyncCleanupGraph(::grpc::ClientContext* context, - const ::tensorflow::CleanupGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupGraphResponse>>( - AsyncCleanupGraphRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupAllResponse>> - AsyncCleanupAll(::grpc::ClientContext* context, - const ::tensorflow::CleanupAllRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupAllResponse>>( - AsyncCleanupAllRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TensorResponse>> - AsyncRecvTensor(::grpc::ClientContext* context, - const ::tensorflow::RecvTensorRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TensorResponse>>( - AsyncRecvTensorRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::LoggingResponse>> - AsyncLogging(::grpc::ClientContext* context, - const ::tensorflow::LoggingRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::LoggingResponse>>( - AsyncLoggingRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TracingResponse>> - AsyncTracing(::grpc::ClientContext* context, - const ::tensorflow::TracingRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TracingResponse>>( - AsyncTracingRaw(context, request, cq)); - } - - private: - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::GetStatusResponse>* - AsyncGetStatusRaw(::grpc::ClientContext* context, - const ::tensorflow::GetStatusRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RegisterGraphResponse>* - AsyncRegisterGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::RegisterGraphRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::DeregisterGraphResponse>* - AsyncDeregisterGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::DeregisterGraphRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::RunGraphResponse>* - AsyncRunGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::RunGraphRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupGraphResponse>* - AsyncCleanupGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::CleanupGraphRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::CleanupAllResponse>* - AsyncCleanupAllRaw(::grpc::ClientContext* context, - const ::tensorflow::CleanupAllRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TensorResponse>* - AsyncRecvTensorRaw(::grpc::ClientContext* context, - const ::tensorflow::RecvTensorRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::LoggingResponse>* - AsyncLoggingRaw(::grpc::ClientContext* context, - const ::tensorflow::LoggingRequest& request, - ::grpc::CompletionQueue* cq) = 0; - virtual ::grpc::ClientAsyncResponseReaderInterface< - ::tensorflow::TracingResponse>* - AsyncTracingRaw(::grpc::ClientContext* context, - const ::tensorflow::TracingRequest& request, - ::grpc::CompletionQueue* cq) = 0; - }; - class Stub GRPC_FINAL : public StubInterface { - public: - Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::GetStatusResponse>> - AsyncGetStatus(::grpc::ClientContext* context, - const ::tensorflow::GetStatusRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::GetStatusResponse>>( - AsyncGetStatusRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::RegisterGraphResponse>> - AsyncRegisterGraph(::grpc::ClientContext* context, - const ::tensorflow::RegisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReader< - ::tensorflow::RegisterGraphResponse>>( - AsyncRegisterGraphRaw(context, request, cq)); - } - std::unique_ptr<::grpc::ClientAsyncResponseReader< - ::tensorflow::DeregisterGraphResponse>> - AsyncDeregisterGraph(::grpc::ClientContext* context, - const ::tensorflow::DeregisterGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReader< - ::tensorflow::DeregisterGraphResponse>>( - AsyncDeregisterGraphRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::RunGraphResponse>> - AsyncRunGraph(::grpc::ClientContext* context, - const ::tensorflow::RunGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::RunGraphResponse>>( - AsyncRunGraphRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::CleanupGraphResponse>> - AsyncCleanupGraph(::grpc::ClientContext* context, - const ::tensorflow::CleanupGraphRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr<::grpc::ClientAsyncResponseReader< - ::tensorflow::CleanupGraphResponse>>( - AsyncCleanupGraphRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::CleanupAllResponse>> - AsyncCleanupAll(::grpc::ClientContext* context, - const ::tensorflow::CleanupAllRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::CleanupAllResponse>>( - AsyncCleanupAllRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::TensorResponse>> - AsyncRecvTensor(::grpc::ClientContext* context, - const ::tensorflow::RecvTensorRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::TensorResponse>>( - AsyncRecvTensorRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::LoggingResponse>> - AsyncLogging(::grpc::ClientContext* context, - const ::tensorflow::LoggingRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::LoggingResponse>>( - AsyncLoggingRaw(context, request, cq)); - } - std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::TracingResponse>> - AsyncTracing(::grpc::ClientContext* context, - const ::tensorflow::TracingRequest& request, - ::grpc::CompletionQueue* cq) { - return std::unique_ptr< - ::grpc::ClientAsyncResponseReader<::tensorflow::TracingResponse>>( - AsyncTracingRaw(context, request, cq)); - } - - private: - std::shared_ptr<::grpc::ChannelInterface> channel_; - ::grpc::ClientAsyncResponseReader<::tensorflow::GetStatusResponse>* - AsyncGetStatusRaw(::grpc::ClientContext* context, - const ::tensorflow::GetStatusRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::RegisterGraphResponse>* - AsyncRegisterGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::RegisterGraphRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::DeregisterGraphResponse>* - AsyncDeregisterGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::DeregisterGraphRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::RunGraphResponse>* - AsyncRunGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::RunGraphRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::CleanupGraphResponse>* - AsyncCleanupGraphRaw(::grpc::ClientContext* context, - const ::tensorflow::CleanupGraphRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::CleanupAllResponse>* - AsyncCleanupAllRaw(::grpc::ClientContext* context, - const ::tensorflow::CleanupAllRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::TensorResponse>* - AsyncRecvTensorRaw(::grpc::ClientContext* context, - const ::tensorflow::RecvTensorRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::LoggingResponse>* - AsyncLoggingRaw(::grpc::ClientContext* context, - const ::tensorflow::LoggingRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - ::grpc::ClientAsyncResponseReader<::tensorflow::TracingResponse>* - AsyncTracingRaw(::grpc::ClientContext* context, - const ::tensorflow::TracingRequest& request, - ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; - const ::grpc::RpcMethod rpcmethod_GetStatus_; - const ::grpc::RpcMethod rpcmethod_RegisterGraph_; - const ::grpc::RpcMethod rpcmethod_DeregisterGraph_; - const ::grpc::RpcMethod rpcmethod_RunGraph_; - const ::grpc::RpcMethod rpcmethod_CleanupGraph_; - const ::grpc::RpcMethod rpcmethod_CleanupAll_; - const ::grpc::RpcMethod rpcmethod_RecvTensor_; - const ::grpc::RpcMethod rpcmethod_Logging_; - const ::grpc::RpcMethod rpcmethod_Tracing_; - }; - static std::unique_ptr<Stub> NewStub( - const std::shared_ptr<::grpc::ChannelInterface>& channel, - const ::grpc::StubOptions& options = ::grpc::StubOptions()); - class AsyncService : public ::grpc::Service { public: AsyncService(); virtual ~AsyncService(); - void RequestGetStatus( - ::grpc::ServerContext* context, ::tensorflow::GetStatusRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::GetStatusResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(0, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestRegisterGraph( - ::grpc::ServerContext* context, - ::tensorflow::RegisterGraphRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::RegisterGraphResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(1, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestDeregisterGraph( - ::grpc::ServerContext* context, - ::tensorflow::DeregisterGraphRequest* request, - ::grpc::ServerAsyncResponseWriter< - ::tensorflow::DeregisterGraphResponse>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(2, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestRunGraph( - ::grpc::ServerContext* context, ::tensorflow::RunGraphRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::RunGraphResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(3, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestCleanupGraph( - ::grpc::ServerContext* context, - ::tensorflow::CleanupGraphRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::CleanupGraphResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(4, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestCleanupAll( - ::grpc::ServerContext* context, - ::tensorflow::CleanupAllRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::CleanupAllResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(5, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestRecvTensorRaw( - ::grpc::ServerContext* context, - ::tensorflow::RecvTensorRequest* request, - ::grpc::ServerAsyncResponseWriter<::grpc::ByteBuffer>* response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(6, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestLogging( - ::grpc::ServerContext* context, ::tensorflow::LoggingRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::LoggingResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(7, context, request, response, - new_call_cq, notification_cq, tag); - } - void RequestTracing( - ::grpc::ServerContext* context, ::tensorflow::TracingRequest* request, - ::grpc::ServerAsyncResponseWriter<::tensorflow::TracingResponse>* - response, - ::grpc::CompletionQueue* new_call_cq, - ::grpc::ServerCompletionQueue* notification_cq, void* tag) { - ::grpc::Service::RequestAsyncUnary(8, context, request, response, - new_call_cq, notification_cq, tag); - } + + // Make RequestAsyncUnary public for grpc_call.h + using ::grpc::Service::RequestAsyncUnary; }; }; |