diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/tensor_handle.h')
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.h | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 46bc94f875..1bc9c6531a 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted { tensorflow::Device** device, tensorflow::Device** op_device); + Status Shape(tensorflow::TensorShape* shape); + Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); @@ -138,6 +140,12 @@ class TensorHandle : public core::RefCounted { remote_shape_ = std::move(remote_shape); } + bool OnHostCPU() { + mutex_lock ml(ctx_mutex_); + return device_ == nullptr || + (ctx_ == nullptr || ctx_->HostCPU() == device_); + } + private: // If the contents of the Tensor pointed to by this handle is yet to be // computed by a EagerNode, this function will block till that compuatation is |