aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/tensor_handle.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/tensor_handle.h')
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 46bc94f875..1bc9c6531a 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -109,6 +109,8 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status Shape(tensorflow::TensorShape* shape);
+
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
@@ -138,6 +140,12 @@ class TensorHandle : public core::RefCounted {
remote_shape_ = std::move(remote_shape);
}
+ bool OnHostCPU() {
+ mutex_lock ml(ctx_mutex_);
+ return device_ == nullptr ||
+ (ctx_ == nullptr || ctx_->HostCPU() == device_);
+ }
+
private:
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that compuatation is