aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-26 15:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 15:43:10 -0700
commit79c828ea6ddbcfccd43a2be176fc1dcad4daf34e (patch)
treeadd42271d2d23096b03af3bf8fb89384cffc9a54
parentec34de06981eed74c2c2a47c8a6372735e9d3622 (diff)
Support shapes for remote eager tensor handles.
Since we respond with the shape, all RPCs will happen sync (note that we may still hide the python overhead, since the op is still scheduled for execution via the eager executor). PiperOrigin-RevId: 202207324
-rw-r--r--tensorflow/c/eager/c_api.cc26
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc73
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc41
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h27
-rw-r--r--tensorflow/core/distributed_runtime/eager/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc21
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h34
-rw-r--r--tensorflow/core/protobuf/eager_service.proto7
10 files changed, 193 insertions, 45 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 00b474fe86..82ca2be2cf 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -156,12 +156,14 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
-#define LOG_AND_RETURN_IF_ERROR(...) \
- do { \
- const ::tensorflow::Status _status = (__VA_ARGS__); \
- LOG(ERROR) << _status.error_message(); \
- if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
- } while (0)
+#define LOG_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ LOG(ERROR) << _status.error_message(); \
+ return _status; \
+ } \
+ } while (0);
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
@@ -346,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dims();
+ int result;
+ status->status = h->handle->NumDims(&result);
+ return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dim_size(dim_index);
+ tensorflow::int64 result;
+ status->status = h->handle->Dim(dim_index, &result);
+ return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index bf60d05e96..60dd848b20 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -623,22 +624,6 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
request.set_context_id(context_id);
- if (op->EagerContext()->Async()) {
- tensorflow::uint64 id = op->EagerContext()->NextId();
- auto* node = new eager::RemoteExecuteNode(id, request, eager_client);
- op->EagerContext()->ExecutorAdd(node);
- } else {
- Notification n;
- Status status;
- eager_client->EnqueueAsync(&request, &response,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- if (!status.ok()) return status;
- }
-
DataTypeVector output_dtypes;
TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
@@ -649,6 +634,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
tensorflow::Device* op_device = op->Device();
+ bool is_async = op->EagerContext()->Async();
+ uint64 remote_node_id = 0;
+
+ if (is_async) {
+ remote_node_id = op->EagerContext()->NextId();
+ }
+
const tensorflow::uint64 id = remote_op->id();
for (int i = 0; i < *num_retvals; i++) {
// TODO(nareshmodi): Change the callback to instead add the decref to a list
@@ -676,9 +668,52 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return tensorflow::Status::OK();
};
- retvals[i] = new TensorHandle(remote_op->id(), i, output_dtypes[i],
- std::move(callback), op_device, op_device,
- op->EagerContext());
+
+ retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
+ output_dtypes[i], std::move(callback),
+ op_device, op_device, op->EagerContext());
+ }
+
+ if (is_async) {
+ // Copy the output handles, since the container for them might get
+ // destroyed.
+ gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals_copy.push_back(retvals[i]);
+ retvals_copy[i]->Ref();
+ }
+ // Unable to capture via std::move, so bind instead.
+ auto* node = new eager::RemoteExecuteNode(
+ remote_node_id, request, eager_client, op->Inputs(),
+ std::bind(
+ [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
+ const Status& status, const eager::EnqueueResponse& response) {
+ if (!status.ok()) return;
+ for (int i = 0; i < retvals.size(); i++) {
+ retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
+ response.queue_response(0).shape(i)));
+ retvals[i]->Unref();
+ }
+ },
+ std::move(retvals_copy), std::placeholders::_1,
+ std::placeholders::_2));
+ op->EagerContext()->ExecutorAdd(node);
+ } else {
+ Notification n;
+ Status status;
+ eager_client->EnqueueAsync(&request, &response,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals[i]->SetRemoteShape(
+ MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
+ }
+
+ if (!status.ok()) return status;
}
return Status::OK();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 431e8299dc..5d64c8c5b9 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -45,7 +45,7 @@ limitations under the License.
namespace tensorflow {
bool TensorHandle::IsReady() {
- if (node_id == 0) return true;
+ if (node_id_ == 0) return true;
mutex_lock l(ctx_mutex_);
return is_ready_;
}
@@ -54,17 +54,19 @@ bool TensorHandle::IsRemote() {
return remote_op_id_ >= 0 && remote_output_num_ >= 0;
}
-Status TensorHandle::WaitReady() {
+Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) {
if (node_id == 0) return Status::OK();
EagerExecutor* executor = nullptr;
{
mutex_lock l(ctx_mutex_);
- if (is_ready_) return Status::OK();
+ if (return_if_is_ready && is_ready_) return Status::OK();
executor = ctx_->Executor();
}
return executor->WaitFor(node_id);
}
+Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); }
+
Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
if (IsRemote()) {
return errors::Unavailable(
@@ -107,6 +109,37 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
return Status::OK();
}
+Status TensorHandle::NumDims(int* num_dims) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ CHECK(remote_shape_ != nullptr);
+ *num_dims = remote_shape_->dims();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_dims != nullptr);
+
+ *num_dims = tensor_.dims();
+ }
+
+ return Status::OK();
+}
+
+Status TensorHandle::Dim(int dim_index, int64* dim) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *dim = remote_shape_->dim_size(dim_index);
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(dim != nullptr);
+
+ *dim = tensor_.dim_size(dim_index);
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
@@ -122,7 +155,7 @@ void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
tensorflow::Device* device,
tensorflow::Device* op_device) {
mutex_lock l(ctx_mutex_);
- DCHECK(node_id > 0 && !is_ready_)
+ DCHECK(node_id_ > 0 && !is_ready_)
<< "SetTensorAndDevice should be only called "
<< "on non-ready handles.";
is_ready_ = true;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 4314b6bd4e..46bc94f875 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -51,38 +51,41 @@ class TensorHandle : public core::RefCounted {
public:
TensorHandle(const Tensor& t, Device* d, Device* op_device, EagerContext* ctx)
: dtype(t.dtype()),
- node_id(0),
+ node_id_(0),
tensor_(t),
device_(d),
op_device_(op_device),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(true) {}
TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
: dtype(dtype),
- node_id(node_id),
+ node_id_(node_id),
tensor_(dtype),
device_(nullptr),
op_device_(nullptr),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(ctx == nullptr) {
- DCHECK_GT(node_id, 0);
+ DCHECK_GT(node_id_, 0);
}
// Remote tensor handle constructor.
- TensorHandle(int64 op_id, int32 output_num, DataType dtype,
- std::function<void()> call_on_destroy, Device* d,
+ TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id,
+ DataType dtype, std::function<void()> call_on_destroy, Device* d,
Device* op_device, EagerContext* ctx)
: dtype(dtype),
- node_id(0),
+ node_id_(0),
device_(d),
op_device_(op_device),
remote_op_id_(op_id),
remote_output_num_(output_num),
+ remote_shape_node_id_(remote_shape_node_id),
call_on_destroy_(std::move(call_on_destroy)),
ctx_(ctx),
is_ready_(true) {
@@ -106,6 +109,9 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status NumDims(int* num_dims);
+ Status Dim(int dim_index, int64* dim);
+
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
@@ -128,11 +134,16 @@ class TensorHandle : public core::RefCounted {
// ready.
const DataType dtype;
+ void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) {
+ remote_shape_ = std::move(remote_shape);
+ }
+
private:
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that compuatation is
// done and the handle is "ready".
Status WaitReady();
+ Status WaitForNode(uint64 node_id, bool return_if_is_ready);
bool IsReady();
@@ -140,7 +151,7 @@ class TensorHandle : public core::RefCounted {
// Id for the EagerNode that will compute the value pointed to by this handle.
// If the value is 0, the handle is already ready, but not vice-versa.
- const uint64 node_id;
+ const uint64 node_id_;
tensorflow::Tensor tensor_;
@@ -161,6 +172,8 @@ class TensorHandle : public core::RefCounted {
// IDs required when this class is representing a remote tensor handle.
const int64 remote_op_id_;
const int32 remote_output_num_;
+ std::unique_ptr<TensorShape> remote_shape_;
+ const uint64 remote_shape_node_id_;
// A callback that is executed when the class is destroyed.
//
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 710bd8d021..22d0902af2 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 2fa234c810..5a26d5bf48 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -128,8 +128,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return Status::OK();
}
+Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
+ const tensorflow::Tensor* t = nullptr;
+
+ // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
+ TF_RETURN_IF_ERROR(handle->Tensor(&t));
+
+ t->shape().AsProto(proto);
+
+ return Status::OK();
+}
+
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
- ServerContext* server_context) {
+ ServerContext* server_context,
+ QueueResponse* queue_response) {
std::unique_ptr<tensorflow::EagerOperation> op;
const char* name = operation.name().c_str(); // Shorthand
const tensorflow::AttrTypeMap* types;
@@ -172,6 +184,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
server_context->AddOperationOutputs(retvals, operation.id());
+ for (auto* handle : retvals) {
+ TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
+ }
+
return Status::OK();
}
@@ -182,8 +198,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
core::ScopedUnref context_unref(context);
for (const auto& item : request->queue()) {
+ auto* queue_response = response->add_queue_response();
if (item.has_operation()) {
- TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context));
+ TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
} else {
TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
RemoteTensorHandleInternal(item.handle_to_decref())));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index ebd5269a57..b0e4aa84b9 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -135,7 +135,8 @@ class EagerServiceImpl {
tensorflow::Status GetServerContext(uint64, ServerContext**);
private:
- Status ExecuteOp(const Operation& operation, ServerContext* server_context);
+ Status ExecuteOp(const Operation& operation, ServerContext* server_context,
+ QueueResponse* queue_response);
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 91b58698a4..b98386ba86 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -198,6 +198,11 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
+ auto& matmul_result_shape =
+ remote_enqueue_response.queue_response(1).shape(0);
+ EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
+ EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
+
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index c4bd67aaed..28b68c3b88 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -27,6 +28,22 @@ namespace eager {
// via RPC in a remote EagerService.
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
+ RemoteExecuteNode(
+ tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
+ tensorflow::eager::EagerClient* eager_client,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback)
+ : tensorflow::EagerNode(id),
+ request_(std::move(request)),
+ eager_client_(eager_client),
+ inputs_(inputs),
+ done_callback_(std::move(done_callback)) {
+ for (auto* handle : inputs_) {
+ handle->Ref();
+ }
+ }
+
RemoteExecuteNode(tensorflow::uint64 id,
const tensorflow::eager::EnqueueRequest& request,
tensorflow::eager::EagerClient* eager_client)
@@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
request_(std::move(request)),
eager_client_(eager_client) {}
+ ~RemoteExecuteNode() {
+ for (auto* handle : inputs_) {
+ handle->Unref();
+ }
+ }
+
tensorflow::Status Run() override {
tensorflow::eager::EnqueueResponse response;
tensorflow::Status status;
@@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
});
n.WaitForNotification();
+ if (done_callback_) {
+ done_callback_(status, response);
+ }
+
return status;
}
@@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
EnqueueRequest request_;
tensorflow::eager::EagerClient*
eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+
+ // This is required to ensure that the tensor handles stay alive across the
+ // execution.
+ gtl::InlinedVector<TensorHandle*, 4> inputs_;
+
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback_;
};
} // namespace eager
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
index 50294b8a42..5b05a1b3ee 100644
--- a/tensorflow/core/protobuf/eager_service.proto
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -7,6 +7,7 @@ import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
import "tensorflow/core/protobuf/tensorflow_server.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
message RemoteTensorHandle {
// The ID of the operation that produced this tensor.
@@ -45,6 +46,10 @@ message QueueItem {
}
}
+message QueueResponse {
+ repeated TensorShapeProto shape = 1;
+}
+
message CreateContextRequest {
// Identifies the full cluster, and this particular worker's position within.
ServerDef server_def = 1;
@@ -84,6 +89,8 @@ message EnqueueRequest {
}
message EnqueueResponse {
+ // A single operation response for every item in the request.
+ repeated QueueResponse queue_response = 1;
}
message WaitQueueDoneRequest {