aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/c_api.cc26
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc73
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc41
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h27
-rw-r--r--tensorflow/core/distributed_runtime/eager/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc21
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h34
-rw-r--r--tensorflow/core/protobuf/eager_service.proto7
10 files changed, 193 insertions, 45 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 00b474fe86..82ca2be2cf 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -156,12 +156,14 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
-#define LOG_AND_RETURN_IF_ERROR(...) \
- do { \
- const ::tensorflow::Status _status = (__VA_ARGS__); \
- LOG(ERROR) << _status.error_message(); \
- if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
- } while (0)
+#define LOG_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ LOG(ERROR) << _status.error_message(); \
+ return _status; \
+ } \
+ } while (0);
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
@@ -346,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dims();
+ int result;
+ status->status = h->handle->NumDims(&result);
+ return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dim_size(dim_index);
+ tensorflow::int64 result;
+ status->status = h->handle->Dim(dim_index, &result);
+ return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index bf60d05e96..60dd848b20 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -623,22 +624,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
request.set_context_id(context_id);
- if (op->EagerContext()->Async()) {
- tensorflow::uint64 id = op->EagerContext()->NextId();
- auto* node = new eager::RemoteExecuteNode(id, request, eager_client);
- op->EagerContext()->ExecutorAdd(node);
- } else {
- Notification n;
- Status status;
- eager_client->EnqueueAsync(&request, &response,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- if (!status.ok()) return status;
- }
-
DataTypeVector output_dtypes;
TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
@@ -649,6 +634,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
tensorflow::Device* op_device = op->Device();
+ bool is_async = op->EagerContext()->Async();
+ uint64 remote_node_id = 0;
+
+ if (is_async) {
+ remote_node_id = op->EagerContext()->NextId();
+ }
+
const tensorflow::uint64 id = remote_op->id();
for (int i = 0; i < *num_retvals; i++) {
// TODO(nareshmodi): Change the callback to instead add the decref to a list
@@ -676,9 +668,52 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return tensorflow::Status::OK();
};
- retvals[i] = new TensorHandle(remote_op->id(), i, output_dtypes[i],
- std::move(callback), op_device, op_device,
- op->EagerContext());
+
+ retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
+ output_dtypes[i], std::move(callback),
+ op_device, op_device, op->EagerContext());
+ }
+
+ if (is_async) {
+ // Copy the output handles, since the container for them might get
+ // destroyed.
+ gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals_copy.push_back(retvals[i]);
+ retvals_copy[i]->Ref();
+ }
+ // Unable to capture via std::move, so bind instead.
+ auto* node = new eager::RemoteExecuteNode(
+ remote_node_id, request, eager_client, op->Inputs(),
+ std::bind(
+ [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
+ const Status& status, const eager::EnqueueResponse& response) {
+ if (!status.ok()) return;
+ for (int i = 0; i < retvals.size(); i++) {
+ retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
+ response.queue_response(0).shape(i)));
+ retvals[i]->Unref();
+ }
+ },
+ std::move(retvals_copy), std::placeholders::_1,
+ std::placeholders::_2));
+ op->EagerContext()->ExecutorAdd(node);
+ } else {
+ Notification n;
+ Status status;
+ eager_client->EnqueueAsync(&request, &response,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals[i]->SetRemoteShape(
+ MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
+ }
+
+ if (!status.ok()) return status;
}
return Status::OK();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 431e8299dc..5d64c8c5b9 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -45,7 +45,7 @@ limitations under the License.
namespace tensorflow {
bool TensorHandle::IsReady() {
- if (node_id == 0) return true;
+ if (node_id_ == 0) return true;
mutex_lock l(ctx_mutex_);
return is_ready_;
}
@@ -54,17 +54,19 @@ bool TensorHandle::IsRemote() {
return remote_op_id_ >= 0 && remote_output_num_ >= 0;
}
-Status TensorHandle::WaitReady() {
+Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) {
if (node_id == 0) return Status::OK();
EagerExecutor* executor = nullptr;
{
mutex_lock l(ctx_mutex_);
- if (is_ready_) return Status::OK();
+ if (return_if_is_ready && is_ready_) return Status::OK();
executor = ctx_->Executor();
}
return executor->WaitFor(node_id);
}
+Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); }
+
Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
if (IsRemote()) {
return errors::Unavailable(
@@ -107,6 +109,37 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
return Status::OK();
}
+Status TensorHandle::NumDims(int* num_dims) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ CHECK(remote_shape_ != nullptr);
+ *num_dims = remote_shape_->dims();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_dims != nullptr);
+
+ *num_dims = tensor_.dims();
+ }
+
+ return Status::OK();
+}
+
+Status TensorHandle::Dim(int dim_index, int64* dim) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *dim = remote_shape_->dim_size(dim_index);
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(dim != nullptr);
+
+ *dim = tensor_.dim_size(dim_index);
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
@@ -122,7 +155,7 @@ void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
tensorflow::Device* device,
tensorflow::Device* op_device) {
mutex_lock l(ctx_mutex_);
- DCHECK(node_id > 0 && !is_ready_)
+ DCHECK(node_id_ > 0 && !is_ready_)
<< "SetTensorAndDevice should be only called "
<< "on non-ready handles.";
is_ready_ = true;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 4314b6bd4e..46bc94f875 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -51,38 +51,41 @@ class TensorHandle : public core::RefCounted {
public:
TensorHandle(const Tensor& t, Device* d, Device* op_device, EagerContext* ctx)
: dtype(t.dtype()),
- node_id(0),
+ node_id_(0),
tensor_(t),
device_(d),
op_device_(op_device),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(true) {}
TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
: dtype(dtype),
- node_id(node_id),
+ node_id_(node_id),
tensor_(dtype),
device_(nullptr),
op_device_(nullptr),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(ctx == nullptr) {
- DCHECK_GT(node_id, 0);
+ DCHECK_GT(node_id_, 0);
}
// Remote tensor handle constructor.
- TensorHandle(int64 op_id, int32 output_num, DataType dtype,
- std::function<void()> call_on_destroy, Device* d,
+ TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id,
+ DataType dtype, std::function<void()> call_on_destroy, Device* d,
Device* op_device, EagerContext* ctx)
: dtype(dtype),
- node_id(0),
+ node_id_(0),
device_(d),
op_device_(op_device),
remote_op_id_(op_id),
remote_output_num_(output_num),
+ remote_shape_node_id_(remote_shape_node_id),
call_on_destroy_(std::move(call_on_destroy)),
ctx_(ctx),
is_ready_(true) {
@@ -106,6 +109,9 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status NumDims(int* num_dims);
+ Status Dim(int dim_index, int64* dim);
+
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
@@ -128,11 +134,16 @@ class TensorHandle : public core::RefCounted {
// ready.
const DataType dtype;
+ void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) {
+ remote_shape_ = std::move(remote_shape);
+ }
+
private:
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that compuatation is
// done and the handle is "ready".
Status WaitReady();
+ Status WaitForNode(uint64 node_id, bool return_if_is_ready);
bool IsReady();
@@ -140,7 +151,7 @@ class TensorHandle : public core::RefCounted {
// Id for the EagerNode that will compute the value pointed to by this handle.
// If the value is 0, the handle is already ready, but not vice-versa.
- const uint64 node_id;
+ const uint64 node_id_;
tensorflow::Tensor tensor_;
@@ -161,6 +172,8 @@ class TensorHandle : public core::RefCounted {
// IDs required when this class is representing a remote tensor handle.
const int64 remote_op_id_;
const int32 remote_output_num_;
+ std::unique_ptr<TensorShape> remote_shape_;
+ const uint64 remote_shape_node_id_;
// A callback that is executed when the class is destroyed.
//
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 710bd8d021..22d0902af2 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 2fa234c810..5a26d5bf48 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -128,8 +128,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return Status::OK();
}
+Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
+ const tensorflow::Tensor* t = nullptr;
+
+ // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
+ TF_RETURN_IF_ERROR(handle->Tensor(&t));
+
+ t->shape().AsProto(proto);
+
+ return Status::OK();
+}
+
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
- ServerContext* server_context) {
+ ServerContext* server_context,
+ QueueResponse* queue_response) {
std::unique_ptr<tensorflow::EagerOperation> op;
const char* name = operation.name().c_str(); // Shorthand
const tensorflow::AttrTypeMap* types;
@@ -172,6 +184,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
server_context->AddOperationOutputs(retvals, operation.id());
+ for (auto* handle : retvals) {
+ TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
+ }
+
return Status::OK();
}
@@ -182,8 +198,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
core::ScopedUnref context_unref(context);
for (const auto& item : request->queue()) {
+ auto* queue_response = response->add_queue_response();
if (item.has_operation()) {
- TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context));
+ TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
} else {
TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
RemoteTensorHandleInternal(item.handle_to_decref())));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index ebd5269a57..b0e4aa84b9 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -135,7 +135,8 @@ class EagerServiceImpl {
tensorflow::Status GetServerContext(uint64, ServerContext**);
private:
- Status ExecuteOp(const Operation& operation, ServerContext* server_context);
+ Status ExecuteOp(const Operation& operation, ServerContext* server_context,
+ QueueResponse* queue_response);
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 91b58698a4..b98386ba86 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -198,6 +198,11 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
+ auto& matmul_result_shape =
+ remote_enqueue_response.queue_response(1).shape(0);
+ EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
+ EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
+
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index c4bd67aaed..28b68c3b88 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -27,6 +28,22 @@ namespace eager {
// via RPC in a remote EagerService.
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
+ RemoteExecuteNode(
+ tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
+ tensorflow::eager::EagerClient* eager_client,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback)
+ : tensorflow::EagerNode(id),
+ request_(std::move(request)),
+ eager_client_(eager_client),
+ inputs_(inputs),
+ done_callback_(std::move(done_callback)) {
+ for (auto* handle : inputs_) {
+ handle->Ref();
+ }
+ }
+
RemoteExecuteNode(tensorflow::uint64 id,
const tensorflow::eager::EnqueueRequest& request,
tensorflow::eager::EagerClient* eager_client)
@@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
request_(std::move(request)),
eager_client_(eager_client) {}
+ ~RemoteExecuteNode() {
+ for (auto* handle : inputs_) {
+ handle->Unref();
+ }
+ }
+
tensorflow::Status Run() override {
tensorflow::eager::EnqueueResponse response;
tensorflow::Status status;
@@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
});
n.WaitForNotification();
+ if (done_callback_) {
+ done_callback_(status, response);
+ }
+
return status;
}
@@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
EnqueueRequest request_;
tensorflow::eager::EagerClient*
eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+
+ // This is required to ensure that the tensor handles stay alive across the
+ // execution.
+ gtl::InlinedVector<TensorHandle*, 4> inputs_;
+
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback_;
};
} // namespace eager
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
index 50294b8a42..5b05a1b3ee 100644
--- a/tensorflow/core/protobuf/eager_service.proto
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -7,6 +7,7 @@ import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
import "tensorflow/core/protobuf/tensorflow_server.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
message RemoteTensorHandle {
// The ID of the operation that produced this tensor.
@@ -45,6 +46,10 @@ message QueueItem {
}
}
+message QueueResponse {
+ repeated TensorShapeProto shape = 1;
+}
+
message CreateContextRequest {
// Identifies the full cluster, and this particular worker's position within.
ServerDef server_def = 1;
@@ -84,6 +89,8 @@ message EnqueueRequest {
}
message EnqueueResponse {
+ // A single operation response for every item in the request.
+ repeated QueueResponse queue_response = 1;
}
message WaitQueueDoneRequest {