diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/eager')
5 files changed, 121 insertions, 17 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 9ba8c8d80c..707f3234b9 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -39,6 +39,7 @@ class EagerClient { CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); CLIENT_METHOD(RegisterFunction); + CLIENT_METHOD(SendTensor); #undef CLIENT_METHOD }; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 466e779fab..916c8720f0 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -81,10 +81,11 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, CreateContextResponse* response) { - //make sure env_ , env_->rendezvous_mgr available + // make sure env_ , env_->rendezvous_mgr available if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { - return tensorflow::errors::Internal("invalid eager env_ or env_->rendezvous_mgr."); - } + return tensorflow::errors::Internal( + "invalid eager env_ or env_->rendezvous_mgr."); + } std::vector<tensorflow::Device*> devices; TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( @@ -266,6 +267,35 @@ Status EagerServiceImpl::RegisterFunction( return context->Context()->AddFunctionDef(request->function_def()); } +Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, + SendTensorResponse* response) { + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); + core::ScopedUnref context_unref(context); + + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; + for (const auto& tensor_proto : request->tensors()) { + Tensor tensor; + if (!tensor.FromProto(tensor_proto)) { + return errors::InvalidArgument("Unable to parse tensor proto"); + } + + TensorHandle* tensor_handle = + new TensorHandle(tensor, nullptr, nullptr, nullptr); + + TensorHandle* copied_handle = nullptr; + TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(), + request->device_name().c_str(), + &copied_handle)); + tensors.push_back(copied_handle); + tensor_handle->Unref(); + } + + context->AddOperationOutputs(tensors, request->op_id()); + + return Status::OK(); +} + tensorflow::Status EagerServiceImpl::GetServerContext( uint64 context_id, ServerContext** server_context) { mutex_lock l(contexts_mu_); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index b0e4aa84b9..718b4e2457 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -62,6 +62,9 @@ class EagerServiceImpl { Status RegisterFunction(const RegisterFunctionRequest* request, RegisterFunctionResponse* response); + Status SendTensor(const SendTensorRequest* request, + SendTensorResponse* response); + protected: // This is the server-side execution context. All state regarding execution of // a client's ops is held in this server-side context (all generated tensors, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index b98386ba86..d1f2a6da8f 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -84,7 +84,7 @@ class EagerServiceImplTest : public ::testing::Test { std::unique_ptr<DeviceMgr> device_mgr_; }; -void SetTensorProto(AttrValue* val) { +void SetTensorProto(TensorProto* tensor_proto) { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; TF_Tensor* t = TF_AllocateTensor( @@ -92,7 +92,7 @@ void SetTensorProto(AttrValue* val) { memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); tensorflow::Tensor tensor; TF_ASSERT_OK(tensorflow::TF_TensorToTensor(t, &tensor)); - tensor.AsProtoTensorContent(val->mutable_tensor()); + tensor.AsProtoTensorContent(tensor_proto); TF_DeleteTensor(t); } @@ -175,7 +175,7 @@ TEST_F(EagerServiceImplTest, BasicTest) { val.set_type(tensorflow::DataType::DT_FLOAT); const_attrs.insert({"dtype", val}); val.Clear(); - SetTensorProto(&val); + SetTensorProto(val.mutable_tensor()); const_attrs.insert({"value", val}); AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, @@ -260,7 +260,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { const_attrs.insert({"dtype", val}); val.Clear(); - SetTensorProto(&val); + SetTensorProto(val.mutable_tensor()); const_attrs.insert({"value", val}); AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, @@ -294,6 +294,77 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { &close_context_response)); } +// Test creates a context and attempts to send a tensor (using the RPC), and +// then use the tensor. +TEST_F(EagerServiceImplTest, SendTensorTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); + CreateContextResponse response; + + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + uint64 context_id = response.context_id(); + + SendTensorRequest send_tensor_request; + send_tensor_request.set_context_id(context_id); + send_tensor_request.set_op_id(1); + SetTensorProto(send_tensor_request.add_tensors()); + SendTensorResponse send_tensor_response; + + TF_ASSERT_OK(eager_service_impl.SendTensor(&send_tensor_request, + &send_tensor_response)); + + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + + std::unordered_map<string, AttrValue> attrs; + AttrValue val; + val.Clear(); + val.set_type(tensorflow::DataType::DT_FLOAT); + attrs.insert({"T", val}); + val.Clear(); + val.set_b(false); + attrs.insert({"transpose_a", val}); + attrs.insert({"transpose_b", val}); + + AddOperationToEnqueueRequest(2, "MatMul", {{1, 0}, {1, 0}}, attrs, + "/job:localhost/replica:0/task:0/device:CPU:0", + &remote_enqueue_request); + + TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, + &remote_enqueue_response)); + + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + + Device* device = nullptr; + TF_ASSERT_OK(tensor_handle->Device(&device)); + EXPECT_NE(device, nullptr); + EXPECT_EQ(device->name(), "/job:localhost/replica:0/task:0/device:CPU:0"); + + auto actual = t->flat<float>(); + EXPECT_EQ(4, actual.size()); + + EXPECT_EQ(7, actual(0)); + EXPECT_EQ(10, actual(1)); + EXPECT_EQ(15, actual(2)); + EXPECT_EQ(22, actual(3)); + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + } // namespace } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 28b68c3b88..0e3a68c4d8 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -29,8 +29,8 @@ namespace eager { class RemoteExecuteNode : public tensorflow::EagerNode { public: RemoteExecuteNode( - tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client, + tensorflow::uint64 id, std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client, const gtl::InlinedVector<TensorHandle*, 4>& inputs, std::function<void(const Status& status, const EnqueueResponse& response)> done_callback) @@ -45,8 +45,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } RemoteExecuteNode(tensorflow::uint64 id, - const tensorflow::eager::EnqueueRequest& request, - tensorflow::eager::EagerClient* eager_client) + std::unique_ptr<EnqueueRequest> request, + EagerClient* eager_client) : tensorflow::EagerNode(id), request_(std::move(request)), eager_client_(eager_client) {} @@ -58,10 +58,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } tensorflow::Status Run() override { - tensorflow::eager::EnqueueResponse response; - tensorflow::Status status; + EnqueueResponse response; + Status status; Notification n; - eager_client_->EnqueueAsync(&request_, &response, + eager_client_->EnqueueAsync(request_.get(), &response, [&n, &status](const tensorflow::Status& s) { status.Update(s); n.Notify(); @@ -76,9 +76,8 @@ class RemoteExecuteNode : public tensorflow::EagerNode { } private: - EnqueueRequest request_; - tensorflow::eager::EagerClient* - eager_client_; // Not owned, and must outlive the RemoteExecuteNode. + std::unique_ptr<EnqueueRequest> request_; + EagerClient* eager_client_; // Not owned, and must outlive this node. // This is required to ensure that the tensor handles stay alive across the // execution. |