aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r--tensorflow/python/framework/tensor_util.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index ca63efbc84..8c9dfce7cc 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -935,8 +935,10 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
def is_tensor(x): # pylint: disable=invalid-name
"""Check whether `x` is of tensor type.
- Check whether an object is a tensor. Equivalent to
- `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])`.
+ Check whether an object is a tensor. This check is equivalent to calling
+ `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])` and also checks
+ if all the component variables of a MirroredVariable or a TowerLocalVariable
+ are tensors.
Args:
x: A python object to check.
@@ -944,4 +946,5 @@ def is_tensor(x): # pylint: disable=invalid-name
Returns:
`True` if `x` is a tensor, `False` if not.
"""
- return isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) # pylint: disable=protected-access
+ return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access
+ (hasattr(x, "is_tensor_like") and x.is_tensor_like))