aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py1207
1 files changed, 1207 insertions, 0 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
new file mode 100644
index 0000000000..ed780db625
--- /dev/null
+++ b/tensorflow/python/ops/array_ops.py
@@ -0,0 +1,1207 @@
+"""## Casting
+
+TensorFlow provides several operations that you can use to cast tensor data
+types in your graph.
+
+@@string_to_number
+@@to_double
+@@to_float
+@@to_bfloat16
+@@to_int32
+@@to_int64
+@@cast
+
+## Shapes and Shaping
+
+TensorFlow provides several operations that you can use to determine the shape
+of a tensor and change the shape of a tensor.
+
+@@shape
+@@size
+@@rank
+@@reshape
+@@squeeze
+@@expand_dims
+
+## Slicing and Joining
+
+TensorFlow provides several operations to slice or extract parts of a tensor,
+or join multiple tensors together.
+
+@@slice
+@@split
+@@tile
+@@pad
+@@concat
+@@pack
+@@unpack
+@@reverse_sequence
+@@reverse
+@@transpose
+@@gather
+@@dynamic_partition
+@@dynamic_stitch
+"""
+import sys
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+# pylint: disable=wildcard-import
+# 'Constant' gets imported in the module 'array_ops'.
+from tensorflow.python.ops.constant_op import constant
+from tensorflow.python.ops.gen_array_ops import *
+
+
+# We override the 'slice' for the "slice" op, so we keep python's
+# existing 'slice' for later use in this module.
+_baseslice = slice
+
+
+# Aliases for some automatically-generated names.
+listdiff = gen_array_ops.list_diff
+
+
+# pylint: disable=undefined-variable,protected-access
+def _SliceHelper(tensor, slice_spec):
+ """Overload for Tensor.__getitem__.
+
+ Currently the size of the slice must be statically known in each dimension,
+ i.e. the "stop" of the slice must not be omitted.
+
+ TODO(mrry): Support slices where the sizes are not specified.
+ TODO(mrry): Support negative indices in slices with numpy/Python semantics.
+
+ Args:
+ tensor: An ops.Tensor object.
+ slice_spec: The arguments to Tensor.__getitem__.
+
+ Returns:
+ The appropriate slice of "tensor", based on "slice_spec".
+
+ Raises:
+ ValueError: If a slice range is negative size.
+ TypeError: If the slice indices aren't int, slice, or Ellipsis.
+ """
+ if not isinstance(slice_spec, (list, tuple)):
+ slice_spec = [slice_spec]
+ indices = []
+ sizes = []
+ squeeze_dims = []
+ for dim, s in enumerate(slice_spec):
+ if isinstance(s, int):
+ if s < 0:
+ raise NotImplementedError("Negative indices are currently unsupported")
+ indices.append(s)
+ sizes.append(1)
+ squeeze_dims.append(dim)
+ elif isinstance(s, _baseslice):
+ if s.step not in (None, 1):
+ raise NotImplementedError(
+ "Steps other than 1 are not currently supported")
+ start = s.start if s.start is not None else 0
+ if start < 0:
+ raise NotImplementedError(
+ "Negative start indices are not currently supported")
+ indices.append(start)
+ if s.stop is not None and s.stop < 0:
+ raise NotImplementedError(
+ "Negative stop indices are not currently supported")
+ # NOTE(mrry): If the stop is not specified, Python substitutes
+ # sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently
+ # supports signed DT_INT32 arguments, we use -1 to specify that all
+ # elements should be captured.
+ if s.stop is None or s.stop == sys.maxsize:
+ sizes.append(-1)
+ else:
+ if start > s.stop:
+ raise ValueError("Stop must be at least start")
+ sizes.append(s.stop - start)
+ elif s is Ellipsis:
+ raise NotImplementedError("Ellipsis is not currently supported")
+ else:
+ raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
+ sliced = slice(tensor, indices, sizes)
+ if squeeze_dims:
+ return squeeze(sliced, squeeze_dims=squeeze_dims)
+ else:
+ return sliced
+
+
+def slice(input_, begin, size, name=None):
+ """Extracts a slice from a tensor.
+
+ This operation extracts a slice of size `size` from a tensor `input` starting
+ at the location specified by `begin`. The slice `size` is represented as a
+ tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
+ of `input` that you want to slice. The starting location (`begin`) for the
+ slice is represented as an offset in each dimension of `input`. In other
+ words, `begin[i]` is the offset into the 'i'th dimension of `input` that you
+ want to slice from.
+
+ `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
+ all remaining elements in dimension i are included in the
+ slice. In other words, this is equivalent to setting:
+
+ `size[i] = input.dim_size(i) - begin[i]`
+
+ This operation requires that:
+
+ `0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]`
+
+ For example:
+
+ ```
+ # 'input' is [[[1, 1, 1], [2, 2, 2]],
+ # [[3, 3, 3], [4, 4, 4]],
+ # [[5, 5, 5], [6, 6, 6]]]
+ tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
+ tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
+ [4, 4, 4]]]
+ tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
+ [[5, 5, 5]]]
+ ```
+
+ Args:
+ input_: A `Tensor`.
+ begin: An `int32` or `int64` `Tensor`.
+ size: An `int32` or `int64` `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` the same type as `input`.
+ """
+ return gen_array_ops._slice(input_, begin, size, name=name)
+
+
+ops.Tensor._override_operator("__getitem__", _SliceHelper)
+
+
+def pack(values, name="pack"):
+ """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
+
+ Packs tensors in `values` into a tensor with rank one higher than each tensor
+ in `values` and shape `[len(values)] + values[0].shape`. The output satisfies
+ `output[i, ...] = values[i][...]`.
+
+ This is the opposite of unpack. The numpy equivalent is
+
+ tf.pack([x, y, z]) = np.asarray([x, y, z])
+
+ Args:
+ values: A list of `Tensor` objects with the same shape and type.
+ name: A name for this operation (optional).
+
+ Returns:
+ output: A packed `Tensor` with the same type as `values`.
+ """
+ return gen_array_ops._pack(values, name=name)
+
+
+def unpack(value, num=None, name="unpack"):
+ """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
+
+ Unpacks `num` tensors from `value` along the first dimension.
+ If `num` is not specified (the default), it is inferred from `value`'s shape.
+ If `value.shape[0]` is not known, `ValueError` is raised.
+
+ The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in
+ `output` has shape `value.shape[1:]`.
+
+ This is the opposite of pack. The numpy equivalent is
+
+ tf.unpack(x, n) = list(x)
+
+ Args:
+ value: A rank `R > 0` `Tensor` to be unpacked.
+ num: An `int`. The first dimension of value. Automatically inferred if
+ `None` (the default).
+ name: A name for the operation (optional).
+
+ Returns:
+ The list of `Tensor` objects unpacked from `value`.
+
+ Raises:
+ ValueError: If `num` is unspecified and cannot be inferred.
+ """
+ if num is None:
+ value = ops.convert_to_tensor(value)
+ shape = value.get_shape()
+ num = shape[0].value
+ if num is None:
+ raise ValueError("Cannot infer num from shape %s" % shape)
+ return gen_array_ops._unpack(value, num=num, name=name)
+
+
+def concat(concat_dim, values, name="concat"):
+ """Concatenates tensors along one dimension.
+
+ Concatenates the list of tensors `values` along dimension `concat_dim`. If
+ `values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn]`, the concatenated
+ result has shape
+
+ [D0, D1, ... Rconcat_dim, ...Dn]
+
+ where
+
+ Rconcat_dim = sum(Dconcat_dim(i))
+
+ That is, the data from the input tensors is joined along the `concat_dim`
+ dimension.
+
+ The number of dimensions of the input tensors must match, and all dimensions
+ except `concat_dim` must be equal.
+
+ For example:
+
+ ```python
+ t1 = [[1, 2, 3], [4, 5, 6]]
+ t2 = [[7, 8, 9], [10, 11, 12]]
+ tf.concat(0, [t1, t2]) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
+ tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
+
+ # tensor t3 with shape [2, 3]
+ # tensor t4 with shape [2, 3]
+ tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
+ tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]
+ ```
+
+ Args:
+ concat_dim: 0-D `int32` `Tensor`. Dimension along which to concatenate.
+ values: A list of `Tensor` objects or a single `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` resulting from concatenation of the input tensors.
+ """
+ if not isinstance(values, (list)):
+ values = [values]
+ # TODO(mrry): Change to return values?
+ if len(values) == 1: # Degenerate case of one tensor.
+ return identity(values[0], name=name)
+ return gen_array_ops._concat(concat_dim=concat_dim,
+ values=values,
+ name=name)
+
+
+@ops.RegisterShape("Pack")
+def _PackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ for inp in op.inputs[1:]:
+ input_shape = input_shape.merge_with(inp.get_shape())
+ return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)]
+
+
+@ops.RegisterShape("Unpack")
+def _UnpackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ return [input_shape[1:]] * op.get_attr("num")
+
+
+@ops.RegisterShape("Concat")
+def _ConcatShape(op):
+ concat_dim = tensor_util.ConstantValue(op.inputs[0])
+ if concat_dim is None:
+ # Return an unknown shape with the same rank as the inputs, or an
+ # unknown rank if no input's rank is known.
+ rank = None
+ for value in op.inputs[1:]:
+ if rank is not None:
+ value.get_shape().assert_has_rank(rank)
+ else:
+ rank = value.get_shape().ndims
+ return [tensor_shape.unknown_shape(ndims=max(rank, 1))]
+
+ else:
+ # Merge all the non-concat dims, and sum the concat dim to make an
+ # output shape.
+ concat_dim = int(concat_dim)
+ output_shape = op.inputs[1].get_shape()
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ if output_shape.ndims == 0:
+ output_shape = tensor_shape.TensorShape([1])
+ for value in op.inputs[2:]:
+ value_shape = value.get_shape()
+ if value_shape.ndims is not None and concat_dim >= value_shape.ndims:
+ if value_shape.ndims == 0 and concat_dim == 0:
+ # Let concat handle scalars
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ value_shape = tensor_shape.TensorShape([1])
+ else:
+ raise ValueError("concat_dim is out of range (values rank = %d)" %
+ value_shape.ndims)
+ before = output_shape[:concat_dim].merge_with(value_shape[:concat_dim])
+ at = output_shape[concat_dim] + value_shape[concat_dim]
+ after = output_shape[
+ concat_dim + 1:].merge_with(value_shape[concat_dim + 1:])
+ output_shape = before.concatenate(at).concatenate(after)
+ return [output_shape]
+
+
+def sparse_mask(a, mask_indices, name=None):
+ """Masks elements of `IndexedSlices`.
+
+ Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
+ contains a subset of the slices of `a`. Only the slices at indices specified
+ in `mask_indices` are returned.
+
+ This is useful when you need to extract a subset of slices in an
+ `IndexedSlices` object.
+
+ For example:
+
+ ```python
+ # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
+ # with shape [1000, 10]
+ a.indices => [12, 26, 37, 45]
+ tf.shape(a.values) => [4, 10]
+
+ # `b` will be the subset of `a` slices at its second and third indices, so
+ # we want to mask of its first and last indices (which are at absolute
+ # indices 12, 45)
+ b = tf.sparse_mask(a, [12, 45])
+
+ b.indices => [26, 37]
+ tf.shape(b.values) => [2, 10]
+
+ ```
+
+ Args:
+ * `a`: An `IndexedSlices` instance.
+ * `mask_indices`: Indices of elements to mask.
+ * `name`: A name for the operation (optional).
+
+ Returns:
+ The masked `IndexedSlices` instance.
+ """
+ with ops.op_scope([a, mask_indices], name, "sparse_mask") as name:
+ indices = a.indices
+ out_indices, to_gather = listdiff(indices, mask_indices)
+ out_values = gather(a.values, to_gather, name=name)
+ return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
+
+
+def split(split_dim, num_split, value, name="split"):
+ """Splits a tensor into `num_split` tensors along one dimension.
+
+ Splits `value` along dimension `split_dim` into `num_split` smaller tensors.
+ Requires that `num_split` evenly divide `value.shape[split_dim]`.
+
+ For example:
+
+ ```python
+ # 'value' is a tensor with shape [5, 30]
+ # Split 'value' into 3 tensors along dimension 1
+ split0, split1, split2 = tf.split(1, 3, value)
+ tf.shape(split0) ==> [5, 10]
+ ```
+
+ Args:
+ split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
+ Must be in the range `[0, rank(value))`.
+ num_split: A 0-D `int32` `Tensor`. The number of ways to split.
+ value: The `Tensor` to split.
+ name: A name for the operation (optional).
+
+ Returns:
+ `num_split` `Tensor` objects resulting from splitting `value`.
+ """
+ return gen_array_ops._split(split_dim=split_dim,
+ num_split=num_split,
+ value=value,
+ name=name)
+
+
+@ops.RegisterShape("Reverse")
+def _ReverseShape(op):
+ return [op.inputs[0].get_shape().with_rank_at_most(8)]
+
+
+def transpose(a, perm=None, name="transpose"):
+ """Transposes `a`. Permutes the dimensions according to `perm`.
+
+ The returned tensor's dimension i will correspond to the input dimension
+ `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
+ the rank of the input tensor. Hence by default, this operation performs a
+ regular matrix transpose on 2-D input Tensors.
+
+ For example:
+
+ ```python
+ # 'x' is [[1 2 3]
+ # [4 5 6]]
+ tf.transpose(x) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # Equivalently
+ tf.transpose(x perm=[0, 1]) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # 'perm' is more useful for n-dimensional tensors, for n > 2
+ # 'x' is [[[1 2 3]
+ # [4 5 6]]
+ # [[7 8 9]
+ # [10 11 12]]]
+ # Take the transpose of the matrices in dimension-0
+ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
+ [2 5]
+ [3 6]]
+
+ [[7 10]
+ [8 11]
+ [9 12]]]
+ ```
+
+ Args:
+ a: A `Tensor`.
+ perm: A permutation of the dimensions of `a`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A transposed `Tensor`.
+ """
+ with ops.op_scope([a], name, "transpose") as name:
+ if perm is None:
+ dims = gen_math_ops._range(0, gen_array_ops.rank(a), 1)
+ perm = gen_array_ops.reverse(dims, [True])
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ # NOTE(mrry): Setting the shape explicitly because
+ # reverse is not handled by the shape function.
+ input_shape = ret.op.inputs[0].get_shape().dims
+ if input_shape is not None:
+ ret.set_shape(input_shape[::-1])
+ else:
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ return ret
+
+
+def zeros(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and
+ all elements set to zero.
+
+ For example:
+
+ ```python
+ tf.zeros([3, 4], int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([shape], name, "zeros") as name:
+ if isinstance(shape, list):
+ output = constant(0, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(0, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def zeros_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the
+ same type and shape as `tensor` with all elements set to zero. Optionally,
+ you can use `dtype` to specify a new type for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([tensor], name, "zeros_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ zeros_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return zeros(zeros_shape, dtype=dtype, name=name)
+
+
+def ones_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the same
+ type and shape as `tensor` with all elements set to 1. Optionally, you can
+ specify a new type (`dtype`) for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([tensor], name, "ones_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ ones_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return ones(ones_shape, dtype=dtype, name=name)
+
+
+def zeros_initializer(shape, dtype=types.float32):
+ """An adaptor for zeros() to match the Initializer spec."""
+ return zeros(shape, dtype)
+
+
+def ones(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and all
+ elements set to 1.
+
+ For example:
+
+ ```python
+ tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([shape], name, "ones") as name:
+ if isinstance(shape, list):
+ output = constant(1, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(1, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def placeholder(dtype, shape=None, name=None):
+ """Inserts a placeholder for a tensor that will be always fed.
+
+ **Important**: This tensor will produce an error if evaluated. Its value must
+ be fed using the `feed_dict` optional argument to `Session.run()`,
+ `Tensor.eval()`, or `Operation.run()`.
+
+ For example:
+
+ ```python
+ x = tf.placeholder(float, shape=(1024, 1024))
+ y = tf.matmul(x, x)
+
+ with tf.Session() as sess:
+ print sess.run(y) # ERROR: will fail because x was not fed.
+
+ rand_array = np.random.rand(1024, 1024)
+ print sess.run(y, feed_dict={x: rand_array}) # Will succeed.
+ ```
+
+ Args:
+ dtype: The type of elements in the tensor to be fed.
+ shape: The shape of the tensor to be fed (optional). If the shape is not
+ specified, you can feed a tensor of any shape.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` that may be used as a handle for feeding a value, but not
+ evaluated directly.
+ """
+ shape = tensor_shape.as_shape(shape)
+ if shape.is_fully_defined():
+ dim_list = shape.as_list()
+ else:
+ dim_list = []
+ ret = gen_array_ops._placeholder(
+ dtype=dtype,
+ shape=dim_list,
+ name=name)
+ ret.set_shape(shape)
+ return ret
+
+
+@ops.RegisterShape("Placeholder")
+def _PlaceholderShape(op):
+ given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
+ if given_shape:
+ return [tensor_shape.TensorShape(given_shape)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("CheckNumerics")
+@ops.RegisterShape("Identity")
+@ops.RegisterShape("RefIdentity")
+@ops.RegisterShape("StopGradient")
+def _UnchangedShape(op):
+ return [op.inputs[0].get_shape()]
+
+
+@ops.RegisterShape("Rank")
+@ops.RegisterShape("Size")
+def _ScalarShape(unused_op):
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("Slice")
+def _SliceShape(op):
+ """Shape function for array_ops.slice."""
+ input_shape = op.inputs[0].get_shape()
+ begin_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ sizes_shape = op.inputs[2].get_shape().with_rank_at_most(1)
+ rank_vector_shape = begin_shape.merge_with(sizes_shape)
+ ndims = rank_vector_shape.num_elements()
+ if ndims is not None:
+ input_shape.assert_has_rank(ndims)
+ begin_value = tensor_util.ConstantValue(op.inputs[1])
+ sizes_value = tensor_util.ConstantValue(op.inputs[2])
+ if sizes_value is not None:
+ returned_dims = []
+ for i, slice_size in enumerate(sizes_value.ravel()):
+ if slice_size != -1:
+ returned_dims.append(slice_size)
+ elif begin_value is not None:
+ returned_dims.append(input_shape[i] - begin_value[i])
+ else:
+ returned_dims.append(None)
+ return [tensor_shape.TensorShape(returned_dims)]
+ else:
+ if input_shape.ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ elif ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=ndims)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("Gather")
+def _GatherShape(op):
+ """Shape function for array_ops.gather."""
+ params_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ return [indices_shape.concatenate(params_shape[1:])]
+
+
+@ops.RegisterShape("Unique")
+def _UniqueShape(op):
+ """Shape function for array_ops.Unique."""
+ # The output is a vector with data-dependent length.
+ input_shape = op.inputs[0].get_shape()
+ input_shape.assert_has_rank(1)
+ return [tensor_shape.vector(None), input_shape]
+
+
+@ops.RegisterShape("Diag")
+def _DiagShape(op):
+ """Shape function for array_ops.diag.
+
+ This op has one input (of rank k <= 3), and one output (of rank 2k),
+ where the shape of the output is the concatenation of the input
+ shape with itself.
+
+ Args:
+ op: A Diag Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank_at_most(3)
+ return [input_shape.concatenate(input_shape)]
+
+
+@ops.RegisterShape("ExpandDims")
+def _ExpandDimsShape(op):
+ """Determine shape for expand op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ op.inputs[0] is the input tensor.
+ op.inputs[1] is the dimension in which to expand.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: If dim is outside of [-rank - 1, rank], where rank is the number
+ of dimensions in the input tensor.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+ dim = tensor_util.ConstantValue(op.inputs[1])
+ input_ndims = input_shape.ndims
+ if dim < -input_ndims - 1 or dim > input_ndims:
+ raise ValueError(
+ "dim %d not in [%d, %d]." % (dim, -input_ndims, input_ndims))
+ if dim < 0:
+ dim += (input_ndims + 1)
+ result_shape = list(input_shape.dims)
+ result_shape.insert(dim, 1)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Squeeze")
+def _SqueezeShape(op):
+ """Determine shape for squeeze op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: if squeeze_dims includes a dimension outside of [-rank, rank),
+ where rank is the number of dimensions in the input tensor. Or, if
+ squeeze_dims includes a dimension for which input shape has a value
+ not equal to 1.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+
+ squeeze_dims = op.get_attr("squeeze_dims") or []
+ wrapped_squeeze_dims = []
+ input_ndims = input_shape.ndims
+ for i, squeeze_dim in enumerate(squeeze_dims):
+ if squeeze_dim < -input_ndims or squeeze_dim >= input_ndims:
+ raise ValueError(
+ "squeeze_dims[%d]=%d not in [%d, %d)." % (
+ i, squeeze_dim, -input_ndims, input_ndims))
+ if squeeze_dim < 0:
+ squeeze_dim += input_ndims
+ wrapped_squeeze_dims.append(squeeze_dim)
+
+ result_shape = []
+ for i, dim in enumerate([d.value for d in input_shape.dims]):
+ is_explicit_match = i in wrapped_squeeze_dims
+ if is_explicit_match or not wrapped_squeeze_dims:
+ if dim is None:
+ return [tensor_shape.unknown_shape()]
+ if dim != 1:
+ if is_explicit_match:
+ raise ValueError(
+ "Can not squeeze dim[%d], expected a dimension of 1, got %d." % (
+ i, dim))
+ result_shape.append(dim)
+ else:
+ result_shape.append(dim)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Reshape")
+def _ReshapeShape(op):
+ """Shape function for Reshape op."""
+ input_shape = op.inputs[0].get_shape()
+ new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ new_shape = tensor_util.ConstantValue(op.inputs[1])
+ if new_shape is None:
+ # Attempt to infer the rank of the output from the length of
+ # new_shape.
+ return [tensor_shape.unknown_shape(ndims=new_shape_shape.num_elements())]
+ new_shape = np.reshape(new_shape, -1).tolist()
+ if -1 not in new_shape:
+ # The new shape is fully defined.
+ return [tensor_shape.TensorShape(new_shape)]
+ elif input_shape.is_fully_defined():
+ # We know the input shape, so we can calculate the missing
+ # dimension in the new_shape.
+ num_elements = 1
+ for dim in input_shape.dims:
+ num_elements *= dim.value
+ known_elements = 1
+ unknown_index = None
+ for i, dim in enumerate(new_shape):
+ if dim == -1:
+ unknown_index = i
+ else:
+ known_elements *= dim
+ if known_elements == 0:
+ raise ValueError("cannot infer the missing input size for "
+ "an empty tensor unless all specified "
+ "input sizes are non-zero")
+ if num_elements % known_elements != 0:
+ raise ValueError("input has %s elements, which isn't divisible by %d" %
+ (num_elements, known_elements))
+ new_shape[unknown_index] = num_elements / known_elements
+ return [tensor_shape.TensorShape(new_shape)]
+ else:
+ # We don't know the input shape, but we know n-1 of the dimensions
+ # in the new shape.
+ new_shape[new_shape.index(-1)] = None
+ return [tensor_shape.TensorShape(new_shape)]
+
+
+@ops.RegisterShape("BroadcastGradientArgs")
+def _BroadcastGradientArgsShape(op):
+ """Shape function for the BroadcastGradientArgs op."""
+ # TODO(mrry): Implement ConstantValue for BroadcastGradientArgs?
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ return [tensor_shape.vector(None), tensor_shape.vector(None)]
+
+
+@ops.RegisterShape("Fill")
+def _FillShape(op):
+ """Shape function for the Fill op.
+
+ This op takes a vector of dimensions and a scalar, and produces a
+ tensor with the given dimensions.
+
+ Args:
+ op: A Fill Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ dimensions_shape = op.inputs[0].get_shape().with_rank_at_most(1)
+ op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+ fill_dims = tensor_util.ConstantValue(op.inputs[0])
+ if fill_dims is None:
+ # Attempt to infer the rank of the output from the length of
+ # dimensions.
+ return [tensor_shape.unknown_shape(ndims=dimensions_shape.num_elements())]
+ else:
+ return [tensor_shape.TensorShape(fill_dims.tolist())]
+
+
+@ops.RegisterShape("InvertPermutation")
+def _InvertPermutationShape(op):
+ """Shape function for the InvertPermutation op."""
+ return [op.inputs[0].get_shape().with_rank(1)]
+
+
+@ops.RegisterShape("ListDiff")
+def _ListDiffShape(op):
+ """Shape function for the ListDiff op."""
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ # TODO(mrry): Indicate that the length falls within an interval?
+ return [tensor_shape.vector(None)] * 2
+
+
+@ops.RegisterShape("Pad")
+def _PadShape(op):
+ """Shape function for the Pad op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * paddings: An N-by-2 matrix, in which the i^th row contains the
+ number of padding elements to add before and after `input` in the
+ i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in paddings.
+
+ Args:
+ op: A Pad Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ paddings_shape = op.inputs[1].get_shape().with_rank(2)
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims == 0 and paddings_shape[0].value == 1:
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ input_shape = tensor_shape.TensorShape([1])
+ else:
+ input_shape = input_shape.with_rank(paddings_shape[0].value)
+ paddings_shape = paddings_shape.merge_with(
+ tensor_shape.matrix(input_shape.ndims, 2))
+ paddings = tensor_util.ConstantValue(op.inputs[1])
+ if paddings is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ if paddings[i, 0] < 0 or paddings[i, 1] < 0:
+ raise ValueError("paddings must be non-negative")
+ output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("ReverseSequence")
+def _ReverseSequenceShape(op):
+ """Shape function for the ReverseSequence op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor with size B in the 0th dimension.
+ * seq_lens: A vector of length B.
+
+ It has one output, with the same size as input.
+
+ Args:
+ op: A ReverseSequence Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape()
+ seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
+ batch_size = input_shape[0].merge_with(seq_lens_shape[0])
+ input_shape = tensor_shape.TensorShape([batch_size]).concatenate(
+ input_shape[1:])
+ seq_dim = op.get_attr("seq_dim")
+ if seq_dim >= input_shape.ndims:
+ raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
+ (seq_dim, input_shape.ndims))
+ return [input_shape]
+
+
+@ops.RegisterShape("Shape")
+def _ShapeShape(op):
+ """Shape function for the Shape op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.vector(input_shape.ndims)]
+
+
+@ops.RegisterShape("Transpose")
+def _TransposeShape(op):
+ """Shape function for the Transpose op.
+
+ This op takes two inputs:
+
+ * input: a rank-N tensor of arbitrary shape.
+ * shuffle: a length-N vector.
+
+ Its output is the rank-N tensor computed by permuting the dimensions
+ of input according to shuffle.
+
+ Args:
+ op: A Transpose op.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the shapes of input and shuffle are incompatible.
+ IndexError: If shuffle contains an index that is >= the rank of input.
+ """
+ input_shape = op.inputs[0].get_shape()
+ transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
+ input_shape.ndims))
+ transpose_vec = tensor_util.ConstantValue(op.inputs[1])
+ if transpose_vec is None:
+ return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
+ else:
+ return [tensor_shape.TensorShape([input_shape[i]
+ for i in transpose_vec.tolist()])]
+
+
+@ops.RegisterShape("Split")
+def _SplitShape(op):
+ """Shape function for the Split op."""
+ split_dim = tensor_util.ConstantValue(op.inputs[0])
+ num_split = len(op.outputs)
+ input_shape = op.inputs[1].get_shape()
+ if split_dim is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] * num_split
+ else:
+ split_dim = int(split_dim)
+ input_shape = input_shape.with_rank_at_least(split_dim + 1)
+ if not (input_shape[split_dim] % num_split).is_compatible_with(0):
+ raise ValueError(
+ "Number of ways to split should evenly divide the split "
+ "dimension but got split_dim %d (size = %d) and num_split %d" %
+ (split_dim, input_shape[split_dim].value, num_split))
+ prefix = input_shape[:split_dim]
+ size_in_split_dim = input_shape[split_dim] / num_split
+ suffix = input_shape[split_dim + 1:]
+ output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
+ return [output_shape] * num_split
+
+
+@ops.RegisterShape("Tile")
+def _TileShape(op):
+ """Shape function for the Tile op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * multiples: A length-N vector, in which the i^th element contains
+ the factor by which `input` will be tiled in the i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in multiples
+
+ Args:
+ op: A Tile Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ multiples = multiples.ravel()
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim * multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("TileGrad")
+def _TileGradShape(op):
+ """Shape function for the TileGrad op."""
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim / multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("Where")
+def _WhereShape(op):
+ """Shape function for the Where op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.matrix(None, input_shape.ndims)]
+
+
+@ops.RegisterShape("ZerosLike")
+def _ZerosLikeShape(op):
+ """Shape function for the ZerosLike op."""
+ return [op.inputs[0].get_shape()]
+
+
+def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
+ """Computes the Levenshtein distance between sequences.
+
+ This operation takes variable-length sequences (`hypothesis` and `truth`),
+ each provided as a `SparseTensor`, and computes the Levenshtein distance.
+ You can normalize the edit distance by length of `truth` by setting
+ `normalize` to true.
+
+ For example, given the following input:
+
+ ```python
+ # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
+ # (0,0) = ["a"]
+ # (1,0) = ["b"]
+ hypothesis = tf.SparseTensor(
+ [[0, 0, 0],
+ [1, 0, 0]],
+ ["a", "b"]
+ (2, 1, 1))
+
+ # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
+ # (0,0) = []
+ # (0,1) = ["a"]
+ # (1,0) = ["b", "c"]
+ # (1,1) = ["a"]
+ truth = tf.SparseTensor(
+ [[0, 1, 0],
+ [1, 0, 0],
+ [1, 0, 1],
+ [1, 1, 0]]
+ ["a", "b", "c", "a"],
+ (2, 2, 2))
+
+ normalize = True
+ ```
+
+ This operation would return the following:
+
+ ```python
+ # 'output' is a tensor of shape `[2, 2]` with edit distances normalized
+ # by 'truth' lengths.
+ output ==> [[inf, 1.0], # (0,0): no truth, (0,1): no hypothesis
+ [0.5, 1.0]] # (1,0): addition, (1,1): no hypothesis
+ ```
+
+ Args:
+ hypothesis: A `SparseTensor` containing hypothesis sequences.
+ truth: A `SparseTensor` containing truth sequences.
+ normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
+ length of `truth.`
+ name: A name for the operation (optional).
+
+ Returns:
+ A dense `Tensor` with rank `R - 1`, where R is the rank of the
+ `SparseTensor` inputs `hypothesis` and `truth`.
+
+ Raises:
+ TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
+ """
+ if not isinstance(hypothesis, ops.SparseTensor):
+ raise TypeError("Hypothesis must be a SparseTensor")
+ if not isinstance(truth, ops.SparseTensor):
+ raise TypeError("Truth must be a SparseTensor")
+
+ return gen_array_ops._edit_distance(hypothesis.indices,
+ hypothesis.values,
+ hypothesis.shape,
+ truth.indices,
+ truth.values,
+ truth.shape,
+ normalize=normalize,
+ name=name)
+
+
+@ops.RegisterShape("EditDistance")
+def _EditDistanceShape(op):
+ """Shape function for the EditDistance op."""
+ hypothesis_shape = tensor_util.ConstantValue(op.inputs[2])
+ truth_shape = tensor_util.ConstantValue(op.inputs[5])
+ if hypothesis_shape is not None and truth_shape is not None:
+ if len(hypothesis_shape) != len(truth_shape):
+ raise ValueError(
+ "Inconsistent ranks in hypothesis and truth. Saw shapes: %s and %s" %
+ (str(hypothesis_shape), str(truth_shape)))
+ return [tensor_shape.TensorShape(
+ [max(h, t) for h, t in zip(hypothesis_shape[:-1], truth_shape[:-1])])]
+
+ return [tensor_shape.unknown_shape()]
+
+
+# The remaining ops do not change the shape of their inputs.
+@ops.RegisterShape("Quantize")
+@ops.RegisterShape("Dequantize")
+def _QuantizeDequantizeShape(op):
+ unused_min_range = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ unused_max_range = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
+ return common_shapes.unchanged_shape(op)