aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/common_shapes.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/common_shapes.py')
-rw-r--r--tensorflow/python/framework/common_shapes.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 3c5aebbce8..40788e24c4 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -28,6 +28,18 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+def has_fully_defined_shape(tensor):
+ """Returns true if tensor has a fully defined shape."""
+ return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
+
+
+def rank(tensor):
+ """Return a rank if it is a tensor, else return None."""
+ if isinstance(tensor, ops.Tensor):
+ return tensor._rank() # pylint: disable=protected-access
+ return None
+
+
def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value."""
return [tensor_shape.scalar()]