diff options
10 files changed, 193 insertions, 45 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 00b474fe86..82ca2be2cf 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -156,12 +156,14 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - LOG(ERROR) << _status.error_message(); \ - if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ - } while (0) +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), @@ -346,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dims(); + int result; + status->status = h->handle->NumDims(&result); + return result; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dim_size(dim_index); + tensorflow::int64 result; + status->status = h->handle->Dim(dim_index, &result); + return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index bf60d05e96..60dd848b20 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -623,22 +624,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, request.set_context_id(context_id); - if (op->EagerContext()->Async()) { - tensorflow::uint64 id = op->EagerContext()->NextId(); - auto* node = new eager::RemoteExecuteNode(id, request, eager_client); - op->EagerContext()->ExecutorAdd(node); - } else { - Notification n; - Status status; - eager_client->EnqueueAsync(&request, &response, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - if (!status.ok()) return status; - } - DataTypeVector output_dtypes; TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes)); @@ -649,6 +634,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, tensorflow::Device* op_device = op->Device(); + bool is_async = op->EagerContext()->Async(); + uint64 remote_node_id = 0; + + if (is_async) { + remote_node_id = op->EagerContext()->NextId(); + } + const tensorflow::uint64 id = remote_op->id(); for (int i = 0; i < *num_retvals; i++) { // TODO(nareshmodi): Change the callback to instead add the decref to a list @@ -676,9 +668,52 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, return tensorflow::Status::OK(); }; - retvals[i] = new TensorHandle(remote_op->id(), i, output_dtypes[i], - std::move(callback), op_device, op_device, - op->EagerContext()); + + retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id, + output_dtypes[i], std::move(callback), + op_device, op_device, op->EagerContext()); + } + + if (is_async) { + // Copy the output handles, since the container for them might get + // destroyed. + gtl::InlinedVector<TensorHandle*, 2> retvals_copy; + for (int i = 0; i < *num_retvals; i++) { + retvals_copy.push_back(retvals[i]); + retvals_copy[i]->Ref(); + } + // Unable to capture via std::move, so bind instead. + auto* node = new eager::RemoteExecuteNode( + remote_node_id, request, eager_client, op->Inputs(), + std::bind( + [](const gtl::InlinedVector<TensorHandle*, 2>& retvals, + const Status& status, const eager::EnqueueResponse& response) { + if (!status.ok()) return; + for (int i = 0; i < retvals.size(); i++) { + retvals[i]->SetRemoteShape(MakeUnique<TensorShape>( + response.queue_response(0).shape(i))); + retvals[i]->Unref(); + } + }, + std::move(retvals_copy), std::placeholders::_1, + std::placeholders::_2)); + op->EagerContext()->ExecutorAdd(node); + } else { + Notification n; + Status status; + eager_client->EnqueueAsync(&request, &response, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + + for (int i = 0; i < *num_retvals; i++) { + retvals[i]->SetRemoteShape( + MakeUnique<TensorShape>(response.queue_response(0).shape(i))); + } + + if (!status.ok()) return status; } return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 431e8299dc..5d64c8c5b9 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -45,7 +45,7 @@ limitations under the License. namespace tensorflow { bool TensorHandle::IsReady() { - if (node_id == 0) return true; + if (node_id_ == 0) return true; mutex_lock l(ctx_mutex_); return is_ready_; } @@ -54,17 +54,19 @@ bool TensorHandle::IsRemote() { return remote_op_id_ >= 0 && remote_output_num_ >= 0; } -Status TensorHandle::WaitReady() { +Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) { if (node_id == 0) return Status::OK(); EagerExecutor* executor = nullptr; { mutex_lock l(ctx_mutex_); - if (is_ready_) return Status::OK(); + if (return_if_is_ready && is_ready_) return Status::OK(); executor = ctx_->Executor(); } return executor->WaitFor(node_id); } +Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); } + Status TensorHandle::Tensor(const tensorflow::Tensor** t) { if (IsRemote()) { return errors::Unavailable( @@ -107,6 +109,37 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor, return Status::OK(); } +Status TensorHandle::NumDims(int* num_dims) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + CHECK(remote_shape_ != nullptr); + *num_dims = remote_shape_->dims(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_dims != nullptr); + + *num_dims = tensor_.dims(); + } + + return Status::OK(); +} + +Status TensorHandle::Dim(int dim_index, int64* dim) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *dim = remote_shape_->dim_size(dim_index); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(dim != nullptr); + + *dim = tensor_.dim_size(dim_index); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( @@ -122,7 +155,7 @@ void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, tensorflow::Device* device, tensorflow::Device* op_device) { mutex_lock l(ctx_mutex_); - DCHECK(node_id > 0 && !is_ready_) + DCHECK(node_id_ > 0 && !is_ready_) << "SetTensorAndDevice should be only called " << "on non-ready handles."; is_ready_ = true; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 4314b6bd4e..46bc94f875 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -51,38 +51,41 @@ class TensorHandle : public core::RefCounted { public: TensorHandle(const Tensor& t, Device* d, Device* op_device, EagerContext* ctx) : dtype(t.dtype()), - node_id(0), + node_id_(0), tensor_(t), device_(d), op_device_(op_device), remote_op_id_(-1), remote_output_num_(-1), + remote_shape_node_id_(-1), ctx_(ctx), is_ready_(true) {} TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx) : dtype(dtype), - node_id(node_id), + node_id_(node_id), tensor_(dtype), device_(nullptr), op_device_(nullptr), remote_op_id_(-1), remote_output_num_(-1), + remote_shape_node_id_(-1), ctx_(ctx), is_ready_(ctx == nullptr) { - DCHECK_GT(node_id, 0); + DCHECK_GT(node_id_, 0); } // Remote tensor handle constructor. - TensorHandle(int64 op_id, int32 output_num, DataType dtype, - std::function<void()> call_on_destroy, Device* d, + TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id, + DataType dtype, std::function<void()> call_on_destroy, Device* d, Device* op_device, EagerContext* ctx) : dtype(dtype), - node_id(0), + node_id_(0), device_(d), op_device_(op_device), remote_op_id_(op_id), remote_output_num_(output_num), + remote_shape_node_id_(remote_shape_node_id), call_on_destroy_(std::move(call_on_destroy)), ctx_(ctx), is_ready_(true) { @@ -106,6 +109,9 @@ class TensorHandle : public core::RefCounted { tensorflow::Device** device, tensorflow::Device** op_device); + Status NumDims(int* num_dims); + Status Dim(int dim_index, int64* dim); + // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(int64* op_id, int32* output_num); @@ -128,11 +134,16 @@ class TensorHandle : public core::RefCounted { // ready. const DataType dtype; + void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) { + remote_shape_ = std::move(remote_shape); + } + private: // If the contents of the Tensor pointed to by this handle is yet to be // computed by a EagerNode, this function will block till that compuatation is // done and the handle is "ready". Status WaitReady(); + Status WaitForNode(uint64 node_id, bool return_if_is_ready); bool IsReady(); @@ -140,7 +151,7 @@ class TensorHandle : public core::RefCounted { // Id for the EagerNode that will compute the value pointed to by this handle. // If the value is 0, the handle is already ready, but not vice-versa. - const uint64 node_id; + const uint64 node_id_; tensorflow::Tensor tensor_; @@ -161,6 +172,8 @@ class TensorHandle : public core::RefCounted { // IDs required when this class is representing a remote tensor handle. const int64 remote_op_id_; const int32 remote_output_num_; + std::unique_ptr<TensorShape> remote_shape_; + const uint64 remote_shape_node_id_; // A callback that is executed when the class is destroyed. // diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 710bd8d021..22d0902af2 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 2fa234c810..5a26d5bf48 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -128,8 +128,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, return Status::OK(); } +Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { + const tensorflow::Tensor* t = nullptr; + + // TODO(nareshmodi): This call makes async calls sync calls. Fix this. + TF_RETURN_IF_ERROR(handle->Tensor(&t)); + + t->shape().AsProto(proto); + + return Status::OK(); +} + Status EagerServiceImpl::ExecuteOp(const Operation& operation, - ServerContext* server_context) { + ServerContext* server_context, + QueueResponse* queue_response) { std::unique_ptr<tensorflow::EagerOperation> op; const char* name = operation.name().c_str(); // Shorthand const tensorflow::AttrTypeMap* types; @@ -172,6 +184,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation, server_context->AddOperationOutputs(retvals, operation.id()); + for (auto* handle : retvals) { + TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape())); + } + return Status::OK(); } @@ -182,8 +198,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, core::ScopedUnref context_unref(context); for (const auto& item : request->queue()) { + auto* queue_response = response->add_queue_response(); if (item.has_operation()) { - TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context)); + TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response)); } else { TF_RETURN_IF_ERROR(context->DeleteTensorHandle( RemoteTensorHandleInternal(item.handle_to_decref()))); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index ebd5269a57..b0e4aa84b9 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -135,7 +135,8 @@ class EagerServiceImpl { tensorflow::Status GetServerContext(uint64, ServerContext**); private: - Status ExecuteOp(const Operation& operation, ServerContext* server_context); + Status ExecuteOp(const Operation& operation, ServerContext* server_context, + QueueResponse* queue_response); const WorkerEnv* const env_; // Not owned. mutex contexts_mu_; diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 91b58698a4..b98386ba86 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -198,6 +198,11 @@ TEST_F(EagerServiceImplTest, BasicTest) { TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, &remote_enqueue_response)); + auto& matmul_result_shape = + remote_enqueue_response.queue_response(1).shape(0); + EXPECT_EQ(matmul_result_shape.dim(0).size(), 2); + EXPECT_EQ(matmul_result_shape.dim(1).size(), 2); + tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index c4bd67aaed..28b68c3b88 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ #include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/protobuf/eager_service.pb.h" @@ -27,6 +28,22 @@ namespace eager { // via RPC in a remote EagerService. class RemoteExecuteNode : public tensorflow::EagerNode { public: + RemoteExecuteNode( + tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, + tensorflow::eager::EagerClient* eager_client, + const gtl::InlinedVector<TensorHandle*, 4>& inputs, + std::function<void(const Status& status, const EnqueueResponse& response)> + done_callback) + : tensorflow::EagerNode(id), + request_(std::move(request)), + eager_client_(eager_client), + inputs_(inputs), + done_callback_(std::move(done_callback)) { + for (auto* handle : inputs_) { + handle->Ref(); + } + } + RemoteExecuteNode(tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request, tensorflow::eager::EagerClient* eager_client) @@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode { request_(std::move(request)), eager_client_(eager_client) {} + ~RemoteExecuteNode() { + for (auto* handle : inputs_) { + handle->Unref(); + } + } + tensorflow::Status Run() override { tensorflow::eager::EnqueueResponse response; tensorflow::Status status; @@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode { }); n.WaitForNotification(); + if (done_callback_) { + done_callback_(status, response); + } + return status; } @@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode { EnqueueRequest request_; tensorflow::eager::EagerClient* eager_client_; // Not owned, and must outlive the RemoteExecuteNode. + + // This is required to ensure that the tensor handles stay alive across the + // execution. + gtl::InlinedVector<TensorHandle*, 4> inputs_; + + std::function<void(const Status& status, const EnqueueResponse& response)> + done_callback_; }; } // namespace eager diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 50294b8a42..5b05a1b3ee 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -7,6 +7,7 @@ import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/function.proto"; import "tensorflow/core/framework/versions.proto"; import "tensorflow/core/protobuf/tensorflow_server.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; message RemoteTensorHandle { // The ID of the operation that produced this tensor. @@ -45,6 +46,10 @@ message QueueItem { } } +message QueueResponse { + repeated TensorShapeProto shape = 1; +} + message CreateContextRequest { // Identifies the full cluster, and this particular worker's position within. ServerDef server_def = 1; @@ -84,6 +89,8 @@ message EnqueueRequest { } message EnqueueResponse { + // A single operation response for every item in the request. + repeated QueueResponse queue_response = 1; } message WaitQueueDoneRequest { |