diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/tensor_handle.cc')
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.cc | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index b912f7d37b..d58724cbfa 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); - CHECK(remote_shape_ != nullptr); *num_dims = remote_shape_->dims(); } else { TF_RETURN_IF_ERROR(WaitReady()); @@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) { return Status::OK(); } +Status TensorHandle::NumElements(int64* num_elements) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *num_elements = remote_shape_->num_elements(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_elements != nullptr); + + *num_elements = tensor_.NumElements(); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( |