diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-06-26 15:36:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 15:43:10 -0700 |
commit | 79c828ea6ddbcfccd43a2be176fc1dcad4daf34e (patch) | |
tree | add42271d2d23096b03af3bf8fb89384cffc9a54 /tensorflow/core/distributed_runtime | |
parent | ec34de06981eed74c2c2a47c8a6372735e9d3622 (diff) |
Support shapes for remote eager tensor handles.
Since we respond with the shape, all RPCs will happen sync (note
that we may still hide the python overhead, since the op is still scheduled for
execution via the eager executor).
PiperOrigin-RevId: 202207324
Diffstat (limited to 'tensorflow/core/distributed_runtime')
5 files changed, 61 insertions, 3 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 710bd8d021..22d0902af2 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2fa234c810..5a26d5bf48 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -128,8 +128,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, return Status::OK(); } +Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { + const tensorflow::Tensor* t = nullptr; + + // TODO(nareshmodi): This call makes async calls sync calls. Fix this. + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + + t->shape().AsProto(proto); + + return Status::OK(); +} + Status EagerServiceImpl::ExecuteOp(const Operation& operation, - ServerContext* server_context) { + ServerContext* server_context, + QueueResponse* queue_response) { std::unique_ptr<tensorflow::EagerOperation> op; const char* name = operation.name().c_str(); // Shorthand const tensorflow::AttrTypeMap* types; @@ -172,6 +184,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, server_context->AddOperationOutputs(retvals, operation.id()); + for (auto* handle : retvals) { + TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape())); + } + return Status::OK(); } @@ -182,8 +198,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, core::ScopedUnref context_unref(context); for (const auto& item : request->queue()) { + auto* queue_response = response->add_queue_response(); if (item.has_operation()) { - TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context)); + TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response)); } else { TF_RETURN_IF_ERROR(context->DeleteTensorHandle( RemoteTensorHandleInternal(item.handle_to_decref()))); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index ebd5269a57..b0e4aa84b9 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -135,7 +135,8 @@ class EagerServiceImpl { tensorflow::Status GetServerContext(uint64, ServerContext**); private: - Status ExecuteOp(const Operation& operation, ServerContext* server_context); + Status ExecuteOp(const Operation& operation, ServerContext* server_context, + QueueResponse* queue_response); const WorkerEnv* const env_; // Not owned. mutex contexts_mu_; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 91b58698a4..b98386ba86 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -198,6 +198,11 @@ TEST_F(EagerServiceImplTest, BasicTest) { TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, &remote_enqueue_response)); + auto& matmul_result_shape = + remote_enqueue_response.queue_response(1).shape(0); + EXPECT_EQ(matmul_result_shape.dim(0).size(), 2); + EXPECT_EQ(matmul_result_shape.dim(1).size(), 2); + tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index c4bd67aaed..28b68c3b88 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ #include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/protobuf/eager_service.pb.h" @@ -27,6 +28,22 @@ namespace eager { // via RPC in a remote EagerService. class RemoteExecuteNode : public tensorflow::EagerNode { public: + RemoteExecuteNode( + tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, + tensorflow::eager::EagerClient* eager_client, + const gtl::InlinedVector<TensorHandle*, 4>& inputs, + std::function<void(const Status& status, const EnqueueResponse& response)> + done_callback) + : tensorflow::EagerNode(id), + request_(std::move(request)), + eager_client_(eager_client), + inputs_(inputs), + done_callback_(std::move(done_callback)) { + for (auto* handle : inputs_) { + handle->Ref(); + } + } + RemoteExecuteNode(tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, tensorflow::eager::EagerClient* eager_client) @@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode { request_(std::move(request)), eager_client_(eager_client) {} + ~RemoteExecuteNode() { + for (auto* handle : inputs_) { + handle->Unref(); + } + } + tensorflow::Status Run() override { tensorflow::eager::EnqueueResponse response; tensorflow::Status status; @@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode { }); n.WaitForNotification(); + if (done_callback_) { + done_callback_(status, response); + } + return status; } @@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode { EnqueueRequest request_; tensorflow::eager::EagerClient* eager_client_; // Not owned, and must outlive the RemoteExecuteNode. + + // This is required to ensure that the tensor handles stay alive across the + // execution. + gtl::InlinedVector<TensorHandle*, 4> inputs_; + + std::function<void(const Status& status, const EnqueueResponse& response)> + done_callback_; }; } // namespace eager |