aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/tensor_handle.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/tensor_handle.cc')
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d37b..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(