diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-17 18:42:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 18:46:51 -0700 |
commit | 3b7ca4b86416f6b6153de90bc1df6e6e5b41934c (patch) | |
tree | e1eb70719e1b8b4edc572d02e18bcedbd8c00d9b /tensorflow/core/common_runtime | |
parent | 71fab28dc4741dedf13fea732f6b134608719bc7 (diff) |
Num elements fastpath for eager tensors.
PiperOrigin-RevId: 213377426
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.h | 1 |
2 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index b912f7d37b..d58724cbfa 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); - CHECK(remote_shape_ != nullptr); *num_dims = remote_shape_->dims(); } else { TF_RETURN_IF_ERROR(WaitReady()); @@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) { return Status::OK(); } +Status TensorHandle::NumElements(int64* num_elements) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *num_elements = remote_shape_->num_elements(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_elements != nullptr); + + *num_elements = tensor_.NumElements(); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 1bc9c6531a..e55f1a0338 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted { Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); + Status NumElements(int64* num_elements); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(int64* op_id, int32* output_num); |