diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc')
5 files changed, 30 insertions, 2 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index b23466037f..181422118c 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -49,6 +49,7 @@ class GrpcEagerClient : public EagerClient { CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); CLIENT_METHOD(RegisterFunction); + CLIENT_METHOD(SendTensor); #undef CLIENT_METHOD diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc index 39ab6856c5..ab3aa3fd1d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc @@ -36,6 +36,7 @@ static const char* grpcEagerService_method_names[] = { "/tensorflow.eager.EagerService/KeepAlive", "/tensorflow.eager.EagerService/CloseContext", "/tensorflow.eager.EagerService/RegisterFunction", + "/tensorflow.eager.EagerService/SendTensor", }; std::unique_ptr<EagerService::Stub> EagerService::NewStub( @@ -62,7 +63,9 @@ EagerService::Stub::Stub( ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_RegisterFunction_(grpcEagerService_method_names[5], ::grpc::internal::RpcMethod::NORMAL_RPC, - channel) {} + channel), + rpcmethod_SendTensor_(grpcEagerService_method_names[6], + ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status EagerService::Stub::CreateContext( ::grpc::ClientContext* context, const CreateContextRequest& request, @@ -106,8 +109,15 @@ EagerService::Stub::Stub( channel_.get(), rpcmethod_RegisterFunction_, context, request, response); } +::grpc::Status EagerService::Stub::SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) { + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_SendTensor_, context, request, response); +} + EagerService::AsyncService::AsyncService() { - for (int i = 0; i < 6; ++i) { + for (int i = 0; i < 7; ++i) { AddMethod(new ::grpc::internal::RpcServiceMethod( grpcEagerService_method_names[i], ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h index 66458186ad..521e0ac4fa 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h @@ -69,6 +69,9 @@ class EagerService final { virtual ::grpc::Status RegisterFunction( ::grpc::ClientContext* context, const RegisterFunctionRequest& request, RegisterFunctionResponse* response) = 0; + virtual ::grpc::Status SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) = 0; }; class Stub final : public StubInterface { public: @@ -91,6 +94,9 @@ class EagerService final { ::grpc::Status RegisterFunction( ::grpc::ClientContext* context, const RegisterFunctionRequest& request, RegisterFunctionResponse* response) override; + ::grpc::Status SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) override; private: std::shared_ptr< ::grpc::ChannelInterface> channel_; @@ -100,6 +106,7 @@ class EagerService final { const ::grpc::internal::RpcMethod rpcmethod_KeepAlive_; const ::grpc::internal::RpcMethod rpcmethod_CloseContext_; const ::grpc::internal::RpcMethod rpcmethod_RegisterFunction_; + const ::grpc::internal::RpcMethod rpcmethod_SendTensor_; }; static std::unique_ptr<Stub> NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, @@ -157,6 +164,14 @@ class EagerService final { ::grpc::Service::RequestAsyncUnary(5, context, request, response, new_call_cq, notification_cq, tag); } + void RequestSendTensor( + ::grpc::ServerContext* context, SendTensorRequest* request, + ::grpc::ServerAsyncResponseWriter<SendTensorResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(6, context, request, response, + new_call_cq, notification_cq, tag); + } }; }; diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index 44e880de04..f511674e1f 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -48,6 +48,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(KeepAlive); ENQUEUE_REQUEST(CloseContext); ENQUEUE_REQUEST(RegisterFunction); + ENQUEUE_REQUEST(SendTensor); #undef ENQUEUE_REQUEST void* tag; // Matches the operation started against this cq_. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 502f3ef529..537e9043bd 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -62,6 +62,7 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface { HANDLER(KeepAlive); HANDLER(CloseContext); HANDLER(RegisterFunction); + HANDLER(SendTensor); #undef HANDLER const WorkerEnv* const env_; // Not owned. |