aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_shape.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/tensor_shape.py')
-rw-r--r--tensorflow/python/framework/tensor_shape.py743
1 files changed, 743 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
new file mode 100644
index 0000000000..d4f27696d4
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -0,0 +1,743 @@
+"""Helper classes for tensor shape inference."""
+import tensorflow.python.platform
+
+
+class Dimension(object):
+ """Represents the value of one dimension in a TensorShape."""
+
+ def __init__(self, value):
+ """Creates a new Dimension with the given value."""
+ if value is None:
+ self._value = None
+ else:
+ self._value = int(value)
+
+ def __repr__(self):
+ return "Dimension(%s)" % repr(self._value)
+
+ def __eq__(self, other):
+ """Returns true if `other` has the same known value as this Dimension."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value == other.value
+
+ def __ne__(self, other):
+ """Returns true if `other` has a different known value from `self`."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value != other.value
+
+ def __int__(self):
+ return self._value
+
+ @property
+ def value(self):
+ """The value of this dimension, or None if it is unknown."""
+ return self._value
+
+ def is_compatible_with(self, other):
+ """Returns true if `other` is compatible with this Dimension.
+
+ Two known Dimensions are compatible if they have the same value.
+ An unknown Dimension is compatible with all other Dimensions.
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ True if this Dimension and `other` are compatible.
+ """
+ other = as_dimension(other)
+ return (self._value is None
+ or other.value is None
+ or self._value == other.value)
+
+ def assert_is_compatible_with(self, other):
+ """Raises an exception if `other` is not compatible with this Dimension.
+
+ Args:
+ other: Another Dimension.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Dimensions %s and %s are not compatible"
+ % (self, other))
+
+ def merge_with(self, other):
+ """Returns a Dimension that combines the information in `self` and `other`.
+
+ Dimensions are combined as follows:
+
+ Dimension(n) .merge_with(Dimension(n)) == Dimension(n)
+ Dimension(n) .merge_with(Dimension(None)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(n)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(None)) == Dimension(None)
+ Dimension(n) .merge_with(Dimension(m)) raises ValueError for n != m
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ other = as_dimension(other)
+ self.assert_is_compatible_with(other)
+ if self._value is None:
+ return Dimension(other.value)
+ else:
+ return Dimension(self._value)
+
+ def __add__(self, other):
+ """Returns the sum of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) + Dimension(n) == Dimension(m + n)
+ Dimension(m) + Dimension(None) == Dimension(None)
+ Dimension(None) + Dimension(n) == Dimension(None)
+ Dimension(None) + Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value + other.value)
+
+ def __sub__(self, other):
+ """Returns the subtraction of `other` from `self`.
+
+ Dimensions are subtracted as follows:
+
+ Dimension(m) - Dimension(n) == Dimension(m - n)
+ Dimension(m) - Dimension(None) == Dimension(None)
+ Dimension(None) - Dimension(n) == Dimension(None)
+ Dimension(None) - Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the subtraction of sum of `other` from `self`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value - other.value)
+
+ def __mul__(self, other):
+ """Returns the product of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) * Dimension(n) == Dimension(m * n)
+ Dimension(m) * Dimension(None) == Dimension(None)
+ Dimension(None) * Dimension(n) == Dimension(None)
+ Dimension(None) * Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value * other.value)
+
+ def __div__(self, other):
+ """Returns the quotient of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) / Dimension(n) == Dimension(m / n)
+ Dimension(m) / Dimension(None) == Dimension(None)
+ Dimension(None) / Dimension(n) == Dimension(None)
+ Dimension(None) / Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value / other.value)
+
+ def __mod__(self, other):
+ """Returns `self` modulo `other.
+
+ Dimension moduli are computed as follows:
+
+ Dimension(m) % Dimension(n) == Dimension(m % n)
+ Dimension(m) % Dimension(None) == Dimension(None)
+ Dimension(None) % Dimension(n) == Dimension(None)
+ Dimension(None) % Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is `self` modulo `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value % other.value)
+
+ def __lt__(self, other):
+ """Returns True if `self` is known to be less than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) < Dimension(n) == m < n
+ Dimension(m) < Dimension(None) == None
+ Dimension(None) < Dimension(n) == None
+ Dimension(None) < Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value < other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value < other.value
+
+ def __le__(self, other):
+ """Returns True if `self` is known to be less than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) <= Dimension(n) == m <= n
+ Dimension(m) <= Dimension(None) == None
+ Dimension(None) <= Dimension(n) == None
+ Dimension(None) <= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value <= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value <= other.value
+
+ def __gt__(self, other):
+ """Returns True if `self` is known to be greater than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) > Dimension(n) == m > n
+ Dimension(m) > Dimension(None) == None
+ Dimension(None) > Dimension(n) == None
+ Dimension(None) > Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value > other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value > other.value
+
+ def __ge__(self, other):
+ """Returns True if `self` is known to be greater than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) >= Dimension(n) == m >= n
+ Dimension(m) >= Dimension(None) == None
+ Dimension(None) >= Dimension(n) == None
+ Dimension(None) >= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value >= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value >= other.value
+
+
+def as_dimension(value):
+ """Converts the given value to a Dimension.
+
+ A Dimenson input will be returned unmodified.
+ An input of `None` will be converted to an unknown Dimension.
+ An integer input will be converted to a Dimension with that value.
+
+ Args:
+ value: The value to be converted.
+
+ Returns:
+ A Dimension corresponding to the given value.
+ """
+ if isinstance(value, Dimension):
+ return value
+ else:
+ return Dimension(value)
+
+
+class TensorShape(object):
+ """Represents the shape of a `Tensor`.
+
+ A `TensorShape` represents a possibly-partial shape specification for a
+ `Tensor`. It may be one of the following:
+
+ * *Fully-known shape:* has a known number of dimensions and a known size
+ for each dimension.
+ * *Partially-known shape:* has a known number of dimensions, and an unknown
+ size for one or more dimension.
+ * *Unknown shape:* has an unknown number of dimensions, and an unknown
+ size in all dimensions.
+
+ If a tensor is produced by an operation of type `"Foo"`, its shape
+ may be inferred if there is a registered shape function for
+ `"Foo"`. See [`tf.RegisterShape()`](framework.md#RegisterShape)
+ for details of shape
+ functions and how to register them. Alternatively, the shape may be set
+ explicitly using [`Tensor.set_shape()`](framework.md#Tensor.set_shape).
+
+ @@merge_with
+ @@concatenate
+
+ @@ndims
+ @@dims
+ @@as_list
+ @@is_compatible_with
+ @@is_fully_defined
+
+ @@with_rank
+ @@with_rank_at_least
+ @@with_rank_at_most
+
+ @@assert_has_rank
+ @@assert_same_rank
+ @@assert_is_compatible_with
+ @@assert_is_fully_defined
+ """
+
+ def __init__(self, dims):
+ """Creates a new TensorShape with the given dimensions.
+
+ Args:
+ dims: A list of Dimensions, or None if the shape is unspecified.
+ DEPRECATED: A single integer is treated as a singleton list.
+ """
+ # TODO(irving): Eliminate the single integer special case.
+ if dims is None:
+ self._dims = None
+ else:
+ try:
+ dims_iter = iter(dims)
+ except TypeError:
+ # Treat as a singleton dimension
+ self._dims = [as_dimension(dims)]
+ else:
+ # Got a list of dimensions
+ self._dims = map(as_dimension, dims_iter)
+
+ def __repr__(self):
+ return "TensorShape(%s)" % str(self._dims)
+
+ @property
+ def dims(self):
+ """Returns a list of Dimensions, or None if the shape is unspecified."""
+ return self._dims
+
+ @property
+ def ndims(self):
+ """Returns the rank of this shape, or None if it is unspecified."""
+ if self._dims is None:
+ return None
+ else:
+ return len(self._dims)
+
+ def __len__(self):
+ """Returns the rank of this shape, or raises ValueError if unspecified."""
+ if self._dims is None:
+ raise ValueError("Cannot take the length of Shape with unknown rank.")
+ return len(self._dims)
+
+ def __nonzero__(self):
+ """Returns True if this shape contains non-zero information."""
+ return self._dims is not None
+
+ def __getitem__(self, key):
+ """Returns the value of a dimension or a shape, depending on the key.
+
+ Args:
+ key: If `key` is an integer, returns the dimension at that index;
+ otherwise if `key` is a slice, returns a TensorShape whose
+ dimensions are those selected by the slice from `self`.
+
+ Returns:
+ A dimension if `key` is an integer, or a `TensorShape` if `key` is a
+ slice.
+
+ Raises:
+ ValueError: If `key` is a slice, and any of its elements are negative, or
+ if `self` is completely unknown and the step is set.
+ """
+ if self._dims is not None:
+ if isinstance(key, slice):
+ return TensorShape(self._dims[key])
+ else:
+ return self._dims[key]
+ else:
+ if isinstance(key, slice):
+ start = key.start if key.start is not None else 0
+ stop = key.stop
+
+ if key.step is not None:
+ # TODO(mrry): Handle these maybe.
+ raise ValueError("Steps are not yet handled")
+ if stop is None:
+ # NOTE(mrry): This implies that TensorShape(None) is compatible with
+ # TensorShape(None)[1:], which is obviously not true. It would be
+ # possible to track the number of dimensions symbolically,
+ # and perhaps we should do that.
+ return unknown_shape()
+ elif start < 0 or stop < 0:
+ # TODO(mrry): Handle this better, as it will be useful for handling
+ # suffixes of otherwise unknown shapes.
+ return unknown_shape()
+ else:
+ return unknown_shape(ndims=stop-start)
+ else:
+ return Dimension(None)
+
+ def num_elements(self):
+ """Returns the total number of elements, or none for incomplete shapes."""
+ if self.is_fully_defined():
+ size = 1
+ for dim in self._dims:
+ size *= dim.value
+ return size
+ else:
+ return None
+
+ def merge_with(self, other):
+ """Returns a `TensorShape` combining the information in `self` and `other`.
+
+ The dimensions in `self` and `other` are merged elementwise,
+ according to the rules defined for `Dimension.merge_with()`.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible.
+ """
+ other = as_shape(other)
+ if self._dims is None:
+ return other
+ else:
+ self.assert_same_rank(other)
+ new_dims = []
+ for i, dim in enumerate(self._dims):
+ new_dims.append(dim.merge_with(other[i]))
+ return TensorShape(new_dims)
+
+ def concatenate(self, other):
+ """Returns the concatenation of the dimension in `self` and `other`.
+
+ *N.B.* If either `self` or `other` is completely unknown,
+ concatenation will discard information about the other shape. In
+ future, we might support concatenation that preserves this
+ information for use with slicing.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` whose dimensions are the concatenation of the
+ dimensions in `self` and `other`.
+ """
+ # TODO(mrry): Handle the case where we concatenate a known shape with a
+ # completely unknown shape, so that we can use the partial information.
+ other = as_shape(other)
+ if self._dims is None or other.dims is None:
+ return unknown_shape()
+ else:
+ return TensorShape(self._dims + other.dims)
+
+ def assert_same_rank(self, other):
+ """Raises an exception if `self` and `other` do not have compatible ranks.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent shapes with the
+ same rank.
+ """
+ other = as_shape(other)
+ if self.ndims is not None and other.ndims is not None:
+ if self.ndims != other.ndims:
+ raise ValueError(
+ "Shapes %s and %s must have the same rank" % (self, other))
+
+ def assert_has_rank(self, rank):
+ """Raises an exception if `self` is not compatible with the given `rank`.
+
+ Args:
+ rank: An integer.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ if self.ndims not in (None, rank):
+ raise ValueError("Shape %s must have rank %d" % (self, rank))
+
+ def with_rank(self, rank):
+ """Returns a shape based on `self` with the given rank.
+
+ This method promotes a completely unknown shape to one with a
+ known rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with the given rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ return self.merge_with(unknown_shape(ndims=rank))
+
+ def with_rank_at_least(self, rank):
+ """Returns a shape based on `self` with at least the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at least the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at least the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims < rank:
+ raise ValueError("Shape %s must have rank at least %d" % (self, rank))
+ else:
+ return self
+
+ def with_rank_at_most(self, rank):
+ """Returns a shape based on `self` with at most the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at most the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at most the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims > rank:
+ raise ValueError("Shape %s must have rank at most %d" % (self, rank))
+ else:
+ return self
+
+ def is_compatible_with(self, other):
+ """Returns True iff `self` is compatible with `other`.
+
+ Two possibly-partially-defined shapes are compatible if there
+ exists a fully-defined shape that both shapes can represent. Thus,
+ compatibility allows the shape inference code to reason about
+ partially-defined shapes. For example:
+
+ * TensorShape(None) is compatible with all shapes.
+
+ * TensorShape([None, None]) is compatible with all two-dimensional
+ shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
+ not compatible with, for example, TensorShape([None]) or
+ TensorShape([None, None, None]).
+
+ * TensorShape([32, None]) is compatible with all two-dimensional shapes
+ with size 32 in the 0th dimension, and also TensorShape([None, None])
+ and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
+
+ * TensorShape([32, 784]) is compatible with itself, and also
+ TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
+ None]) and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32, 1, 784]) or TensorShape([None]).
+
+ The compatibility relation is reflexive and symmetric, but not
+ transitive. For example, TensorShape([32, 784]) is compatible with
+ TensorShape(None), and TensorShape(None) is compatible with
+ TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
+ TensorShape([4, 4]).
+
+ Args:
+ other: Another TensorShape.
+
+ Returns:
+ True iff `self` is compatible with `other`.
+
+ """
+ other = as_shape(other)
+ if self._dims is not None and other.dims is not None:
+ if self.ndims != other.ndims:
+ return False
+ for x_dim, y_dim in zip(self._dims, other.dims):
+ if not x_dim.is_compatible_with(y_dim):
+ return False
+ return True
+
+ def assert_is_compatible_with(self, other):
+ """Raises exception if `self` and `other` do not represent the same shape.
+
+ This method can be used to assert that there exists a shape that both
+ `self` and `other` represent.
+
+ Args:
+ other: Another TensorShape.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent the same shape.
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Shapes %s and %s are incompatible" % (self, other))
+
+ def is_fully_defined(self):
+ """Returns True iff `self` is fully defined in every dimension."""
+ return (self._dims is not None
+ and all(dim.value is not None for dim in self._dims))
+
+ def assert_is_fully_defined(self):
+ """Raises an exception if `self` is not fully defined in every dimension.
+
+ Raises:
+ ValueError: If `self` does not have a known value for every dimension.
+ """
+ if not self.is_fully_defined():
+ raise ValueError("Shape %s is not fully defined" % self)
+
+ def as_dimension_list(self):
+ """DEPRECATED: use as_list()."""
+ self.assert_is_fully_defined()
+ return self.as_list()
+
+ def as_list(self):
+ """Returns a list of integers or None for each dimension."""
+ return [dim.value for dim in self._dims]
+
+ def __eq__(self, other):
+ """Returns True if `self` is equivalent to `other`."""
+ other = as_shape(other)
+ return self._dims == other.dims
+
+ def __ne__(self, other):
+ """Returns True if `self` is known to be different from `other`."""
+ other = as_shape(other)
+ if self.ndims is None or other.ndims is None:
+ raise ValueError("The inequality of unknown TensorShapes is undefined.")
+ if self.ndims != other.ndims:
+ return True
+ return self._dims != other.dims
+
+
+def as_shape(shape):
+ """Converts the given object to a TensorShape."""
+ if isinstance(shape, TensorShape):
+ return shape
+ else:
+ return TensorShape(shape)
+
+
+def unknown_shape(ndims=None):
+ """Returns an unknown TensorShape, optionally with a known rank.
+
+ Args:
+ ndims: (Optional) If specified, the number of dimensions in the shape.
+
+ Returns:
+ An unknown TensorShape.
+ """
+ if ndims is None:
+ return TensorShape(None)
+ else:
+ return TensorShape([Dimension(None) for _ in range(ndims)])
+
+
+def scalar():
+ """Returns a shape representing a scalar."""
+ return TensorShape([])
+
+
+def vector(length):
+ """Returns a shape representing a vector.
+
+ Args:
+ length: The length of the vector, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a vector of the given length.
+ """
+ return TensorShape([length])
+
+
+def matrix(rows, cols):
+ """Returns a shape representing a matrix.
+
+ Args:
+ rows: The number of rows in the matrix, which may be None if unknown.
+ cols: The number of columns in the matrix, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a matrix of the given size.
+ """
+ return TensorShape([rows, cols])