diff options
author | 2018-07-16 15:36:38 -0700 | |
---|---|---|
committer | 2018-07-16 15:39:20 -0700 | |
commit | b027ac978f6ed03b634fde2a5ee3fa20d766921e (patch) | |
tree | e73eeabfbb6d9c0dc7a1eb88a9be6cbde1c2afd3 | |
parent | dee9561680141ff916f3f487e212b3106da23a2f (diff) |
This CL fixes a bug preventing Eager tapes from working remotely.
Previously, tapes would attempt to access remote TensorHandle tensors directly (unsupported remotely) to get their shapes, causing an error.
They now access the remote shape via a new TensorHandle `Shape` method, which can unify local and remote TensorHandle shape accesses.
This CL also adds some tests to ensure taping during remote execution works.
PiperOrigin-RevId: 204819435
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.h | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 8 |
3 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index f9b9abcc99..85b0b79bce 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -109,6 +109,19 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor, return Status::OK(); } +Status TensorHandle::Shape(tensorflow::TensorShape* shape) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + CHECK(remote_shape_ != nullptr); + *shape = *(remote_shape_.get()); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *shape = tensor_.shape(); + } + return Status::OK(); +} + Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 46bc94f875..5580d37234 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted { tensorflow::Device** device, tensorflow::Device** op_device); + Status Shape(tensorflow::TensorShape* shape); + Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index ec7e2371e9..4d28e98961 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1173,14 +1173,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); - const tensorflow::Tensor* tensor = nullptr; - const tensorflow::Status status = t->handle->Tensor(&tensor); + tensorflow::TensorShape tensor_shape; + const tensorflow::Status status = t->handle->Shape(&tensor_shape); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensorflow::TensorShape({})}; } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensor->shape()}; + return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; } } tensorflow::int64 id = FastTensorId(tensor); |