diff options
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index ca63efbc84..8c9dfce7cc 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -935,8 +935,10 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name def is_tensor(x): # pylint: disable=invalid-name """Check whether `x` is of tensor type. - Check whether an object is a tensor. Equivalent to - `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])`. + Check whether an object is a tensor. This check is equivalent to calling + `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])` and also checks + if all the component variables of a MirroredVariable or a TowerLocalVariable + are tensors. Args: x: A python object to check. @@ -944,4 +946,5 @@ def is_tensor(x): # pylint: disable=invalid-name Returns: `True` if `x` is a tensor, `False` if not. """ - return isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) # pylint: disable=protected-access + return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access + (hasattr(x, "is_tensor_like") and x.is_tensor_like)) |