diff options
Diffstat (limited to 'tensorflow/python/framework/common_shapes.py')
-rw-r--r-- | tensorflow/python/framework/common_shapes.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 3c5aebbce8..40788e24c4 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -28,6 +28,18 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +def has_fully_defined_shape(tensor): + """Returns true if tensor has a fully defined shape.""" + return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined() + + +def rank(tensor): + """Return a rank if it is a tensor, else return None.""" + if isinstance(tensor, ops.Tensor): + return tensor._rank() # pylint: disable=protected-access + return None + + def scalar_shape(unused_op): """Shape function for ops that output a scalar value.""" return [tensor_shape.scalar()] |