aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-16 15:36:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 15:39:20 -0700
commitb027ac978f6ed03b634fde2a5ee3fa20d766921e (patch)
treee73eeabfbb6d9c0dc7a1eb88a9be6cbde1c2afd3
parentdee9561680141ff916f3f487e212b3106da23a2f (diff)
This CL fixes a bug preventing Eager tapes from working remotely.
Previously, tapes would attempt to access remote TensorHandle tensors directly (unsupported remotely) to get their shapes, causing an error. They now access the remote shape via a new TensorHandle `Shape` method, which can unify local and remote TensorHandle shape accesses. This CL also adds some tests to ensure taping during remote execution works. PiperOrigin-RevId: 204819435
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc13
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc8
3 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index f9b9abcc99..85b0b79bce 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -109,6 +109,19 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
return Status::OK();
}
+Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ CHECK(remote_shape_ != nullptr);
+ *shape = *(remote_shape_.get());
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *shape = tensor_.shape();
+ }
+ return Status::OK();
+}
+
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 46bc94f875..5580d37234 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status Shape(tensorflow::TensorShape* shape);
+
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index ec7e2371e9..4d28e98961 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1173,14 +1173,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = EagerTensor_id(tensor);
- const tensorflow::Tensor* tensor = nullptr;
- const tensorflow::Status status = t->handle->Tensor(&tensor);
+ tensorflow::TensorShape tensor_shape;
+ const tensorflow::Status status = t->handle->Shape(&tensor_shape);
+
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return tensorflow::eager::TapeTensor{id, t->handle->dtype,
tensorflow::TensorShape({})};
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensor->shape()};
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
}
}
tensorflow::int64 id = FastTensorId(tensor);