diff options
Diffstat (limited to 'tensorflow/python/framework/tensor_shape.py')
-rw-r--r-- | tensorflow/python/framework/tensor_shape.py | 743 |
1 files changed, 743 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py new file mode 100644 index 0000000000..d4f27696d4 --- /dev/null +++ b/tensorflow/python/framework/tensor_shape.py @@ -0,0 +1,743 @@ +"""Helper classes for tensor shape inference.""" +import tensorflow.python.platform + + +class Dimension(object): + """Represents the value of one dimension in a TensorShape.""" + + def __init__(self, value): + """Creates a new Dimension with the given value.""" + if value is None: + self._value = None + else: + self._value = int(value) + + def __repr__(self): + return "Dimension(%s)" % repr(self._value) + + def __eq__(self, other): + """Returns true if `other` has the same known value as this Dimension.""" + other = as_dimension(other) + if self._value is None or other.value is None: + return None + return self._value == other.value + + def __ne__(self, other): + """Returns true if `other` has a different known value from `self`.""" + other = as_dimension(other) + if self._value is None or other.value is None: + return None + return self._value != other.value + + def __int__(self): + return self._value + + @property + def value(self): + """The value of this dimension, or None if it is unknown.""" + return self._value + + def is_compatible_with(self, other): + """Returns true if `other` is compatible with this Dimension. + + Two known Dimensions are compatible if they have the same value. + An unknown Dimension is compatible with all other Dimensions. + + Args: + other: Another Dimension. + + Returns: + True if this Dimension and `other` are compatible. + """ + other = as_dimension(other) + return (self._value is None + or other.value is None + or self._value == other.value) + + def assert_is_compatible_with(self, other): + """Raises an exception if `other` is not compatible with this Dimension. + + Args: + other: Another Dimension. + + Raises: + ValueError: If `self` and `other` are not compatible (see + is_compatible_with). + """ + if not self.is_compatible_with(other): + raise ValueError("Dimensions %s and %s are not compatible" + % (self, other)) + + def merge_with(self, other): + """Returns a Dimension that combines the information in `self` and `other`. + + Dimensions are combined as follows: + + Dimension(n) .merge_with(Dimension(n)) == Dimension(n) + Dimension(n) .merge_with(Dimension(None)) == Dimension(n) + Dimension(None).merge_with(Dimension(n)) == Dimension(n) + Dimension(None).merge_with(Dimension(None)) == Dimension(None) + Dimension(n) .merge_with(Dimension(m)) raises ValueError for n != m + + Args: + other: Another Dimension. + + Returns: + A Dimension containing the combined information of `self` and + `other`. + + Raises: + ValueError: If `self` and `other` are not compatible (see + is_compatible_with). + """ + other = as_dimension(other) + self.assert_is_compatible_with(other) + if self._value is None: + return Dimension(other.value) + else: + return Dimension(self._value) + + def __add__(self, other): + """Returns the sum of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) + Dimension(n) == Dimension(m + n) + Dimension(m) + Dimension(None) == Dimension(None) + Dimension(None) + Dimension(n) == Dimension(None) + Dimension(None) + Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value + other.value) + + def __sub__(self, other): + """Returns the subtraction of `other` from `self`. + + Dimensions are subtracted as follows: + + Dimension(m) - Dimension(n) == Dimension(m - n) + Dimension(m) - Dimension(None) == Dimension(None) + Dimension(None) - Dimension(n) == Dimension(None) + Dimension(None) - Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the subtraction of sum of `other` from `self`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value - other.value) + + def __mul__(self, other): + """Returns the product of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) * Dimension(n) == Dimension(m * n) + Dimension(m) * Dimension(None) == Dimension(None) + Dimension(None) * Dimension(n) == Dimension(None) + Dimension(None) * Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value * other.value) + + def __div__(self, other): + """Returns the quotient of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) / Dimension(n) == Dimension(m / n) + Dimension(m) / Dimension(None) == Dimension(None) + Dimension(None) / Dimension(n) == Dimension(None) + Dimension(None) / Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value / other.value) + + def __mod__(self, other): + """Returns `self` modulo `other. + + Dimension moduli are computed as follows: + + Dimension(m) % Dimension(n) == Dimension(m % n) + Dimension(m) % Dimension(None) == Dimension(None) + Dimension(None) % Dimension(n) == Dimension(None) + Dimension(None) % Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is `self` modulo `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value % other.value) + + def __lt__(self, other): + """Returns True if `self` is known to be less than `other`. + + Dimensions are compared as follows: + + Dimension(m) < Dimension(n) == m < n + Dimension(m) < Dimension(None) == None + Dimension(None) < Dimension(n) == None + Dimension(None) < Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value < other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value < other.value + + def __le__(self, other): + """Returns True if `self` is known to be less than or equal to `other`. + + Dimensions are compared as follows: + + Dimension(m) <= Dimension(n) == m <= n + Dimension(m) <= Dimension(None) == None + Dimension(None) <= Dimension(n) == None + Dimension(None) <= Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value <= other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value <= other.value + + def __gt__(self, other): + """Returns True if `self` is known to be greater than `other`. + + Dimensions are compared as follows: + + Dimension(m) > Dimension(n) == m > n + Dimension(m) > Dimension(None) == None + Dimension(None) > Dimension(n) == None + Dimension(None) > Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value > other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value > other.value + + def __ge__(self, other): + """Returns True if `self` is known to be greater than or equal to `other`. + + Dimensions are compared as follows: + + Dimension(m) >= Dimension(n) == m >= n + Dimension(m) >= Dimension(None) == None + Dimension(None) >= Dimension(n) == None + Dimension(None) >= Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value >= other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value >= other.value + + +def as_dimension(value): + """Converts the given value to a Dimension. + + A Dimenson input will be returned unmodified. + An input of `None` will be converted to an unknown Dimension. + An integer input will be converted to a Dimension with that value. + + Args: + value: The value to be converted. + + Returns: + A Dimension corresponding to the given value. + """ + if isinstance(value, Dimension): + return value + else: + return Dimension(value) + + +class TensorShape(object): + """Represents the shape of a `Tensor`. + + A `TensorShape` represents a possibly-partial shape specification for a + `Tensor`. It may be one of the following: + + * *Fully-known shape:* has a known number of dimensions and a known size + for each dimension. + * *Partially-known shape:* has a known number of dimensions, and an unknown + size for one or more dimension. + * *Unknown shape:* has an unknown number of dimensions, and an unknown + size in all dimensions. + + If a tensor is produced by an operation of type `"Foo"`, its shape + may be inferred if there is a registered shape function for + `"Foo"`. See [`tf.RegisterShape()`](framework.md#RegisterShape) + for details of shape + functions and how to register them. Alternatively, the shape may be set + explicitly using [`Tensor.set_shape()`](framework.md#Tensor.set_shape). + + @@merge_with + @@concatenate + + @@ndims + @@dims + @@as_list + @@is_compatible_with + @@is_fully_defined + + @@with_rank + @@with_rank_at_least + @@with_rank_at_most + + @@assert_has_rank + @@assert_same_rank + @@assert_is_compatible_with + @@assert_is_fully_defined + """ + + def __init__(self, dims): + """Creates a new TensorShape with the given dimensions. + + Args: + dims: A list of Dimensions, or None if the shape is unspecified. + DEPRECATED: A single integer is treated as a singleton list. + """ + # TODO(irving): Eliminate the single integer special case. + if dims is None: + self._dims = None + else: + try: + dims_iter = iter(dims) + except TypeError: + # Treat as a singleton dimension + self._dims = [as_dimension(dims)] + else: + # Got a list of dimensions + self._dims = map(as_dimension, dims_iter) + + def __repr__(self): + return "TensorShape(%s)" % str(self._dims) + + @property + def dims(self): + """Returns a list of Dimensions, or None if the shape is unspecified.""" + return self._dims + + @property + def ndims(self): + """Returns the rank of this shape, or None if it is unspecified.""" + if self._dims is None: + return None + else: + return len(self._dims) + + def __len__(self): + """Returns the rank of this shape, or raises ValueError if unspecified.""" + if self._dims is None: + raise ValueError("Cannot take the length of Shape with unknown rank.") + return len(self._dims) + + def __nonzero__(self): + """Returns True if this shape contains non-zero information.""" + return self._dims is not None + + def __getitem__(self, key): + """Returns the value of a dimension or a shape, depending on the key. + + Args: + key: If `key` is an integer, returns the dimension at that index; + otherwise if `key` is a slice, returns a TensorShape whose + dimensions are those selected by the slice from `self`. + + Returns: + A dimension if `key` is an integer, or a `TensorShape` if `key` is a + slice. + + Raises: + ValueError: If `key` is a slice, and any of its elements are negative, or + if `self` is completely unknown and the step is set. + """ + if self._dims is not None: + if isinstance(key, slice): + return TensorShape(self._dims[key]) + else: + return self._dims[key] + else: + if isinstance(key, slice): + start = key.start if key.start is not None else 0 + stop = key.stop + + if key.step is not None: + # TODO(mrry): Handle these maybe. + raise ValueError("Steps are not yet handled") + if stop is None: + # NOTE(mrry): This implies that TensorShape(None) is compatible with + # TensorShape(None)[1:], which is obviously not true. It would be + # possible to track the number of dimensions symbolically, + # and perhaps we should do that. + return unknown_shape() + elif start < 0 or stop < 0: + # TODO(mrry): Handle this better, as it will be useful for handling + # suffixes of otherwise unknown shapes. + return unknown_shape() + else: + return unknown_shape(ndims=stop-start) + else: + return Dimension(None) + + def num_elements(self): + """Returns the total number of elements, or none for incomplete shapes.""" + if self.is_fully_defined(): + size = 1 + for dim in self._dims: + size *= dim.value + return size + else: + return None + + def merge_with(self, other): + """Returns a `TensorShape` combining the information in `self` and `other`. + + The dimensions in `self` and `other` are merged elementwise, + according to the rules defined for `Dimension.merge_with()`. + + Args: + other: Another `TensorShape`. + + Returns: + A `TensorShape` containing the combined information of `self` and + `other`. + + Raises: + ValueError: If `self` and `other` are not compatible. + """ + other = as_shape(other) + if self._dims is None: + return other + else: + self.assert_same_rank(other) + new_dims = [] + for i, dim in enumerate(self._dims): + new_dims.append(dim.merge_with(other[i])) + return TensorShape(new_dims) + + def concatenate(self, other): + """Returns the concatenation of the dimension in `self` and `other`. + + *N.B.* If either `self` or `other` is completely unknown, + concatenation will discard information about the other shape. In + future, we might support concatenation that preserves this + information for use with slicing. + + Args: + other: Another `TensorShape`. + + Returns: + A `TensorShape` whose dimensions are the concatenation of the + dimensions in `self` and `other`. + """ + # TODO(mrry): Handle the case where we concatenate a known shape with a + # completely unknown shape, so that we can use the partial information. + other = as_shape(other) + if self._dims is None or other.dims is None: + return unknown_shape() + else: + return TensorShape(self._dims + other.dims) + + def assert_same_rank(self, other): + """Raises an exception if `self` and `other` do not have compatible ranks. + + Args: + other: Another `TensorShape`. + + Raises: + ValueError: If `self` and `other` do not represent shapes with the + same rank. + """ + other = as_shape(other) + if self.ndims is not None and other.ndims is not None: + if self.ndims != other.ndims: + raise ValueError( + "Shapes %s and %s must have the same rank" % (self, other)) + + def assert_has_rank(self, rank): + """Raises an exception if `self` is not compatible with the given `rank`. + + Args: + rank: An integer. + + Raises: + ValueError: If `self` does not represent a shape with the given `rank`. + """ + if self.ndims not in (None, rank): + raise ValueError("Shape %s must have rank %d" % (self, rank)) + + def with_rank(self, rank): + """Returns a shape based on `self` with the given rank. + + This method promotes a completely unknown shape to one with a + known rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with the given rank. + + Raises: + ValueError: If `self` does not represent a shape with the given `rank`. + """ + return self.merge_with(unknown_shape(ndims=rank)) + + def with_rank_at_least(self, rank): + """Returns a shape based on `self` with at least the given rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with at least the given + rank. + + Raises: + ValueError: If `self` does not represent a shape with at least the given + `rank`. + """ + if self.ndims is not None and self.ndims < rank: + raise ValueError("Shape %s must have rank at least %d" % (self, rank)) + else: + return self + + def with_rank_at_most(self, rank): + """Returns a shape based on `self` with at most the given rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with at most the given + rank. + + Raises: + ValueError: If `self` does not represent a shape with at most the given + `rank`. + """ + if self.ndims is not None and self.ndims > rank: + raise ValueError("Shape %s must have rank at most %d" % (self, rank)) + else: + return self + + def is_compatible_with(self, other): + """Returns True iff `self` is compatible with `other`. + + Two possibly-partially-defined shapes are compatible if there + exists a fully-defined shape that both shapes can represent. Thus, + compatibility allows the shape inference code to reason about + partially-defined shapes. For example: + + * TensorShape(None) is compatible with all shapes. + + * TensorShape([None, None]) is compatible with all two-dimensional + shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is + not compatible with, for example, TensorShape([None]) or + TensorShape([None, None, None]). + + * TensorShape([32, None]) is compatible with all two-dimensional shapes + with size 32 in the 0th dimension, and also TensorShape([None, None]) + and TensorShape(None). It is not compatible with, for example, + TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). + + * TensorShape([32, 784]) is compatible with itself, and also + TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, + None]) and TensorShape(None). It is not compatible with, for example, + TensorShape([32, 1, 784]) or TensorShape([None]). + + The compatibility relation is reflexive and symmetric, but not + transitive. For example, TensorShape([32, 784]) is compatible with + TensorShape(None), and TensorShape(None) is compatible with + TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with + TensorShape([4, 4]). + + Args: + other: Another TensorShape. + + Returns: + True iff `self` is compatible with `other`. + + """ + other = as_shape(other) + if self._dims is not None and other.dims is not None: + if self.ndims != other.ndims: + return False + for x_dim, y_dim in zip(self._dims, other.dims): + if not x_dim.is_compatible_with(y_dim): + return False + return True + + def assert_is_compatible_with(self, other): + """Raises exception if `self` and `other` do not represent the same shape. + + This method can be used to assert that there exists a shape that both + `self` and `other` represent. + + Args: + other: Another TensorShape. + + Raises: + ValueError: If `self` and `other` do not represent the same shape. + """ + if not self.is_compatible_with(other): + raise ValueError("Shapes %s and %s are incompatible" % (self, other)) + + def is_fully_defined(self): + """Returns True iff `self` is fully defined in every dimension.""" + return (self._dims is not None + and all(dim.value is not None for dim in self._dims)) + + def assert_is_fully_defined(self): + """Raises an exception if `self` is not fully defined in every dimension. + + Raises: + ValueError: If `self` does not have a known value for every dimension. + """ + if not self.is_fully_defined(): + raise ValueError("Shape %s is not fully defined" % self) + + def as_dimension_list(self): + """DEPRECATED: use as_list().""" + self.assert_is_fully_defined() + return self.as_list() + + def as_list(self): + """Returns a list of integers or None for each dimension.""" + return [dim.value for dim in self._dims] + + def __eq__(self, other): + """Returns True if `self` is equivalent to `other`.""" + other = as_shape(other) + return self._dims == other.dims + + def __ne__(self, other): + """Returns True if `self` is known to be different from `other`.""" + other = as_shape(other) + if self.ndims is None or other.ndims is None: + raise ValueError("The inequality of unknown TensorShapes is undefined.") + if self.ndims != other.ndims: + return True + return self._dims != other.dims + + +def as_shape(shape): + """Converts the given object to a TensorShape.""" + if isinstance(shape, TensorShape): + return shape + else: + return TensorShape(shape) + + +def unknown_shape(ndims=None): + """Returns an unknown TensorShape, optionally with a known rank. + + Args: + ndims: (Optional) If specified, the number of dimensions in the shape. + + Returns: + An unknown TensorShape. + """ + if ndims is None: + return TensorShape(None) + else: + return TensorShape([Dimension(None) for _ in range(ndims)]) + + +def scalar(): + """Returns a shape representing a scalar.""" + return TensorShape([]) + + +def vector(length): + """Returns a shape representing a vector. + + Args: + length: The length of the vector, which may be None if unknown. + + Returns: + A TensorShape representing a vector of the given length. + """ + return TensorShape([length]) + + +def matrix(rows, cols): + """Returns a shape representing a matrix. + + Args: + rows: The number of rows in the matrix, which may be None if unknown. + cols: The number of columns in the matrix, which may be None if unknown. + + Returns: + A TensorShape representing a matrix of the given size. + """ + return TensorShape([rows, cols]) |