diff options
author | 2018-07-24 14:10:01 -0700 | |
---|---|---|
committer | 2018-07-24 14:14:04 -0700 | |
commit | e9398c43cf470a7388df7d20baf6dd10a3b42edb (patch) | |
tree | b1f901eaa30d36533f5056de3e6ccdc6805fb66e /tensorflow/core/distributed_runtime | |
parent | 76e8f7b7fdf89b131e0406022129d5dde6b89e40 (diff) |
Push tensors from client to workers.
At times, a server cannot open a reverse connection to the client. This is
required when using the _Send/_Recv ops and the client needs to send a tensor
to the server (tensors are pulled). Instead, this adds a way to push the
tensors directly from the client.
Currently, pushing tensors always happens in sync mode.
PiperOrigin-RevId: 205888825
Diffstat (limited to 'tensorflow/core/distributed_runtime')
10 files changed, 151 insertions, 19 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 9ba8c8d80c..707f3234b9 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -39,6 +39,7 @@ class EagerClient { CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); CLIENT_METHOD(RegisterFunction); + CLIENT_METHOD(SendTensor); #undef CLIENT_METHOD }; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 466e779fab..916c8720f0 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -81,10 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, CreateContextResponse* response) { - //make sure env_ , env_->rendezvous_mgr available + // make sure env_ , env_->rendezvous_mgr available if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { - return tensorflow::errors::Internal("invalid eager env_ or env_->rendezvous_mgr."); - } + return tensorflow::errors::Internal( + "invalid eager env_ or env_->rendezvous_mgr."); + } std::vector<tensorflow::Device*> devices; TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( @@ -266,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction( return context->Context()->AddFunctionDef(request->function_def()); } +Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, + SendTensorResponse* response) { + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); + core::ScopedUnref context_unref(context); + + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; + for (const auto& tensor_proto : request->tensors()) { + Tensor tensor; + if (!tensor.FromProto(tensor_proto)) { + return errors::InvalidArgument("Unable to parse tensor proto"); + } + + TensorHandle* tensor_handle = + new TensorHandle(tensor, nullptr, nullptr, nullptr); + + TensorHandle* copied_handle = nullptr; + TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(), + request->device_name().c_str(), + &copied_handle)); + tensors.push_back(copied_handle); + tensor_handle->Unref(); + } + + context->AddOperationOutputs(tensors, request->op_id()); + + return Status::OK(); +} + tensorflow::Status EagerServiceImpl::GetServerContext( uint64 context_id, ServerContext** server_context) { mutex_lock l(contexts_mu_); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index b0e4aa84b9..718b4e2457 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -62,6 +62,9 @@ class EagerServiceImpl { Status RegisterFunction(const RegisterFunctionRequest* request, RegisterFunctionResponse* response); + Status SendTensor(const SendTensorRequest* request, + SendTensorResponse* response); + protected: // This is the server-side execution context. All state regarding execution of // a client's ops is held in this server-side context (all generated tensors, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index b98386ba86..d1f2a6da8f 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -84,7 +84,7 @@ class EagerServiceImplTest : public ::testing::Test { std::unique_ptr<DeviceMgr> device_mgr_; }; -void SetTensorProto(AttrValue* val) { +void SetTensorProto(TensorProto* tensor_proto) { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; TF_Tensor* t = TF_AllocateTensor( @@ -92,7 +92,7 @@ void SetTensorProto(AttrValue* val) { memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); tensorflow::Tensor tensor; TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor)); - tensor.AsProtoTensorContent(val->mutable_tensor()); + tensor.AsProtoTensorContent(tensor_proto); TF_DeleteTensor(t); } @@ -175,7 +175,7 @@ TEST_F(EagerServiceImplTest, BasicTest) { val.set_type(tensorflow::DataType::DT_FLOAT); const_attrs.insert({"dtype", val}); val.Clear(); - SetTensorProto(&val); + SetTensorProto(val.mutable_tensor()); const_attrs.insert({"value", val}); AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, @@ -260,7 +260,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { const_attrs.insert({"dtype", val}); val.Clear(); - SetTensorProto(&val); + SetTensorProto(val.mutable_tensor()); const_attrs.insert({"value", val}); AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, @@ -294,6 +294,77 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { &close_context_response)); } +// Test creates a context and attempts to send a tensor (using the RPC), and +// then use the tensor. +TEST_F(EagerServiceImplTest, SendTensorTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); + CreateContextResponse response; + + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + uint64 context_id = response.context_id(); + + SendTensorRequest send_tensor_request; + send_tensor_request.set_context_id(context_id); + send_tensor_request.set_op_id(1); + SetTensorProto(send_tensor_request.add_tensors()); + SendTensorResponse send_tensor_response; + + TF_ASSERT_OK(eager_service_impl.SendTensor(&send_tensor_request, + &send_tensor_response)); + + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + + std::unordered_map<string, AttrValue> attrs; + AttrValue val; + val.Clear(); + val.set_type(tensorflow::DataType::DT_FLOAT); + attrs.insert({"T", val}); + val.Clear(); + val.set_b(false); + attrs.insert({"transpose_a", val}); + attrs.insert({"transpose_b", val}); + + AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs, + "/job:localhost/replica:0/task:0/device:CPU:0", + &remote_enqueue_request); + + TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, + &remote_enqueue_response)); + + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + Device* device = nullptr; + TF_ASSERT_OK(tensor_handle->Device(&device)); + EXPECT_NE(device, nullptr); + EXPECT_EQ(device->name(), "/job:localhost/replica:0/task:0/device:CPU:0"); + + auto actual = t->flat<float>(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(7, actual(0)); + EXPECT_EQ(10, actual(1)); + EXPECT_EQ(15, actual(2)); + EXPECT_EQ(22, actual(3)); + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + } // namespace } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 28b68c3b88..0e3a68c4d8 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -29,8 +29,8 @@ namespace eager { class RemoteExecuteNode : public tensorflow::EagerNode { public: RemoteExecuteNode( - tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client, + tensorflow::uint64 id, std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client, const gtl::InlinedVector<TensorHandle*, 4>& inputs, std::function<void(const Status& status, const EnqueueResponse& response)> done_callback) @@ -45,8 +45,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } RemoteExecuteNode(tensorflow::uint64 id, - const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client) + std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client) : tensorflow::EagerNode(id), request_(std::move(request)), eager_client_(eager_client) {} @@ -58,10 +58,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } tensorflow::Status Run() override { - tensorflow::eager::EnqueueResponse response; - tensorflow::Status status; + EnqueueResponse response; + Status status; Notification n; - eager_client_->EnqueueAsync(&request_, &response, + eager_client_->EnqueueAsync(request_.get(), &response, [&n, &status](const tensorflow::Status& s) { status.Update(s); n.Notify(); @@ -76,9 +76,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } private: - EnqueueRequest request_; - tensorflow::eager::EagerClient* - eager_client_; // Not owned, and must outlive the RemoteExecuteNode. + std::unique_ptr<EnqueueRequest> request_; + EagerClient* eager_client_; // Not owned, and must outlive this node. // This is required to ensure that the tensor handles stay alive across the // execution. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index b23466037f..181422118c 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -49,6 +49,7 @@ class GrpcEagerClient : public EagerClient { CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); CLIENT_METHOD(RegisterFunction); + CLIENT_METHOD(SendTensor); #undef CLIENT_METHOD diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc index 39ab6856c5..ab3aa3fd1d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc @@ -36,6 +36,7 @@ static const char* grpcEagerService_method_names[] = { "/tensorflow.eager.EagerService/KeepAlive", "/tensorflow.eager.EagerService/CloseContext", "/tensorflow.eager.EagerService/RegisterFunction", + "/tensorflow.eager.EagerService/SendTensor", }; std::unique_ptr<EagerService::Stub> EagerService::NewStub( @@ -62,7 +63,9 @@ EagerService::Stub::Stub( ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_RegisterFunction_(grpcEagerService_method_names[5], ::grpc::internal::RpcMethod::NORMAL_RPC, - channel) {} + channel), + rpcmethod_SendTensor_(grpcEagerService_method_names[6], + ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status EagerService::Stub::CreateContext( ::grpc::ClientContext* context, const CreateContextRequest& request, @@ -106,8 +109,15 @@ EagerService::Stub::Stub( channel_.get(), rpcmethod_RegisterFunction_, context, request, response); } +::grpc::Status EagerService::Stub::SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) { + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_SendTensor_, context, request, response); +} + EagerService::AsyncService::AsyncService() { - for (int i = 0; i < 6; ++i) { + for (int i = 0; i < 7; ++i) { AddMethod(new ::grpc::internal::RpcServiceMethod( grpcEagerService_method_names[i], ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h index 66458186ad..521e0ac4fa 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h @@ -69,6 +69,9 @@ class EagerService final { virtual ::grpc::Status RegisterFunction( ::grpc::ClientContext* context, const RegisterFunctionRequest& request, RegisterFunctionResponse* response) = 0; + virtual ::grpc::Status SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) = 0; }; class Stub final : public StubInterface { public: @@ -91,6 +94,9 @@ class EagerService final { ::grpc::Status RegisterFunction( ::grpc::ClientContext* context, const RegisterFunctionRequest& request, RegisterFunctionResponse* response) override; + ::grpc::Status SendTensor(::grpc::ClientContext* context, + const SendTensorRequest& request, + SendTensorResponse* response) override; private: std::shared_ptr< ::grpc::ChannelInterface> channel_; @@ -100,6 +106,7 @@ class EagerService final { const ::grpc::internal::RpcMethod rpcmethod_KeepAlive_; const ::grpc::internal::RpcMethod rpcmethod_CloseContext_; const ::grpc::internal::RpcMethod rpcmethod_RegisterFunction_; + const ::grpc::internal::RpcMethod rpcmethod_SendTensor_; }; static std::unique_ptr<Stub> NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, @@ -157,6 +164,14 @@ class EagerService final { ::grpc::Service::RequestAsyncUnary(5, context, request, response, new_call_cq, notification_cq, tag); } + void RequestSendTensor( + ::grpc::ServerContext* context, SendTensorRequest* request, + ::grpc::ServerAsyncResponseWriter<SendTensorResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(6, context, request, response, + new_call_cq, notification_cq, tag); + } }; }; diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index 44e880de04..f511674e1f 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -48,6 +48,7 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(KeepAlive); ENQUEUE_REQUEST(CloseContext); ENQUEUE_REQUEST(RegisterFunction); + ENQUEUE_REQUEST(SendTensor); #undef ENQUEUE_REQUEST void* tag; // Matches the operation started against this cq_. diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 502f3ef529..537e9043bd 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -62,6 +62,7 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface { HANDLER(KeepAlive); HANDLER(CloseContext); HANDLER(RegisterFunction); + HANDLER(SendTensor); #undef HANDLER const WorkerEnv* const env_; // Not owned. |