aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rwxr-xr-xtensorflow/python/ops/__init__.py0
-rw-r--r--tensorflow/python/ops/array_grad.py187
-rw-r--r--tensorflow/python/ops/array_ops.py1207
-rw-r--r--tensorflow/python/ops/attention_ops.py34
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py365
-rw-r--r--tensorflow/python/ops/clip_ops.py234
-rw-r--r--tensorflow/python/ops/common_shapes.py371
-rw-r--r--tensorflow/python/ops/constant_op.py189
-rw-r--r--tensorflow/python/ops/control_flow_grad.py100
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1561
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py88
-rw-r--r--tensorflow/python/ops/data_flow_grad.py37
-rw-r--r--tensorflow/python/ops/data_flow_ops.py680
-rw-r--r--tensorflow/python/ops/embedding_ops.py197
-rw-r--r--tensorflow/python/ops/gradients.py661
-rw-r--r--tensorflow/python/ops/gradients_test.py337
-rw-r--r--tensorflow/python/ops/image_ops.py786
-rw-r--r--tensorflow/python/ops/image_ops_test.py771
-rw-r--r--tensorflow/python/ops/init_ops.py181
-rw-r--r--tensorflow/python/ops/io_ops.py541
-rw-r--r--tensorflow/python/ops/linalg_grad.py25
-rw-r--r--tensorflow/python/ops/linalg_ops.py62
-rw-r--r--tensorflow/python/ops/logging_ops.py58
-rw-r--r--tensorflow/python/ops/math_grad.py506
-rw-r--r--tensorflow/python/ops/math_ops.py1201
-rw-r--r--tensorflow/python/ops/math_ops_test.py68
-rw-r--r--tensorflow/python/ops/nn.py816
-rw-r--r--tensorflow/python/ops/nn_grad.py229
-rw-r--r--tensorflow/python/ops/nn_ops.py365
-rw-r--r--tensorflow/python/ops/nn_test.py882
-rw-r--r--tensorflow/python/ops/numerics.py50
-rw-r--r--tensorflow/python/ops/op_def_library.py640
-rw-r--r--tensorflow/python/ops/op_def_library_test.py1402
-rw-r--r--tensorflow/python/ops/parsing_ops.py390
-rw-r--r--tensorflow/python/ops/random_ops.py181
-rw-r--r--tensorflow/python/ops/sparse_grad.py12
-rw-r--r--tensorflow/python/ops/sparse_ops.py458
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py212
-rw-r--r--tensorflow/python/ops/standard_ops.py41
-rw-r--r--tensorflow/python/ops/state_grad.py18
-rw-r--r--tensorflow/python/ops/state_ops.py189
-rw-r--r--tensorflow/python/ops/string_ops.py12
-rw-r--r--tensorflow/python/ops/summary_ops.py177
-rw-r--r--tensorflow/python/ops/variable_scope.py333
-rw-r--r--tensorflow/python/ops/variables.py569
45 files changed, 17423 insertions, 0 deletions
diff --git a/tensorflow/python/ops/__init__.py b/tensorflow/python/ops/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/ops/__init__.py
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
new file mode 100644
index 0000000000..2a463940d6
--- /dev/null
+++ b/tensorflow/python/ops/array_grad.py
@@ -0,0 +1,187 @@
+"""Gradients for operators defined in array_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import gen_array_ops
+
+
+@ops.RegisterGradient("Pack")
+def _PackGrad(op, grad):
+ """Gradient for pack op."""
+ return array_ops.unpack(grad, num=op.get_attr('N'))
+
+
+@ops.RegisterGradient("Unpack")
+def _UnpackGrad(_, *grads):
+ """Gradient for unpack op."""
+ return array_ops.pack(grads)
+
+
+@ops.RegisterGradient("Concat")
+def _ConcatGrad(op, grad):
+ """Gradient for concat op."""
+ assert isinstance(grad, ops.Tensor)
+ # Degenerate concatenation, just return grad.
+ if len(op.inputs) == 2:
+ return [None, grad]
+ # Get the inputs' tensor shapes
+ sizes = [array_ops.shape(x) for x in op.inputs[1:]]
+ concat_dim = op.inputs[0]
+ # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
+ shape_of_shape = array_ops.shape(sizes[0])
+ # Make a vector of length equal to the input's dimensions,
+ # with 0's everywhere and 1 in the concat dim position.
+ # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
+ mask = array_ops.concat(0,
+ [array_ops.fill(
+ array_ops.expand_dims(concat_dim, 0), 0), [1],
+ array_ops.fill(shape_of_shape - concat_dim - 1, 0)])
+ out_grads = []
+ begin = array_ops.fill(shape_of_shape, 0)
+ for i in range(len(sizes)):
+ out_grads.append(array_ops.slice(grad, begin, sizes[i]))
+ # Lint complains begin = begin + ...
+ begin = math_ops.add(begin, sizes[i] * mask)
+ return [None] + out_grads
+
+
+@ops.RegisterGradient("Slice")
+def _SliceGrad(op, grad):
+ """Gradient for Slice op."""
+ # Create an Nx2 padding where the first column represents how many
+ # zeros are to be prepended for each dimension, and the second
+ # column indicates how many zeros are appended.
+ #
+ # The number of zeros to append is the shape of the input
+ # elementwise-subtracted by both the begin vector and sizes vector.
+ #
+ # Some more reshaping is needed to assemble this tensor with the
+ # right dimensions.
+ input_vec = op.inputs[0]
+ begin_vec = op.inputs[1]
+ input_rank = array_ops.rank(input_vec)
+ slice_size = array_ops.shape(op.outputs[0])
+
+ shape = array_ops.pack([input_rank, 1])
+ before_pad = array_ops.reshape(begin_vec, shape)
+ after_pad = array_ops.reshape(
+ array_ops.shape(input_vec) - slice_size - begin_vec, shape)
+ paddings = array_ops.concat(1, [before_pad, after_pad])
+ return array_ops.pad(grad, paddings), None, None
+
+
+@ops.RegisterGradient("Split")
+def _SplitGrad(op, *grads):
+ return None, array_ops.concat(op.inputs[0], list(grads))
+
+
+ops.NoGradient("Const")
+
+# TODO(liqzhang): The gradient for Diag operator would be
+# the diagonal of the backprop. Implement if there is a need.
+ops.NoGradient("Diag")
+
+# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
+ops.NoGradient("EditDistance")
+
+ops.NoGradient("Fill")
+
+
+@ops.RegisterGradient("Gather")
+def _GatherGrad(op, grad):
+ return [
+ ops.IndexedSlices(grad, op.inputs[1], array_ops.shape(op.inputs[0])), None
+ ]
+
+
+@ops.RegisterGradient("Identity")
+def _IdGrad(_, grad):
+ return grad
+
+
+@ops.RegisterGradient("RefIdentity")
+def _RefIdGrad(_, grad):
+ return grad
+
+
+ops.NoGradient("StopGradient")
+
+
+@ops.RegisterGradient("Reshape")
+def _ReshapeGrad(op, grad):
+ return [array_ops.reshape(grad, array_ops.shape(op.inputs[0])), None]
+
+
+ops.NoGradient("InvertPermutation")
+
+
+def _ReshapeToInput(op, grad):
+ """Reshapes the gradient to the shape of the original input."""
+ return array_ops.reshape(grad, array_ops.shape(op.inputs[0]))
+
+
+@ops.RegisterGradient("ExpandDims")
+def _ExpandDimsGrad(op, grad):
+ return [_ReshapeToInput(op, grad), None]
+
+
+@ops.RegisterGradient("Squeeze")
+def _SqueezeGrad(op, grad):
+ return _ReshapeToInput(op, grad)
+
+
+@ops.RegisterGradient("Transpose")
+def _TransposeGrad(op, grad):
+ """Returns unshuffle(grad)."""
+ p = op.inputs[1]
+ return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
+
+
+ops.NoGradient("Shape")
+
+
+ops.NoGradient("Rank")
+
+
+ops.NoGradient("Size")
+
+
+@ops.RegisterGradient("Tile")
+def _TileGrad(op, grad):
+ """Sum reduces grad along the tiled dimensions."""
+ assert isinstance(grad, ops.Tensor)
+ return [gen_array_ops._tile_grad(grad, op.inputs[1]), None]
+
+
+ops.NoGradient("TileGrad")
+
+
+ops.NoGradient("BroadcastGradientArgs")
+
+
+@ops.RegisterGradient("Pad")
+def _PadGrad(op, grad):
+ """Gradient for Pad."""
+ # Pad introduces values around the original tensor, so the gradient function
+ # slices the original shape out of the gradient."""
+ x = op.inputs[0]
+ a = op.inputs[1] # [Rank(x), 2]
+ # Takes a slice of a. The 1st column. [Rank(x), 1].
+ pad_before = array_ops.slice(a, [0, 0],
+ array_ops.pack([array_ops.rank(x), 1]))
+ # Make it a 1-D tensor.
+ begin = array_ops.reshape(pad_before, [-1])
+ sizes = array_ops.shape(x)
+ return array_ops.slice(grad, begin, sizes), None
+
+
+# ReverseSequence is just a permutation. The gradient permutes back.
+@ops.RegisterGradient("ReverseSequence")
+def _ReverseSequenceGrad(op, grad):
+ seq_lengths = op.inputs[1]
+ return [array_ops.reverse_sequence(grad,
+ seq_dim=op.get_attr("seq_dim"),
+ seq_lengths=seq_lengths),
+ None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
new file mode 100644
index 0000000000..ed780db625
--- /dev/null
+++ b/tensorflow/python/ops/array_ops.py
@@ -0,0 +1,1207 @@
+"""## Casting
+
+TensorFlow provides several operations that you can use to cast tensor data
+types in your graph.
+
+@@string_to_number
+@@to_double
+@@to_float
+@@to_bfloat16
+@@to_int32
+@@to_int64
+@@cast
+
+## Shapes and Shaping
+
+TensorFlow provides several operations that you can use to determine the shape
+of a tensor and change the shape of a tensor.
+
+@@shape
+@@size
+@@rank
+@@reshape
+@@squeeze
+@@expand_dims
+
+## Slicing and Joining
+
+TensorFlow provides several operations to slice or extract parts of a tensor,
+or join multiple tensors together.
+
+@@slice
+@@split
+@@tile
+@@pad
+@@concat
+@@pack
+@@unpack
+@@reverse_sequence
+@@reverse
+@@transpose
+@@gather
+@@dynamic_partition
+@@dynamic_stitch
+"""
+import sys
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+# pylint: disable=wildcard-import
+# 'Constant' gets imported in the module 'array_ops'.
+from tensorflow.python.ops.constant_op import constant
+from tensorflow.python.ops.gen_array_ops import *
+
+
+# We override the 'slice' for the "slice" op, so we keep python's
+# existing 'slice' for later use in this module.
+_baseslice = slice
+
+
+# Aliases for some automatically-generated names.
+listdiff = gen_array_ops.list_diff
+
+
+# pylint: disable=undefined-variable,protected-access
+def _SliceHelper(tensor, slice_spec):
+ """Overload for Tensor.__getitem__.
+
+ Currently the size of the slice must be statically known in each dimension,
+ i.e. the "stop" of the slice must not be omitted.
+
+ TODO(mrry): Support slices where the sizes are not specified.
+ TODO(mrry): Support negative indices in slices with numpy/Python semantics.
+
+ Args:
+ tensor: An ops.Tensor object.
+ slice_spec: The arguments to Tensor.__getitem__.
+
+ Returns:
+ The appropriate slice of "tensor", based on "slice_spec".
+
+ Raises:
+ ValueError: If a slice range is negative size.
+ TypeError: If the slice indices aren't int, slice, or Ellipsis.
+ """
+ if not isinstance(slice_spec, (list, tuple)):
+ slice_spec = [slice_spec]
+ indices = []
+ sizes = []
+ squeeze_dims = []
+ for dim, s in enumerate(slice_spec):
+ if isinstance(s, int):
+ if s < 0:
+ raise NotImplementedError("Negative indices are currently unsupported")
+ indices.append(s)
+ sizes.append(1)
+ squeeze_dims.append(dim)
+ elif isinstance(s, _baseslice):
+ if s.step not in (None, 1):
+ raise NotImplementedError(
+ "Steps other than 1 are not currently supported")
+ start = s.start if s.start is not None else 0
+ if start < 0:
+ raise NotImplementedError(
+ "Negative start indices are not currently supported")
+ indices.append(start)
+ if s.stop is not None and s.stop < 0:
+ raise NotImplementedError(
+ "Negative stop indices are not currently supported")
+ # NOTE(mrry): If the stop is not specified, Python substitutes
+ # sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently
+ # supports signed DT_INT32 arguments, we use -1 to specify that all
+ # elements should be captured.
+ if s.stop is None or s.stop == sys.maxsize:
+ sizes.append(-1)
+ else:
+ if start > s.stop:
+ raise ValueError("Stop must be at least start")
+ sizes.append(s.stop - start)
+ elif s is Ellipsis:
+ raise NotImplementedError("Ellipsis is not currently supported")
+ else:
+ raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
+ sliced = slice(tensor, indices, sizes)
+ if squeeze_dims:
+ return squeeze(sliced, squeeze_dims=squeeze_dims)
+ else:
+ return sliced
+
+
+def slice(input_, begin, size, name=None):
+ """Extracts a slice from a tensor.
+
+ This operation extracts a slice of size `size` from a tensor `input` starting
+ at the location specified by `begin`. The slice `size` is represented as a
+ tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
+ of `input` that you want to slice. The starting location (`begin`) for the
+ slice is represented as an offset in each dimension of `input`. In other
+ words, `begin[i]` is the offset into the 'i'th dimension of `input` that you
+ want to slice from.
+
+ `begin` is zero-based; `size` is one-based. If `size[i]` is -1,
+ all remaining elements in dimension i are included in the
+ slice. In other words, this is equivalent to setting:
+
+ `size[i] = input.dim_size(i) - begin[i]`
+
+ This operation requires that:
+
+ `0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]`
+
+ For example:
+
+ ```
+ # 'input' is [[[1, 1, 1], [2, 2, 2]],
+ # [[3, 3, 3], [4, 4, 4]],
+ # [[5, 5, 5], [6, 6, 6]]]
+ tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
+ tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
+ [4, 4, 4]]]
+ tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
+ [[5, 5, 5]]]
+ ```
+
+ Args:
+ input_: A `Tensor`.
+ begin: An `int32` or `int64` `Tensor`.
+ size: An `int32` or `int64` `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` the same type as `input`.
+ """
+ return gen_array_ops._slice(input_, begin, size, name=name)
+
+
+ops.Tensor._override_operator("__getitem__", _SliceHelper)
+
+
+def pack(values, name="pack"):
+ """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
+
+ Packs tensors in `values` into a tensor with rank one higher than each tensor
+ in `values` and shape `[len(values)] + values[0].shape`. The output satisfies
+ `output[i, ...] = values[i][...]`.
+
+ This is the opposite of unpack. The numpy equivalent is
+
+ tf.pack([x, y, z]) = np.asarray([x, y, z])
+
+ Args:
+ values: A list of `Tensor` objects with the same shape and type.
+ name: A name for this operation (optional).
+
+ Returns:
+ output: A packed `Tensor` with the same type as `values`.
+ """
+ return gen_array_ops._pack(values, name=name)
+
+
+def unpack(value, num=None, name="unpack"):
+ """Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
+
+ Unpacks `num` tensors from `value` along the first dimension.
+ If `num` is not specified (the default), it is inferred from `value`'s shape.
+ If `value.shape[0]` is not known, `ValueError` is raised.
+
+ The ith tensor in `output` is the slice `value[i, ...]`. Each tensor in
+ `output` has shape `value.shape[1:]`.
+
+ This is the opposite of pack. The numpy equivalent is
+
+ tf.unpack(x, n) = list(x)
+
+ Args:
+ value: A rank `R > 0` `Tensor` to be unpacked.
+ num: An `int`. The first dimension of value. Automatically inferred if
+ `None` (the default).
+ name: A name for the operation (optional).
+
+ Returns:
+ The list of `Tensor` objects unpacked from `value`.
+
+ Raises:
+ ValueError: If `num` is unspecified and cannot be inferred.
+ """
+ if num is None:
+ value = ops.convert_to_tensor(value)
+ shape = value.get_shape()
+ num = shape[0].value
+ if num is None:
+ raise ValueError("Cannot infer num from shape %s" % shape)
+ return gen_array_ops._unpack(value, num=num, name=name)
+
+
+def concat(concat_dim, values, name="concat"):
+ """Concatenates tensors along one dimension.
+
+ Concatenates the list of tensors `values` along dimension `concat_dim`. If
+ `values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn]`, the concatenated
+ result has shape
+
+ [D0, D1, ... Rconcat_dim, ...Dn]
+
+ where
+
+ Rconcat_dim = sum(Dconcat_dim(i))
+
+ That is, the data from the input tensors is joined along the `concat_dim`
+ dimension.
+
+ The number of dimensions of the input tensors must match, and all dimensions
+ except `concat_dim` must be equal.
+
+ For example:
+
+ ```python
+ t1 = [[1, 2, 3], [4, 5, 6]]
+ t2 = [[7, 8, 9], [10, 11, 12]]
+ tf.concat(0, [t1, t2]) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
+ tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
+
+ # tensor t3 with shape [2, 3]
+ # tensor t4 with shape [2, 3]
+ tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
+ tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]
+ ```
+
+ Args:
+ concat_dim: 0-D `int32` `Tensor`. Dimension along which to concatenate.
+ values: A list of `Tensor` objects or a single `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` resulting from concatenation of the input tensors.
+ """
+ if not isinstance(values, (list)):
+ values = [values]
+ # TODO(mrry): Change to return values?
+ if len(values) == 1: # Degenerate case of one tensor.
+ return identity(values[0], name=name)
+ return gen_array_ops._concat(concat_dim=concat_dim,
+ values=values,
+ name=name)
+
+
+@ops.RegisterShape("Pack")
+def _PackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ for inp in op.inputs[1:]:
+ input_shape = input_shape.merge_with(inp.get_shape())
+ return [tensor_shape.TensorShape([len(op.inputs)]).concatenate(input_shape)]
+
+
+@ops.RegisterShape("Unpack")
+def _UnpackShape(op):
+ input_shape = op.inputs[0].get_shape()
+ return [input_shape[1:]] * op.get_attr("num")
+
+
+@ops.RegisterShape("Concat")
+def _ConcatShape(op):
+ concat_dim = tensor_util.ConstantValue(op.inputs[0])
+ if concat_dim is None:
+ # Return an unknown shape with the same rank as the inputs, or an
+ # unknown rank if no input's rank is known.
+ rank = None
+ for value in op.inputs[1:]:
+ if rank is not None:
+ value.get_shape().assert_has_rank(rank)
+ else:
+ rank = value.get_shape().ndims
+ return [tensor_shape.unknown_shape(ndims=max(rank, 1))]
+
+ else:
+ # Merge all the non-concat dims, and sum the concat dim to make an
+ # output shape.
+ concat_dim = int(concat_dim)
+ output_shape = op.inputs[1].get_shape()
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ if output_shape.ndims == 0:
+ output_shape = tensor_shape.TensorShape([1])
+ for value in op.inputs[2:]:
+ value_shape = value.get_shape()
+ if value_shape.ndims is not None and concat_dim >= value_shape.ndims:
+ if value_shape.ndims == 0 and concat_dim == 0:
+ # Let concat handle scalars
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ value_shape = tensor_shape.TensorShape([1])
+ else:
+ raise ValueError("concat_dim is out of range (values rank = %d)" %
+ value_shape.ndims)
+ before = output_shape[:concat_dim].merge_with(value_shape[:concat_dim])
+ at = output_shape[concat_dim] + value_shape[concat_dim]
+ after = output_shape[
+ concat_dim + 1:].merge_with(value_shape[concat_dim + 1:])
+ output_shape = before.concatenate(at).concatenate(after)
+ return [output_shape]
+
+
+def sparse_mask(a, mask_indices, name=None):
+ """Masks elements of `IndexedSlices`.
+
+ Given an `IndexedSlices` instance `a`, returns another `IndexedSlices` that
+ contains a subset of the slices of `a`. Only the slices at indices specified
+ in `mask_indices` are returned.
+
+ This is useful when you need to extract a subset of slices in an
+ `IndexedSlices` object.
+
+ For example:
+
+ ```python
+ # `a` contains slices at indices [12, 26, 37, 45] from a large tensor
+ # with shape [1000, 10]
+ a.indices => [12, 26, 37, 45]
+ tf.shape(a.values) => [4, 10]
+
+ # `b` will be the subset of `a` slices at its second and third indices, so
+ # we want to mask of its first and last indices (which are at absolute
+ # indices 12, 45)
+ b = tf.sparse_mask(a, [12, 45])
+
+ b.indices => [26, 37]
+ tf.shape(b.values) => [2, 10]
+
+ ```
+
+ Args:
+ * `a`: An `IndexedSlices` instance.
+ * `mask_indices`: Indices of elements to mask.
+ * `name`: A name for the operation (optional).
+
+ Returns:
+ The masked `IndexedSlices` instance.
+ """
+ with ops.op_scope([a, mask_indices], name, "sparse_mask") as name:
+ indices = a.indices
+ out_indices, to_gather = listdiff(indices, mask_indices)
+ out_values = gather(a.values, to_gather, name=name)
+ return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
+
+
+def split(split_dim, num_split, value, name="split"):
+ """Splits a tensor into `num_split` tensors along one dimension.
+
+ Splits `value` along dimension `split_dim` into `num_split` smaller tensors.
+ Requires that `num_split` evenly divide `value.shape[split_dim]`.
+
+ For example:
+
+ ```python
+ # 'value' is a tensor with shape [5, 30]
+ # Split 'value' into 3 tensors along dimension 1
+ split0, split1, split2 = tf.split(1, 3, value)
+ tf.shape(split0) ==> [5, 10]
+ ```
+
+ Args:
+ split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
+ Must be in the range `[0, rank(value))`.
+ num_split: A 0-D `int32` `Tensor`. The number of ways to split.
+ value: The `Tensor` to split.
+ name: A name for the operation (optional).
+
+ Returns:
+ `num_split` `Tensor` objects resulting from splitting `value`.
+ """
+ return gen_array_ops._split(split_dim=split_dim,
+ num_split=num_split,
+ value=value,
+ name=name)
+
+
+@ops.RegisterShape("Reverse")
+def _ReverseShape(op):
+ return [op.inputs[0].get_shape().with_rank_at_most(8)]
+
+
+def transpose(a, perm=None, name="transpose"):
+ """Transposes `a`. Permutes the dimensions according to `perm`.
+
+ The returned tensor's dimension i will correspond to the input dimension
+ `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
+ the rank of the input tensor. Hence by default, this operation performs a
+ regular matrix transpose on 2-D input Tensors.
+
+ For example:
+
+ ```python
+ # 'x' is [[1 2 3]
+ # [4 5 6]]
+ tf.transpose(x) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # Equivalently
+ tf.transpose(x perm=[0, 1]) ==> [[1 4]
+ [2 5]
+ [3 6]]
+
+ # 'perm' is more useful for n-dimensional tensors, for n > 2
+ # 'x' is [[[1 2 3]
+ # [4 5 6]]
+ # [[7 8 9]
+ # [10 11 12]]]
+ # Take the transpose of the matrices in dimension-0
+ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
+ [2 5]
+ [3 6]]
+
+ [[7 10]
+ [8 11]
+ [9 12]]]
+ ```
+
+ Args:
+ a: A `Tensor`.
+ perm: A permutation of the dimensions of `a`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A transposed `Tensor`.
+ """
+ with ops.op_scope([a], name, "transpose") as name:
+ if perm is None:
+ dims = gen_math_ops._range(0, gen_array_ops.rank(a), 1)
+ perm = gen_array_ops.reverse(dims, [True])
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ # NOTE(mrry): Setting the shape explicitly because
+ # reverse is not handled by the shape function.
+ input_shape = ret.op.inputs[0].get_shape().dims
+ if input_shape is not None:
+ ret.set_shape(input_shape[::-1])
+ else:
+ ret = gen_array_ops.transpose(a, perm, name=name)
+ return ret
+
+
+def zeros(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and
+ all elements set to zero.
+
+ For example:
+
+ ```python
+ tf.zeros([3, 4], int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([shape], name, "zeros") as name:
+ if isinstance(shape, list):
+ output = constant(0, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(0, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def zeros_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to zero.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the
+ same type and shape as `tensor` with all elements set to zero. Optionally,
+ you can use `dtype` to specify a new type for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to zero.
+ """
+ with ops.op_scope([tensor], name, "zeros_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ zeros_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return zeros(zeros_shape, dtype=dtype, name=name)
+
+
+def ones_like(tensor, dtype=None, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ Given a single tensor (`tensor`), this operation returns a tensor of the same
+ type and shape as `tensor` with all elements set to 1. Optionally, you can
+ specify a new type (`dtype`) for the returned tensor.
+
+ For example:
+
+ ```python
+ # 'tensor' is [[1, 2, 3], [4, 5, 6]]
+ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ tensor: A `Tensor`.
+ dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
+ `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([tensor], name, "ones_like") as name:
+ tensor = ops.convert_to_tensor(tensor, name="tensor")
+ ones_shape = shape(tensor)
+ if dtype is None:
+ dtype = tensor.dtype
+ return ones(ones_shape, dtype=dtype, name=name)
+
+
+def zeros_initializer(shape, dtype=types.float32):
+ """An adaptor for zeros() to match the Initializer spec."""
+ return zeros(shape, dtype)
+
+
+def ones(shape, dtype=types.float32, name=None):
+ """Creates a tensor with all elements set to 1.
+
+ This operation returns a tensor of type `dtype` with shape `shape` and all
+ elements set to 1.
+
+ For example:
+
+ ```python
+ tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
+ ```
+
+ Args:
+ shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
+ dtype: The type of an element in the resulting `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with all elements set to 1.
+ """
+ with ops.op_scope([shape], name, "ones") as name:
+ if isinstance(shape, list):
+ output = constant(1, shape=shape, dtype=dtype, name=name)
+ else:
+ shape = ops.convert_to_tensor(shape, name="shape")
+ output = fill(shape, constant(1, dtype=dtype), name=name)
+ assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
+ return output
+
+
+def placeholder(dtype, shape=None, name=None):
+ """Inserts a placeholder for a tensor that will be always fed.
+
+ **Important**: This tensor will produce an error if evaluated. Its value must
+ be fed using the `feed_dict` optional argument to `Session.run()`,
+ `Tensor.eval()`, or `Operation.run()`.
+
+ For example:
+
+ ```python
+ x = tf.placeholder(float, shape=(1024, 1024))
+ y = tf.matmul(x, x)
+
+ with tf.Session() as sess:
+ print sess.run(y) # ERROR: will fail because x was not fed.
+
+ rand_array = np.random.rand(1024, 1024)
+ print sess.run(y, feed_dict={x: rand_array}) # Will succeed.
+ ```
+
+ Args:
+ dtype: The type of elements in the tensor to be fed.
+ shape: The shape of the tensor to be fed (optional). If the shape is not
+ specified, you can feed a tensor of any shape.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` that may be used as a handle for feeding a value, but not
+ evaluated directly.
+ """
+ shape = tensor_shape.as_shape(shape)
+ if shape.is_fully_defined():
+ dim_list = shape.as_list()
+ else:
+ dim_list = []
+ ret = gen_array_ops._placeholder(
+ dtype=dtype,
+ shape=dim_list,
+ name=name)
+ ret.set_shape(shape)
+ return ret
+
+
+@ops.RegisterShape("Placeholder")
+def _PlaceholderShape(op):
+ given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
+ if given_shape:
+ return [tensor_shape.TensorShape(given_shape)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("CheckNumerics")
+@ops.RegisterShape("Identity")
+@ops.RegisterShape("RefIdentity")
+@ops.RegisterShape("StopGradient")
+def _UnchangedShape(op):
+ return [op.inputs[0].get_shape()]
+
+
+@ops.RegisterShape("Rank")
+@ops.RegisterShape("Size")
+def _ScalarShape(unused_op):
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("Slice")
+def _SliceShape(op):
+ """Shape function for array_ops.slice."""
+ input_shape = op.inputs[0].get_shape()
+ begin_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ sizes_shape = op.inputs[2].get_shape().with_rank_at_most(1)
+ rank_vector_shape = begin_shape.merge_with(sizes_shape)
+ ndims = rank_vector_shape.num_elements()
+ if ndims is not None:
+ input_shape.assert_has_rank(ndims)
+ begin_value = tensor_util.ConstantValue(op.inputs[1])
+ sizes_value = tensor_util.ConstantValue(op.inputs[2])
+ if sizes_value is not None:
+ returned_dims = []
+ for i, slice_size in enumerate(sizes_value.ravel()):
+ if slice_size != -1:
+ returned_dims.append(slice_size)
+ elif begin_value is not None:
+ returned_dims.append(input_shape[i] - begin_value[i])
+ else:
+ returned_dims.append(None)
+ return [tensor_shape.TensorShape(returned_dims)]
+ else:
+ if input_shape.ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ elif ndims is not None:
+ return [tensor_shape.unknown_shape(ndims=ndims)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("Gather")
+def _GatherShape(op):
+ """Shape function for array_ops.gather."""
+ params_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ return [indices_shape.concatenate(params_shape[1:])]
+
+
+@ops.RegisterShape("Unique")
+def _UniqueShape(op):
+ """Shape function for array_ops.Unique."""
+ # The output is a vector with data-dependent length.
+ input_shape = op.inputs[0].get_shape()
+ input_shape.assert_has_rank(1)
+ return [tensor_shape.vector(None), input_shape]
+
+
+@ops.RegisterShape("Diag")
+def _DiagShape(op):
+ """Shape function for array_ops.diag.
+
+ This op has one input (of rank k <= 3), and one output (of rank 2k),
+ where the shape of the output is the concatenation of the input
+ shape with itself.
+
+ Args:
+ op: A Diag Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank_at_most(3)
+ return [input_shape.concatenate(input_shape)]
+
+
+@ops.RegisterShape("ExpandDims")
+def _ExpandDimsShape(op):
+ """Determine shape for expand op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ op.inputs[0] is the input tensor.
+ op.inputs[1] is the dimension in which to expand.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: If dim is outside of [-rank - 1, rank], where rank is the number
+ of dimensions in the input tensor.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+ dim = tensor_util.ConstantValue(op.inputs[1])
+ input_ndims = input_shape.ndims
+ if dim < -input_ndims - 1 or dim > input_ndims:
+ raise ValueError(
+ "dim %d not in [%d, %d]." % (dim, -input_ndims, input_ndims))
+ if dim < 0:
+ dim += (input_ndims + 1)
+ result_shape = list(input_shape.dims)
+ result_shape.insert(dim, 1)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Squeeze")
+def _SqueezeShape(op):
+ """Determine shape for squeeze op's output tensor.
+
+ Args:
+ op: Operation for which to determine shape.
+ Returns:
+ Shape of op's output tensor.
+ Raises:
+ ValueError: if squeeze_dims includes a dimension outside of [-rank, rank),
+ where rank is the number of dimensions in the input tensor. Or, if
+ squeeze_dims includes a dimension for which input shape has a value
+ not equal to 1.
+ """
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.dims is None:
+ return [tensor_shape.unknown_shape()]
+
+ squeeze_dims = op.get_attr("squeeze_dims") or []
+ wrapped_squeeze_dims = []
+ input_ndims = input_shape.ndims
+ for i, squeeze_dim in enumerate(squeeze_dims):
+ if squeeze_dim < -input_ndims or squeeze_dim >= input_ndims:
+ raise ValueError(
+ "squeeze_dims[%d]=%d not in [%d, %d)." % (
+ i, squeeze_dim, -input_ndims, input_ndims))
+ if squeeze_dim < 0:
+ squeeze_dim += input_ndims
+ wrapped_squeeze_dims.append(squeeze_dim)
+
+ result_shape = []
+ for i, dim in enumerate([d.value for d in input_shape.dims]):
+ is_explicit_match = i in wrapped_squeeze_dims
+ if is_explicit_match or not wrapped_squeeze_dims:
+ if dim is None:
+ return [tensor_shape.unknown_shape()]
+ if dim != 1:
+ if is_explicit_match:
+ raise ValueError(
+ "Can not squeeze dim[%d], expected a dimension of 1, got %d." % (
+ i, dim))
+ result_shape.append(dim)
+ else:
+ result_shape.append(dim)
+ return [tensor_shape.TensorShape(result_shape)]
+
+
+@ops.RegisterShape("Reshape")
+def _ReshapeShape(op):
+ """Shape function for Reshape op."""
+ input_shape = op.inputs[0].get_shape()
+ new_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ new_shape = tensor_util.ConstantValue(op.inputs[1])
+ if new_shape is None:
+ # Attempt to infer the rank of the output from the length of
+ # new_shape.
+ return [tensor_shape.unknown_shape(ndims=new_shape_shape.num_elements())]
+ new_shape = np.reshape(new_shape, -1).tolist()
+ if -1 not in new_shape:
+ # The new shape is fully defined.
+ return [tensor_shape.TensorShape(new_shape)]
+ elif input_shape.is_fully_defined():
+ # We know the input shape, so we can calculate the missing
+ # dimension in the new_shape.
+ num_elements = 1
+ for dim in input_shape.dims:
+ num_elements *= dim.value
+ known_elements = 1
+ unknown_index = None
+ for i, dim in enumerate(new_shape):
+ if dim == -1:
+ unknown_index = i
+ else:
+ known_elements *= dim
+ if known_elements == 0:
+ raise ValueError("cannot infer the missing input size for "
+ "an empty tensor unless all specified "
+ "input sizes are non-zero")
+ if num_elements % known_elements != 0:
+ raise ValueError("input has %s elements, which isn't divisible by %d" %
+ (num_elements, known_elements))
+ new_shape[unknown_index] = num_elements / known_elements
+ return [tensor_shape.TensorShape(new_shape)]
+ else:
+ # We don't know the input shape, but we know n-1 of the dimensions
+ # in the new shape.
+ new_shape[new_shape.index(-1)] = None
+ return [tensor_shape.TensorShape(new_shape)]
+
+
+@ops.RegisterShape("BroadcastGradientArgs")
+def _BroadcastGradientArgsShape(op):
+ """Shape function for the BroadcastGradientArgs op."""
+ # TODO(mrry): Implement ConstantValue for BroadcastGradientArgs?
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ return [tensor_shape.vector(None), tensor_shape.vector(None)]
+
+
+@ops.RegisterShape("Fill")
+def _FillShape(op):
+ """Shape function for the Fill op.
+
+ This op takes a vector of dimensions and a scalar, and produces a
+ tensor with the given dimensions.
+
+ Args:
+ op: A Fill Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ dimensions_shape = op.inputs[0].get_shape().with_rank_at_most(1)
+ op.inputs[1].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+ fill_dims = tensor_util.ConstantValue(op.inputs[0])
+ if fill_dims is None:
+ # Attempt to infer the rank of the output from the length of
+ # dimensions.
+ return [tensor_shape.unknown_shape(ndims=dimensions_shape.num_elements())]
+ else:
+ return [tensor_shape.TensorShape(fill_dims.tolist())]
+
+
+@ops.RegisterShape("InvertPermutation")
+def _InvertPermutationShape(op):
+ """Shape function for the InvertPermutation op."""
+ return [op.inputs[0].get_shape().with_rank(1)]
+
+
+@ops.RegisterShape("ListDiff")
+def _ListDiffShape(op):
+ """Shape function for the ListDiff op."""
+ op.inputs[0].get_shape().assert_has_rank(1)
+ op.inputs[1].get_shape().assert_has_rank(1)
+ # TODO(mrry): Indicate that the length falls within an interval?
+ return [tensor_shape.vector(None)] * 2
+
+
+@ops.RegisterShape("Pad")
+def _PadShape(op):
+ """Shape function for the Pad op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * paddings: An N-by-2 matrix, in which the i^th row contains the
+ number of padding elements to add before and after `input` in the
+ i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in paddings.
+
+ Args:
+ op: A Pad Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ paddings_shape = op.inputs[1].get_shape().with_rank(2)
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims == 0 and paddings_shape[0].value == 1:
+ # TODO(irving): Remove once !kAllowLegacyScalars.
+ input_shape = tensor_shape.TensorShape([1])
+ else:
+ input_shape = input_shape.with_rank(paddings_shape[0].value)
+ paddings_shape = paddings_shape.merge_with(
+ tensor_shape.matrix(input_shape.ndims, 2))
+ paddings = tensor_util.ConstantValue(op.inputs[1])
+ if paddings is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ if paddings[i, 0] < 0 or paddings[i, 1] < 0:
+ raise ValueError("paddings must be non-negative")
+ output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("ReverseSequence")
+def _ReverseSequenceShape(op):
+ """Shape function for the ReverseSequence op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor with size B in the 0th dimension.
+ * seq_lens: A vector of length B.
+
+ It has one output, with the same size as input.
+
+ Args:
+ op: A ReverseSequence Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the input shapes are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape()
+ seq_lens_shape = op.inputs[1].get_shape().with_rank(1)
+ batch_size = input_shape[0].merge_with(seq_lens_shape[0])
+ input_shape = tensor_shape.TensorShape([batch_size]).concatenate(
+ input_shape[1:])
+ seq_dim = op.get_attr("seq_dim")
+ if seq_dim >= input_shape.ndims:
+ raise ValueError("seq_dim must be < input.dims() (%d vs %d)" %
+ (seq_dim, input_shape.ndims))
+ return [input_shape]
+
+
+@ops.RegisterShape("Shape")
+def _ShapeShape(op):
+ """Shape function for the Shape op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.vector(input_shape.ndims)]
+
+
+@ops.RegisterShape("Transpose")
+def _TransposeShape(op):
+ """Shape function for the Transpose op.
+
+ This op takes two inputs:
+
+ * input: a rank-N tensor of arbitrary shape.
+ * shuffle: a length-N vector.
+
+ Its output is the rank-N tensor computed by permuting the dimensions
+ of input according to shuffle.
+
+ Args:
+ op: A Transpose op.
+
+ Returns:
+ A single-element list containing the shape of the output.
+
+ Raises:
+ ValueError: If the shapes of input and shuffle are incompatible.
+ IndexError: If shuffle contains an index that is >= the rank of input.
+ """
+ input_shape = op.inputs[0].get_shape()
+ transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
+ input_shape.ndims))
+ transpose_vec = tensor_util.ConstantValue(op.inputs[1])
+ if transpose_vec is None:
+ return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
+ else:
+ return [tensor_shape.TensorShape([input_shape[i]
+ for i in transpose_vec.tolist()])]
+
+
+@ops.RegisterShape("Split")
+def _SplitShape(op):
+ """Shape function for the Split op."""
+ split_dim = tensor_util.ConstantValue(op.inputs[0])
+ num_split = len(op.outputs)
+ input_shape = op.inputs[1].get_shape()
+ if split_dim is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] * num_split
+ else:
+ split_dim = int(split_dim)
+ input_shape = input_shape.with_rank_at_least(split_dim + 1)
+ if not (input_shape[split_dim] % num_split).is_compatible_with(0):
+ raise ValueError(
+ "Number of ways to split should evenly divide the split "
+ "dimension but got split_dim %d (size = %d) and num_split %d" %
+ (split_dim, input_shape[split_dim].value, num_split))
+ prefix = input_shape[:split_dim]
+ size_in_split_dim = input_shape[split_dim] / num_split
+ suffix = input_shape[split_dim + 1:]
+ output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
+ return [output_shape] * num_split
+
+
+@ops.RegisterShape("Tile")
+def _TileShape(op):
+ """Shape function for the Tile op.
+
+ This op has two inputs:
+
+ * input: A rank-N tensor.
+ * multiples: A length-N vector, in which the i^th element contains
+ the factor by which `input` will be tiled in the i^th dimension.
+
+ It has one output, which has the same rank as input, and additional
+ elements according to the values in multiples
+
+ Args:
+ op: A Tile Operation.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ multiples = multiples.ravel()
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim * multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("TileGrad")
+def _TileGradShape(op):
+ """Shape function for the TileGrad op."""
+ multiples_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ input_shape = op.inputs[0].get_shape().with_rank(multiples_shape.num_elements())
+ multiples = tensor_util.ConstantValue(op.inputs[1])
+ if multiples is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ output_dims = []
+ for i, dim in enumerate(input_shape.dims):
+ output_dims.append(dim / multiples[i])
+ return [tensor_shape.TensorShape(output_dims)]
+
+
+@ops.RegisterShape("Where")
+def _WhereShape(op):
+ """Shape function for the Where op."""
+ input_shape = op.inputs[0].get_shape()
+ return [tensor_shape.matrix(None, input_shape.ndims)]
+
+
+@ops.RegisterShape("ZerosLike")
+def _ZerosLikeShape(op):
+ """Shape function for the ZerosLike op."""
+ return [op.inputs[0].get_shape()]
+
+
+def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
+ """Computes the Levenshtein distance between sequences.
+
+ This operation takes variable-length sequences (`hypothesis` and `truth`),
+ each provided as a `SparseTensor`, and computes the Levenshtein distance.
+ You can normalize the edit distance by length of `truth` by setting
+ `normalize` to true.
+
+ For example, given the following input:
+
+ ```python
+ # 'hypothesis' is a tensor of shape `[2, 1]` with variable-length values:
+ # (0,0) = ["a"]
+ # (1,0) = ["b"]
+ hypothesis = tf.SparseTensor(
+ [[0, 0, 0],
+ [1, 0, 0]],
+ ["a", "b"]
+ (2, 1, 1))
+
+ # 'truth' is a tensor of shape `[2, 2]` with variable-length values:
+ # (0,0) = []
+ # (0,1) = ["a"]
+ # (1,0) = ["b", "c"]
+ # (1,1) = ["a"]
+ truth = tf.SparseTensor(
+ [[0, 1, 0],
+ [1, 0, 0],
+ [1, 0, 1],
+ [1, 1, 0]]
+ ["a", "b", "c", "a"],
+ (2, 2, 2))
+
+ normalize = True
+ ```
+
+ This operation would return the following:
+
+ ```python
+ # 'output' is a tensor of shape `[2, 2]` with edit distances normalized
+ # by 'truth' lengths.
+ output ==> [[inf, 1.0], # (0,0): no truth, (0,1): no hypothesis
+ [0.5, 1.0]] # (1,0): addition, (1,1): no hypothesis
+ ```
+
+ Args:
+ hypothesis: A `SparseTensor` containing hypothesis sequences.
+ truth: A `SparseTensor` containing truth sequences.
+ normalize: A `bool`. If `True`, normalizes the Levenshtein distance by
+ length of `truth.`
+ name: A name for the operation (optional).
+
+ Returns:
+ A dense `Tensor` with rank `R - 1`, where R is the rank of the
+ `SparseTensor` inputs `hypothesis` and `truth`.
+
+ Raises:
+ TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
+ """
+ if not isinstance(hypothesis, ops.SparseTensor):
+ raise TypeError("Hypothesis must be a SparseTensor")
+ if not isinstance(truth, ops.SparseTensor):
+ raise TypeError("Truth must be a SparseTensor")
+
+ return gen_array_ops._edit_distance(hypothesis.indices,
+ hypothesis.values,
+ hypothesis.shape,
+ truth.indices,
+ truth.values,
+ truth.shape,
+ normalize=normalize,
+ name=name)
+
+
+@ops.RegisterShape("EditDistance")
+def _EditDistanceShape(op):
+ """Shape function for the EditDistance op."""
+ hypothesis_shape = tensor_util.ConstantValue(op.inputs[2])
+ truth_shape = tensor_util.ConstantValue(op.inputs[5])
+ if hypothesis_shape is not None and truth_shape is not None:
+ if len(hypothesis_shape) != len(truth_shape):
+ raise ValueError(
+ "Inconsistent ranks in hypothesis and truth. Saw shapes: %s and %s" %
+ (str(hypothesis_shape), str(truth_shape)))
+ return [tensor_shape.TensorShape(
+ [max(h, t) for h, t in zip(hypothesis_shape[:-1], truth_shape[:-1])])]
+
+ return [tensor_shape.unknown_shape()]
+
+
+# The remaining ops do not change the shape of their inputs.
+@ops.RegisterShape("Quantize")
+@ops.RegisterShape("Dequantize")
+def _QuantizeDequantizeShape(op):
+ unused_min_range = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ unused_max_range = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
+ return common_shapes.unchanged_shape(op)
diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py
new file mode 100644
index 0000000000..4829bcd7cd
--- /dev/null
+++ b/tensorflow/python/ops/attention_ops.py
@@ -0,0 +1,34 @@
+"""Operations for implementing attention.
+"""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_attention_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_attention_ops import *
+
+
+# TODO(bsteiner): Implement the gradient function for extract_glimpse
+ops.NoGradient("ExtractGlimpse")
+
+
+@ops.RegisterShape("ExtractGlimpse")
+def _ExtractGlimpseShape(op):
+ """Shape function for ExtractGlimpse op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(2))
+ offsets_shape = op.inputs[2].get_shape().merge_with(
+ input_shape[:1].concatenate([2]))
+ offsets_shape = offsets_shape
+ size_value = tensor_util.ConstantValue(op.inputs[1])
+ if size_value is not None:
+ height = size_value[0]
+ width = size_value[1]
+ else:
+ height = None
+ width = None
+ return [tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])]
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
new file mode 100644
index 0000000000..06857c0adc
--- /dev/null
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -0,0 +1,365 @@
+"""Wrappers for primitive Neural Net (NN) Operations."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_candidate_sampling_ops
+from tensorflow.python.ops import math_ops
+
+
+def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, seed=None, name=None):
+ """Samples a set of classes using a uniform base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is the uniform distribution
+ over the range of integers `[0, range_max]`.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._uniform_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, seed=None, name=None):
+ """Samples a set of classes using a log-uniform (Zipfian) base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is an approximately log-uniform
+ or Zipfian distribution:
+
+ `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
+
+ This sampler is useful when the target classes approximately follow such
+ a distribution - for example, if the classes represent words in a lexicon
+ sorted in decreasing order of frequency. If your classes are not ordered by
+ decreasing frequency, do not use this op.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._log_uniform_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
+ unique, range_max, seed=None, name=None):
+ """Samples a set of classes from a distribution learned during training.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution for this operation is constructed on the fly
+ during training. It is a unigram distribution over the target
+ classes seen so far during training. Every integer in `[0, range_max]`
+ begins with a weight of 1, and is incremented by 1 each time it is
+ seen as a target class. The base distribution is not saved to checkpoints,
+ so it is reset when the model is reloaded.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._learned_unigram_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
+ seed2=seed2, name=name)
+
+
+def fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ range_max, vocab_file='', distortion=0.0,
+ num_reserved_ids=0, num_shards=1, shard=0,
+ unigrams=[], seed=None, name=None):
+ """Samples a set of classes using the provided (fixed) base distribution.
+
+ This operation randomly samples a tensor of sampled classes
+ (`sampled_candidates`) from the range of integers `[0, range_max]`.
+
+ The elements of `sampled_candidates` are drawn without replacement
+ (if `unique=True`) or with replacement (if `unique=False`) from
+ the base distribution.
+
+ The base distribution is read from a file or passed in as an
+ in-memory array. There is also an option to skew the distribution by
+ applying a distortion power to the weights.
+
+ In addition, this operation returns tensors `true_expected_count`
+ and `sampled_expected_count` representing the number of times each
+ of the target classes (`true_classes`) and the sampled
+ classes (`sampled_candidates`) is expected to occur in an average
+ tensor of sampled classes. These values correspond to `Q(y|x)`
+ defined in [this
+ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ If `unique=True`, then these are post-rejection probabilities and we
+ compute them approximately.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ unique: A `bool`. Determines whether all sampled classes in a batch are
+ unique.
+ range_max: An `int`. The number of possible classes.
+ vocab_file: Each valid line in this file (which should have a CSV-like
+ format) corresponds to a valid word ID. IDs are in sequential order,
+ starting from num_reserved_ids. The last entry in each line is expected
+ to be a value corresponding to the count or relative probability. Exactly
+ one of `vocab_file` and `unigrams` needs to be passed to this operation.
+ distortion: The distortion is used to skew the unigram probability
+ distribution. Each weight is first raised to the distortion's power
+ before adding to the internal unigram distribution. As a result,
+ `distortion = 1.0` gives regular unigram sampling (as defined by the vocab
+ file), and `distortion = 0.0` gives a uniform distribution.
+ num_reserved_ids: Optionally some reserved IDs can be added in the range
+ `[0, num_reserved_ids]` by the users. One use case is that a special
+ unknown word token is used as ID 0. These IDs will have a sampling
+ probability of 0.
+ num_shards: A sampler can be used to sample from a subset of the original
+ range in order to speed up the whole computation through parallelism. This
+ parameter (together with `shard`) indicates the number of partitions that
+ are being used in the overall computation.
+ shard: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This
+ parameter (together with `num_shards`) indicates the particular partition
+ number of the operation, when partitioning is being used.
+ unigrams: A list of unigram counts or probabilities, one per ID in
+ sequential order. Exactly one of `vocab_file` and `unigrams` should be
+ passed to this operation.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled classes.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._fixed_unigram_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, range_max,
+ vocab_file=vocab_file, distortion=distortion,
+ num_reserved_ids=num_reserved_ids, num_shards=num_shards, shard=shard,
+ unigrams=unigrams, seed=seed1, seed2=seed2, name=name)
+
+
+def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
+ seed=None, name=None):
+ """Generate the set of all classes.
+
+ Deterministically generates and returns the set of all possible classes.
+ For testing purposes. There is no need to use this, since you might as
+ well use full softmax or full logistic regression.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_true: An `int`. The number of target classes per training example.
+ num_sampled: An `int`. The number of possible classes.
+ unique: A `bool`. Ignored.
+ unique.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ This operation deterministically returns the entire range
+ `[0, num_sampled]`.
+ true_expected_count: A tensor of type `float`. Same shape as
+ `true_classes`. The expected counts under the sampling distribution
+ of each of `true_classes`. All returned values are 1.0.
+ sampled_expected_count: A tensor of type `float`. Same shape as
+ `sampled_candidates`. The expected counts under the sampling distribution
+ of each of `sampled_candidates`. All returned values are 1.0.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._all_candidate_sampler(
+ true_classes, num_true, num_sampled, unique, seed=seed1, seed2=seed2,
+ name=name)
+
+
+def compute_accidental_hits(true_classes, sampled_candidates, num_true,
+ seed=None, name=None):
+ """Compute the ids of positions in sampled_candidates matching true_classes.
+
+ In Candidate Sampling, this operation facilitates virtually removing
+ sampled classes which happen to match target classes. This is done
+ in Sampled Softmax and Sampled Logistic.
+
+ See our [Candidate Sampling Algorithms
+ Reference](http://www.tensorflow.org/extras/candidate_sampling.pdf).
+
+ We presuppose that the `sampled_candidates` are unique.
+
+ We call it an 'accidental hit' when one of the target classes
+ matches one of the sampled classes. This operation reports
+ accidental hits as triples `(index, id, weight)`, where `index`
+ represents the row number in `true_classes`, `id` represents the
+ position in `sampled_candidates`, and weight is `-FLOAT_MAX`.
+
+ The result of this op should be passed through a `sparse_to_dense`
+ operation, then added to the logits of the sampled classes. This
+ removes the contradictory effect of accidentally sampling the true
+ target classes as noise classes for the same example.
+
+ Args:
+ true_classes: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
+ The sampled_candidates output of CandidateSampler.
+ num_true: An `int`. The number of target classes per training example.
+ seed: An `int`. An operation-specific seed. Default is 0.
+ name: A name for the operation (optional).
+
+ Returns:
+ indices: A `Tensor` of type `int32` and shape `[num_accidental_hits]`.
+ Values indicate rows in `true_classes`.
+ ids: A `Tensor` of type `int64` and shape `[num_accidental_hits]`.
+ Values indicate positions in `sampled_candidates`.
+ weights: A `Tensor` of type `float` and shape `[num_accidental_hits]`.
+ Each value is `-FLOAT_MAX`.
+
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_candidate_sampling_ops._compute_accidental_hits(
+ true_classes, sampled_candidates, num_true, seed=seed1, seed2=seed2,
+ name=name)
+
+
+@ops.RegisterShape("AllCandidateSampler")
+@ops.RegisterShape("FixedUnigramCandidateSampler")
+@ops.RegisterShape("LearnedUnigramCandidateSampler")
+@ops.RegisterShape("LogUniformCandidateSampler")
+@ops.RegisterShape("ThreadUnsafeUnigramCandidateSampler")
+@ops.RegisterShape("UniformCandidateSampler")
+def _CandidateSamplerShape(op):
+ true_classes_shape = op.inputs[0].get_shape().with_rank(2)
+ batch_size = true_classes_shape[0]
+ num_sampled = op.get_attr("num_sampled")
+ num_true = op.get_attr("num_true")
+ return [tensor_shape.vector(num_sampled),
+ tensor_shape.matrix(batch_size, num_true),
+ tensor_shape.vector(num_sampled)]
+
+
+@ops.RegisterShape("ComputeAccidentalHits")
+def _ComputeAccidentalHitsShape(op):
+ num_true = op.get_attr("num_true")
+ # Validate that the input shape matches the attrs, even though it
+ # does not influence the shape of the output.
+ true_candidates_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.matrix(None, num_true))
+ output_shape = tensor_shape.vector(None)
+ return [output_shape] * 3
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
new file mode 100644
index 0000000000..08781932f9
--- /dev/null
+++ b/tensorflow/python/ops/clip_ops.py
@@ -0,0 +1,234 @@
+"""Operations for clipping (gradient, weight) tensors to min/max values."""
+
+import collections
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+
+
+def clip_by_value(t, clip_value_min, clip_value_max,
+ name=None):
+ """Clips tensor values to a specified min and max.
+
+ Given a tensor `t`, this operation returns a tensor of the same type and
+ shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
+ Any values less than `clip_value_min` are set to `clip_value_min`. Any values
+ greater than `clip_value_max` are set to `clip_value_max`.
+
+ Args:
+ t: A `Tensor`.
+ clip_value_min: A 0-D (scalar) `Tensor`. The minimum value to clip by.
+ clip_value_max: A 0-D (scalar) `Tensor`. The maximum value to clip by.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_value_min, clip_value_max], name,
+ "clip_by_value") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Go through list of tensors, for each value in each tensor clip
+ t_min = math_ops.minimum(
+ t, array_ops.fill(array_ops.shape(t), clip_value_max))
+ t_max = math_ops.maximum(
+ t_min, array_ops.fill(array_ops.shape(t), clip_value_min),
+ name=name)
+
+ return t_max
+
+
+def clip_by_norm(t, clip_norm, name=None):
+ """Clips tensor values to a maximum L2-norm.
+
+ Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
+ normalizes `t` so that its L2-norm is less than or equal to `clip_norm'.
+ Specifically, if the L2-norm is already less than or equal to `clip_norm`,
+ then `t` is not modified. If the L2-norm is greater than `clip_norm`, then
+ this operation returns a tensor of the same type and shape as `t` with its
+ values set to:
+
+ `t * clip_norm / l2norm(t)`
+
+ In this case, the L2-norm of the output tensor is `clip_norm`.
+
+ This operation is typically used to clip gradients before applying them with
+ an optimizer.
+
+ Args:
+ t: A `Tensor`.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_norm], name, "clip_by_norm") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
+ l2norm_inv = math_ops.rsqrt(
+ math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
+ l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name)
+
+ return tclip
+
+def global_norm(t_list, name=None):
+ """Computes the global norm of multiple tensors.
+
+ Given a tuple or list of tensors `t_list`, this operation returns the
+ global norm of the elements in all tensors in `t_list`. The global norm is
+ computed as:
+
+ `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
+
+ Any entries in `t_list` that are of type None are ignored.
+
+ Args:
+ t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 0-D (scalar) `Tensor` of type `float`.
+
+ Raises:
+ TypeError: If `t_list` is not a sequence.
+ """
+ if (not isinstance(t_list, collections.Sequence)
+ or isinstance(t_list, basestring)):
+ raise TypeError("t_list should be a sequence")
+ t_list = list(t_list)
+ with ops.op_scope(t_list, name, "global_norm") as name:
+ values = [
+ ops.convert_to_tensor(
+ t.values if isinstance(t, ops.IndexedSlices) else t,
+ name="t_%d" % i)
+ if t is not None else t
+ for i, t in enumerate(t_list)]
+ squared_norms = array_ops.pack(
+ [math_ops.reduce_sum(v * v) for v in values if v])
+
+ norm = math_ops.sqrt(
+ math_ops.reduce_sum(squared_norms), name="global_norm")
+
+ return norm
+
+def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
+ """Clips values of multiple tensors by the ratio of the sum of their norms.
+
+ Given a tuple or list of tensors `t_list`, and a clipping ratio `clip_norm`,
+ this operation returns a list of clipped tensors `list_clipped`
+ and the global norm (`global_norm`) of all tensors in `t_list`. Optionally,
+ if you've already computed the global norm for `t_list`, you can specify
+ the global norm with `use_norm`.
+
+ To perform the clipping, the values t_list[i] are set to:
+
+ `t_list[i] * clip_norm / max(global_norm, clip_norm)`
+
+ where:
+
+ `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
+
+ If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
+ otherwise they're all shrunk by the global ratio.
+
+ Any of the entries of `t_list` that are of type None are ignored.
+
+ This is the correct way to perform gradient clipping (for example, see
+ R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training
+ Recurrent Neural Networks". http://arxiv.org/abs/1211.5063)
+
+ However, it is slower than `clip_by_norm()` because all the parameters must be
+ ready before the clipping operation can be performed.
+
+ Args:
+ t_list: A tuple or list of mixed `Tensors`, `IndexedSlices`, or None.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. The clipping ratio.
+ use_norm: A 0-D (scalar) `Tensor` of type `float` (optional). The global
+ norm to use. If not provided, `global_norm()` is used to compute the norm.
+ name: A name for the operation (optional).
+
+ Returns:
+ list_clipped: A list of `Tensors` of the same type as `list_t`.
+ global_norm: A 0-D (scalar) `Tensor` representing the global norm.
+
+ Raises:
+ TypeError: If `t_list` is not a sequence.
+ """
+ if (not isinstance(t_list, collections.Sequence)
+ or isinstance(t_list, basestring)):
+ raise TypeError("t_list should be a sequence")
+ t_list = list(t_list)
+ if use_norm is None:
+ use_norm = global_norm(t_list, name)
+
+ with ops.op_scope(t_list + [clip_norm], name, "clip_by_global_norm") as name:
+ # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
+ scale = clip_norm * math_ops.minimum(
+ 1.0 / use_norm, constant_op.constant(1.0 / clip_norm))
+
+ values = [
+ ops.convert_to_tensor(
+ t.values if isinstance(t, ops.IndexedSlices) else t,
+ name="t_%d" % i)
+ if t is not None else t
+ for i, t in enumerate(t_list)]
+
+ values_clipped = [
+ array_ops.identity(v * scale, name="%s_%d" % (name, i))
+ if v is not None else None
+ for i, v in enumerate(values)]
+
+ list_clipped = [
+ ops.IndexedSlices(c_v, t.indices)
+ if isinstance(t, ops.IndexedSlices)
+ else c_v
+ for (c_v, t) in zip(values_clipped, t_list)]
+
+ return list_clipped, use_norm
+
+
+def clip_by_average_norm(t, clip_norm, name=None):
+ """Clips tensor values to a maximum average L2-norm.
+
+ Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
+ normalizes `t` so that its average L2-norm is less than or equal to
+ `clip_norm'. Specifically, if the average L2-norm is already less than or
+ equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
+ greater than `clip_norm`, then this operation returns a tensor of the same
+ type and shape as `t` with its values set to:
+
+ `t * clip_norm / l2norm_avg(t)`
+
+ In this case, the average L2-norm of the output tensor is `clip_norm`.
+
+ This operation is typically used to clip gradients before applying them with
+ an optimizer.
+
+ Args:
+ t: A `Tensor`.
+ clip_norm: A 0-D (scalar) `Tensor` > 0. A maximum clipping value.
+ name: A name for the operation (optional).
+
+ Returns:
+ A clipped `Tensor`.
+ """
+ with ops.op_scope([t, clip_norm], name, "clip_by_average_norm") as name:
+ t = ops.convert_to_tensor(t, name="t")
+
+ # Calculate L2-norm per element, clip elements by ratio of clip_norm to
+ # L2-norm per element
+ n_element = math_ops.cast(array_ops.size(t), types.float32)
+ l2norm_inv = math_ops.rsqrt(
+ math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+ tclip = array_ops.identity(
+ t * clip_norm * math_ops.minimum(
+ l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)),
+ name=name)
+
+ return tclip
diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py
new file mode 100644
index 0000000000..c41d1ff71d
--- /dev/null
+++ b/tensorflow/python/ops/common_shapes.py
@@ -0,0 +1,371 @@
+"""A library of common shape functions."""
+import math
+
+from tensorflow.python.framework import tensor_shape
+
+
+def scalar_shape(unused_op):
+ """Shape function for ops that output a scalar value."""
+ return [tensor_shape.scalar()]
+
+
+def unchanged_shape(op):
+ """Shape function for ops that output an tensor like their first input."""
+ return [op.inputs[0].get_shape()]
+
+
+def unchanged_shape_with_rank(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: The exact rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank(rank)]
+ return _ShapeFunction
+
+
+def unchanged_shape_with_rank_at_least(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: A lower bound on the rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank_at_least(rank)]
+ return _ShapeFunction
+
+
+def unchanged_shape_with_rank_at_most(rank):
+ """Returns a shape function for ops that constrain the rank of their input.
+
+ Args:
+ rank: An upper bound on the rank of the input and output.
+
+ Returns:
+ A shape function for ops that output a tensor of the same size as their
+ input, with a particular rank.
+ """
+ def _ShapeFunction(op):
+ return [op.inputs[0].get_shape().with_rank_at_most(rank)]
+ return _ShapeFunction
+
+
+def matmul_shape(op):
+ """Shape function for a MatMul op."""
+ a_shape = op.inputs[0].get_shape().with_rank(2)
+ transpose_a = op.get_attr("transpose_a")
+ b_shape = op.inputs[1].get_shape().with_rank(2)
+ transpose_b = op.get_attr("transpose_b")
+ output_rows = a_shape[1] if transpose_a else a_shape[0]
+ output_cols = b_shape[0] if transpose_b else b_shape[1]
+ inner_a = a_shape[0] if transpose_a else a_shape[1]
+ inner_b = b_shape[1] if transpose_b else b_shape[0]
+ inner_a.assert_is_compatible_with(inner_b)
+ return [tensor_shape.TensorShape([output_rows, output_cols])]
+
+
+def bias_add_shape(op):
+ """Shape function for a BiasAdd op."""
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
+ bias_shape = op.inputs[1].get_shape().with_rank(1)
+ if input_shape.ndims is not None:
+ # Output has the same shape as input, and matches the length of
+ # bias in its last dimension.
+ output_shape = input_shape[0:-1].concatenate(
+ input_shape[-1].merge_with(bias_shape[0]))
+ else:
+ output_shape = tensor_shape.unknown_shape()
+ return [output_shape]
+
+
+def _Get2DOutputSize(input_height, input_width, filter_height, filter_width,
+ row_stride, col_stride, padding_type):
+ """Returns the number of rows and columns in a convolution/pooling output."""
+ input_height = tensor_shape.as_dimension(input_height)
+ input_width = tensor_shape.as_dimension(input_width)
+ filter_height = tensor_shape.as_dimension(filter_height)
+ filter_width = tensor_shape.as_dimension(filter_width)
+ row_stride = int(row_stride)
+ col_stride = int(col_stride)
+
+ if filter_height.value == 1 and filter_width.value == 1 and (
+ row_stride == 1 and col_stride == 1):
+ return input_height, input_width
+ else:
+ if filter_height > input_height or filter_width > input_width:
+ raise ValueError("filter must not be larger than the input: ",
+ "Filter: [", filter_height, "x", filter_width, "] ",
+ "Input: [", input_height, "x", input_width, "] ")
+ if row_stride > filter_height or col_stride > filter_width:
+ raise ValueError("stride must be less than or equal to filter size",
+ "stride: [", row_stride, "x", col_stride, "] ",
+ "filter: [", filter_height, "x", filter_width, "] ")
+
+ # Compute number of rows in the output, based on the padding.
+ if input_height.value is None or filter_height.value is None:
+ out_rows = None
+ elif padding_type == "VALID":
+ out_rows = int(
+ math.ceil((input_height.value - filter_height.value + 1.0)
+ / row_stride))
+ elif padding_type == "SAME":
+ out_rows = int(math.ceil(input_height.value * 1.0
+ / row_stride))
+ else:
+ raise ValueError("Invalid value for padding: %r" % padding_type)
+
+ # Compute number of columns in the output, based on the padding.
+ if input_width.value is None or filter_width.value is None:
+ out_cols = None
+ elif padding_type == "VALID":
+ out_cols = int(
+ math.ceil((input_width.value - filter_width.value + 1.0)
+ / col_stride))
+ elif padding_type == "SAME":
+ out_cols = int(math.ceil(input_width.value * 1.0 / col_stride))
+
+ return out_rows, out_cols
+
+
+def conv2d_shape(op):
+ """Shape function for a Conv2D op.
+
+ This op has two inputs:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+ * filter, a 4D tensor with shape = [filter_rows, filter_cols,
+ depth_in, depth_out]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "padding" and "strides" attrs.
+
+ Args:
+ op: A Conv2D Operation.
+
+ Returns:
+ A list containing the Shape of the Conv2D output.
+
+ Raises:
+ ValueError: If the shapes of the input or filter are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ filter_shape = op.inputs[1].get_shape().with_rank(4)
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+
+ filter_rows = filter_shape[0]
+ filter_cols = filter_shape[1]
+ depth_out = filter_shape[3]
+ # Check that the input depths are compatible.
+ input_shape[3].assert_is_compatible_with(filter_shape[2])
+
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not yet support "
+ "strides in the batch and depth dimensions.")
+ if stride_r != stride_c:
+ # TODO(shlens): Add support for this.
+ raise ValueError("Current implementation only supports equal length "
+ "strides in the row and column dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ stride = stride_r
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
+
+
+def separable_conv2d_shape(op):
+ """Shape function for a SeparableConv2D op.
+
+ This op has three inputs:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+
+ * depthwise_filter, a 4D tensor with shape = [filter_rows,
+ filter_cols, depth_in, depth_multiplier]
+
+ * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
+ depth_multiplier, depth_out]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "padding" and "strides" attrs.
+
+ Args:
+ op: A SeparableConv2D Operation.
+
+ Returns:
+ A list containing the Shape of the SeparableConv2D output.
+
+ Raises:
+ ValueError: If the shapes of the input or filter are incompatible.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.TensorShape([None, None, input_shape[3], None]))
+ pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
+
+ pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+
+ filter_rows = depthwise_filter_shape[0]
+ filter_cols = depthwise_filter_shape[1]
+ depth_out = pointwise_filter_shape[3]
+
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not yet support "
+ "strides in the batch and depth dimensions.")
+ if stride_r != stride_c:
+ # TODO(shlens): Add support for this.
+ raise ValueError("Current implementation only supports equal length "
+ "strides in the row and column dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ stride = stride_r
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
+
+
+def avg_pool_shape(op):
+ """Shape function for an AvgPool op.
+
+ This op has one input:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows and out_cols depend on the
+ value of the op's "ksize", "strides", and "padding" attrs.
+
+ Args:
+ op: An AvgPool Operation.
+
+ Returns:
+ A single-element list containing the Shape of the AvgPool output.
+
+ Raises:
+ ValueError: If the shape of the input is invalid or incompatible with
+ the values of the attrs.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+ depth = input_shape[3]
+
+ if ksize_b != 1 or ksize_d != 1:
+ raise ValueError("Current implementation does not support pooling "
+ "in the batch and depth dimensions.")
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not support strides "
+ "in the batch and depth dimensions.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ padding = op.get_attr("padding")
+
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
+
+
+def max_pool_shape(op):
+ """Shape function for a MaxPool op.
+
+ This op has one input:
+
+ * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
+
+ The output is a 4D tensor with shape = [batch_size, out_rows,
+ out_cols, depth_out], where out_rows, out_cols, and depth_out depend
+ on the value of the op's "ksize", "strides", and "padding" attrs.
+
+ Args:
+ op: A MaxPool Operation.
+
+ Returns:
+ A single-element list containing the Shape of the MaxPool output.
+
+ Raises:
+ ValueError: If the shape of the input is invalid or incompatible with
+ the values of the attrs.
+ """
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+ batch_size = input_shape[0]
+ in_rows = input_shape[1]
+ in_cols = input_shape[2]
+ depth = input_shape[3]
+
+ if ksize_b != 1:
+ raise ValueError("Current implementation does not support pooling "
+ "in the batch dimension.")
+ if stride_b != 1:
+ raise ValueError("Current implementation does not support strides "
+ "in the batch dimension.")
+
+ if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
+ raise ValueError("MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height.")
+
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ if ksize_d == 1:
+ padding = op.get_attr("padding")
+ out_rows, out_cols = _Get2DOutputSize(
+ in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding)
+ return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth])]
+ else:
+ if depth % ksize_d > 0:
+ raise ValueError("Depthwise max pooling requires the depth window "
+ "to evenly divide the input depth.")
+ if stride_d != ksize_d:
+ raise ValueError("Depthwise max pooling requires the depth window "
+ "to equal the depth stride.")
+ return [tensor_shape.TensorShape(
+ [batch_size, in_rows, in_cols, depth / ksize_d])]
+
+
+def no_outputs(unused_op):
+ """Shape function for use with ops that have no outputs."""
+ return []
+
+
+def unknown_shape(op):
+ """Shape function for use with ops whose output shapes are unknown."""
+ return [tensor_shape.unknown_shape() for _ in op.outputs]
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
new file mode 100644
index 0000000000..7d9044b689
--- /dev/null
+++ b/tensorflow/python/ops/constant_op.py
@@ -0,0 +1,189 @@
+"""## Constant Value Tensors
+
+TensorFlow provides several operations that you can use to generate constants.
+
+@@zeros
+@@zeros_like
+
+@@ones
+@@ones_like
+
+@@fill
+
+@@constant
+
+## Sequences
+
+@@linspace
+
+@@range
+
+## Random Tensors
+
+TensorFlow has several ops that create random tensors with different
+distributions. The random ops are stateful, and create new random values each
+time they are evaluated.
+
+The `seed` keyword argument in these functions acts in conjunction with
+the graph-level random seed. Changing either the graph-level seed using
+[`set_random_seed`](constant_op.md#set_random_seed) or the op-level seed
+will change the underlying seed of these operations. Setting neither graph-level
+nor op-level seed, results in a random seed for all operations.
+See [`set_random_seed`](constant_op.md#set_random_seed) for details on the
+interaction between operation-level and graph-level random seeds.
+
+### Examples:
+
+```python
+# Create a tensor of shape [2, 3] consisting of random normal values, with mean
+# -1 and standard deviation 4.
+norm = tf.random_normal([2, 3], mean=-1, stddev=4)
+
+# Shuffle the first dimension of a tensor
+c = tf.constant([[1, 2], [3, 4], [5, 6]])
+shuff = tf.random_shuffle(c)
+
+# Each time we run these ops, different results are generated
+sess = tf.Session()
+print sess.run(norm)
+print sess.run(norm)
+
+# Set an op-level seed to generate repeatable sequences across sessions.
+c = tf.constant([[1, 2], [3, 4], [5, 6]])
+sess = tf.Session()
+norm = tf.random_normal(c, seed=1234)
+print sess.run(norm)
+print sess.run(norm)
+```
+
+Another common use of random values is the intialization of variables. Also see
+the [Variables How To](../../how_tos/variables/index.md).
+
+```python
+# Use random uniform values in [0, 1) as the initializer for a variable of shape
+# [2, 3]. The default type is float32.
+var = tf.Variable(tf.random_uniform([2, 3]), name="var")
+init = tf.initialize_all_variables()
+
+sess = tf.Session()
+sess.run(init)
+print sess.run(var)
+```
+
+@@random_normal
+@@truncated_normal
+@@random_uniform
+@@random_shuffle
+@@set_random_seed
+
+"""
+"""Constant Operation.
+
+Has to be separate from array_ops to avoid a cyclic dependency.
+"""
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+
+
+def constant(value, dtype=None, shape=None, name="Const"):
+ """Creates a constant tensor.
+
+ The resulting tensor is populated with values of type `dtype`, as
+ specified by arguments `value` and (optionally) `shape` (see examples
+ below).
+
+ The argument `value` can be a constant value, or a list of values of type
+ `dtype`. If `value` is a list, then the length of the list must be less
+ than or equal to the number of elements implied by the `shape` argument (if
+ specified). In the case where the list length is less than the number of
+ elements specified by `shape`, the last element in the list will be used
+ to fill the remaining entries.
+
+ The argument `shape` is optional. If present, it specifies the dimensions
+ of the resulting tensor. If not present, then the tensor is a scalar (0-D)
+ if `value` is a scalar, or 1-D otherwise.
+
+ If the argument `dtype` is not specified, then the type is inferred from
+ the type of `value`.
+
+ For example:
+
+ ```python
+ # Constant 1-D Tensor populated with value list.
+ tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]
+
+ # Constant 2-D tensor populated with scalar value -1.
+ tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
+ [-1. -1. -1.]]
+ ```
+
+ Args:
+ value: A constant value (or list) of output type `dtype`.
+
+ dtype: The type of the elements of the resulting tensor.
+
+ shape: Optional dimensions of resulting tensor.
+
+ name: Optional name for the tensor.
+
+ Returns:
+ A Constant Tensor.
+ """
+ g = ops.get_default_graph()
+ tensor_value = attr_value_pb2.AttrValue()
+ tensor_value.tensor.CopyFrom(
+ tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
+ dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
+ const_tensor = g.create_op(
+ "Const", [], [dtype_value.type],
+ attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
+ return const_tensor
+
+
+@ops.RegisterShape("Const")
+def _ConstantShape(op):
+ return [tensor_shape.TensorShape(
+ [d.size for d in op.get_attr("value").tensor_shape.dim])]
+
+
+ops.register_tensor_conversion_function((list, tuple), constant, 100)
+ops.register_tensor_conversion_function(np.ndarray, constant, 100)
+ops.register_tensor_conversion_function(np.generic, constant, 100)
+ops.register_tensor_conversion_function(object, constant, 200)
+
+def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
+ if not s.is_fully_defined():
+ raise ValueError(
+ "Cannot convert a partially known TensorShape to a Tensor: %s" % s)
+ if dtype is not None:
+ if dtype not in (types.int32, types.int64):
+ raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
+ else:
+ dtype = types.int32
+ if name is None:
+ name = "shape_as_tensor"
+ return constant(s.as_list(), dtype=dtype, name=name)
+
+ops.register_tensor_conversion_function(
+ tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
+
+def _dimension_tensor_conversion_function(d, dtype=None, name=None):
+ if d.value is None:
+ raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
+ if dtype is not None:
+ if dtype not in (types.int32, types.int64):
+ raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
+ else:
+ dtype = types.int32
+ if name is None:
+ name = "shape_as_tensor"
+ return constant(d.value, dtype=dtype, name=name)
+
+ops.register_tensor_conversion_function(
+ tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
new file mode 100644
index 0000000000..3a1a5b91c0
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -0,0 +1,100 @@
+"""Gradients for operators defined in control_flow_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.control_flow_ops import *
+from tensorflow.python.ops.gen_control_flow_ops import *
+
+
+@ops.RegisterGradient("Switch")
+def _SwitchGrad(op, *grad):
+ op = GetRealOp(op)
+ ctxt = op._get_control_flow_context() # pylint: disable=protected-access
+ if isinstance(ctxt, WhileContext):
+ merge_op = ctxt.switch_map.get(op)
+ if merge_op:
+ merge_op._update_input(1, grad[1])
+ return None, None
+ else:
+ merge_op = merge(grad, name="b_switch")[0]
+ ctxt.switch_map[op] = merge_op.op
+ return merge_op, None
+ elif isinstance(ctxt, CondContext):
+ good_grad = grad[ctxt.branch]
+ zero_grad = grad[1 - ctxt.branch]
+ zero_grad = switch(zero_grad, ctxt.pred, name="grad_0")[1 - ctxt.branch]
+ return merge([good_grad, zero_grad], name="switch_grad")[0], None
+ else:
+ false_grad = switch(grad[0], op.inputs[1])[0]
+ true_grad = switch(grad[1], op.inputs[1])[1]
+ return merge([false_grad, true_grad])[0], None
+
+
+@ops.RegisterGradient("RefSwitch")
+def _RefSwitchGrad(op, *grad):
+ return _SwitchGrad(op, *grad)
+
+
+@ops.RegisterGradient("Merge")
+def _MergeGrad(op, grad, _):
+ op = GetRealOp(op)
+ input_op = op.inputs[0].op
+ # pylint: disable=protected-access
+ ctxt = input_op._get_control_flow_context()
+ # pylint: enable=protected-access
+ if isinstance(ctxt, WhileContext):
+ grad_ctxt = ctxt.grad_context
+ return switch(grad, grad_ctxt.pivot)
+ elif isinstance(ctxt, CondContext):
+ return switch(grad, ctxt.pred, name="merge_grad")
+ else:
+ num_inputs = len(op.inputs)
+ cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
+ return [Switch(grad, cond[i])[1] for i in xrange(num_inputs)]
+
+
+@ops.RegisterGradient("Exit")
+def _ExitGrad(op, grad):
+ # pylint: disable=protected-access
+ forward_ctxt = op._get_control_flow_context()
+ # pylint: enable=protected-access
+ if not forward_ctxt.back_prop:
+ return None
+ grad_ctxt = forward_ctxt.grad_context
+ grad_ctxt.AddName(grad.name)
+ return enter(grad, grad_ctxt.name, is_constant=False,
+ parallel_iterations=forward_ctxt.parallel_iterations,
+ name="b_exit")
+
+
+@ops.RegisterGradient("NextIteration")
+def _NextIterationGrad(_, grad):
+ return next_iteration(grad)
+
+
+@ops.RegisterGradient("Enter")
+def _EnterGrad(op, grad):
+ op = GetRealOp(op)
+ # pylint: disable=protected-access
+ forward_ctxt = op._get_control_flow_context()
+ # pylint: enable=protected-access
+ grad_ctxt = forward_ctxt.grad_context
+ if grad_ctxt:
+ if op.get_attr("is_constant"):
+ # Add a gradient accumulator for every loop invariant.
+ result = grad_ctxt.AddBackPropAccumulateLoop(grad)
+ else:
+ result = exit(grad)
+ return result
+ else:
+ return grad
+
+
+@ops.RegisterGradient("RefEnter")
+def _RefEnterGrad(op, grad):
+ return _EnterGrad(op, grad)
+
+
+@ops.RegisterGradient("LoopCond")
+def _LoopCondGrad(_):
+ return None
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
new file mode 100644
index 0000000000..068e3b5553
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -0,0 +1,1561 @@
+"""## Control Flow Operations
+
+TensorFlow provides several operations and classes that you can use to control
+the execution of operations and add conditional dependencies to your graph.
+
+@@identity
+@@tuple
+@@group
+@@no_op
+@@count_up_to
+
+## Logical Operators
+
+TensorFlow provides several operations that you can use to add logical operators
+to your graph.
+
+@@logical_and
+@@logical_not
+@@logical_or
+@@logical_xor
+
+## Comparison Operators
+
+TensorFlow provides several operations that you can use to add comparison
+operators to your graph.
+
+@@equal
+@@not_equal
+@@less
+@@less_equal
+@@greater
+@@greater_equal
+@@select
+@@where
+
+## Debugging Operations
+
+TensorFlow provides several operations that you can use to validate values and
+debug your graph.
+
+@@is_finite
+@@is_inf
+@@is_nan
+@@verify_tensor_all_finite
+@@check_numerics
+@@add_check_numerics_ops
+@@Assert
+@@Print
+"""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_control_flow_ops import *
+
+
+# We override the 'tuple' for a control flow op, so we keep python's
+# existing 'tuple' for later use in this module.
+_basetuple = tuple
+
+
+# pylint: disable=protected-access
+def _Identity(data, name=None):
+ """Return a tensor with the same shape and contents as the input tensor.
+
+ Args:
+ data: A Tensor.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor with the same type and value as the input Tensor.
+ """
+ if not data.dtype.is_ref_dtype:
+ return array_ops.identity(data, name=name)
+ else:
+ return gen_array_ops._ref_identity(data, name=name)
+
+
+def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
+ name=None):
+ """Creates or finds a child frame, and makes 'data' available to it.
+
+ The unique `frame_name` is used by the `Executor` to identify frames. If
+ `is_constant` is true, `output` is a constant in the child frame; otherwise
+ it may be changed in the child frame. At most `parallel_iterations` iterations
+ are run in parallel in the child frame.
+
+ Args:
+ data: The tensor to be made available to the child frame.
+ frame_name: The name of the child frame.
+ is_constant: If true, the output is constant within the child frame.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as 'data'.
+ """
+ if not data.dtype.is_ref_dtype:
+ return enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+ else:
+ return ref_enter(data, frame_name, is_constant, parallel_iterations,
+ name=name)
+
+
+def exit(data, name=None):
+ """Exits the current frame to its parent frame.
+
+ Exit makes its input `data` available to the parent frame.
+
+ Args:
+ data: The tensor to be made available to the parent frame.
+ name: A name for this operation (optional).
+
+ Returns:
+ The same tensor as `data`.
+ """
+ return gen_control_flow_ops._exit(data, name)
+
+
+def switch(data, pred, name=None):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_true, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+ """
+ with ops.op_scope([data, pred], name, "Switch") as name:
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ pred = ops.convert_to_tensor(pred, name="pred")
+ if isinstance(data, ops.Tensor):
+ return gen_control_flow_ops._switch(data, pred, name=name)
+ else:
+ val, ind, dense_shape = data.values, data.indices, data.dense_shape
+ val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
+ ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
+ if dense_shape:
+ dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
+ dense_shape, pred, name="dense_shape")
+ else:
+ dense_shape_f, dense_shape_t = None, None
+ return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
+ ops.IndexedSlices(val_t, ind_t, dense_shape_t))
+
+
+def merge(inputs, name=None):
+ """Returns the value of an available element of `inputs`.
+
+ This op tests each of the tensors in `inputs` in turn to determine if any of
+ them is available. If it finds an available tensor, it returns it and its
+ index in `inputs`.
+
+ It is an error if more than one tensor in `inputs` is available. If no tensor
+ in `inputs` is available, the returned tensor and index are not set.
+
+ This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
+ `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
+ before merging.
+
+ Args:
+ inputs: The input tensors, at most one of which is available.
+ name: A name for this operation (optional).
+
+ Returns:
+ A tuple containing the chosen input tensor and its index in `inputs`.
+
+ Raises:
+ ValueError: If inputs are IndexedSlices and some but not all have a
+ dense_shape property.
+ """
+ with ops.op_scope(inputs, name, "Merge") as name:
+ inputs = [ops.convert_to_tensor_or_indexed_slices(inp) for inp in inputs]
+ if all([isinstance(inp, ops.Tensor) for inp in inputs]):
+ return gen_control_flow_ops._merge(inputs, name=name)
+ else:
+ inputs = math_ops._as_indexed_slices_list(inputs)
+ values, _ = gen_control_flow_ops._merge([inp.values for inp in inputs],
+ name=name)
+ indices, chosen_index = gen_control_flow_ops._merge(
+ [inp.indices for inp in inputs], name="indices")
+ if any(inp.dense_shape for inp in inputs):
+ if not all(inp.dense_shape for inp in inputs):
+ raise ValueError("Either all merged IndexedSlices must have a "
+ "dense_shape, or none must have a dense_shape.")
+ dense_shape, _ = gen_control_flow_ops._merge(
+ [inp.dense_shape for inp in inputs], name="dense_shape")
+ else:
+ dense_shape = None
+ return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+
+
+def _SwitchRefOrTensor(data, pred, name="Switch"):
+ """Forwards `data` to an output determined by `pred`.
+
+ If `pred` is true, the `data` input is forwared to the first output.
+ Otherwise, the data goes to the second output.
+
+ This op handles `Tensor`s and `IndexedSlices`.
+
+ Args:
+ data: The tensor to be forwarded to the appropriate output.
+ pred: A scalar that specifies which output port will receive data.
+ name: A name for this operation (optional).
+
+ Returns:
+ `(output_false, output_false)`: If `pred` is true, data will be forwarded to
+ `output_true`, otherwise it goes to `output_false`.
+
+ Raises:
+ TypeError: if data is not a Tensor or IndexedSlices
+ """
+ data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
+ if isinstance(data, ops.Tensor):
+ if not data.dtype.is_ref_dtype:
+ return switch(data, pred, name=name)
+ else:
+ return ref_switch(data, pred, name=name)
+ else:
+ return switch(data, pred, name=name)
+
+
+class ControlFlowOpInputs(object):
+ """An indirection to capture the input tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+
+ def __len__(self):
+ return len(self._op._inputs)
+
+ def __getitem__(self, index):
+ if self._inputs is None:
+ self._inputs = [None for _ in self._op.inputs]
+ if isinstance(index, int):
+ val = self._inputs[index]
+ if val is None:
+ f_val = self._op.inputs[index]
+ val = _GetRealValue(f_val)
+ self._inputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpOutputs(object):
+ """An indirection to capture the output tensors needed in backprop."""
+
+ def __init__(self, op):
+ self._op = op
+ self._outputs = None
+
+ def __len__(self):
+ return len(self._op._outputs)
+
+ def __getitem__(self, index):
+ if self._outputs is None:
+ self._outputs = [None for _ in self._op.outputs]
+ if isinstance(index, int):
+ val = self._outputs[index]
+ if val is None:
+ f_val = self._op.outputs[index]
+ val = _GetRealValue(f_val)
+ self._outputs[index] = val
+ return val
+ elif isinstance(index, slice):
+ start, stop, step = index.indices(len(self))
+ vals = [self[i] for i in xrange(start, stop, step)]
+ return vals
+ else:
+ raise TypeError("index must be an integer or slice")
+
+
+class ControlFlowOpWrapper(object):
+ """A wrapper class for Operation."""
+
+ def __init__(self, op):
+ self._op = op
+ self._inputs = None
+ self._outputs = None
+
+ @property
+ def inputs(self):
+ if self._inputs is None:
+ self._inputs = ControlFlowOpInputs(self._op)
+ return self._inputs
+
+ @property
+ def outputs(self):
+ if self._outputs is None:
+ self._outputs = ControlFlowOpOutputs(self._op)
+ return self._outputs
+
+ @property
+ def op(self):
+ return self._op
+
+ @property
+ def name(self):
+ """Returns the name of this instance of op."""
+ return self._op.name
+
+ @property
+ def _id(self):
+ """Returns the unique id of this operation."""
+ return self._op._id
+
+ @property
+ def device(self):
+ """Returns the device of this operation.
+
+ Returns:
+ a string or None if the device was not set.
+ """
+ return self._op.device
+
+ @property
+ def output_types(self):
+ return self._op.output_types
+
+ @property
+ def input_types(self):
+ return self._op._input_types
+
+ @property
+ def type(self):
+ """Returns the type of the op."""
+ return self._op.type
+
+ @property
+ def graph(self):
+ """Returns the parent graph."""
+ return self._op.graph
+
+ def GetAttr(self, attr_name):
+ """Returns the value of attribute 'attr_name' of NodeDef."""
+ return self._op.get_attr(attr_name)
+
+ def _get_control_flow_context(self):
+ return self._op._get_control_flow_context()
+
+
+def GetRealOp(op):
+ while isinstance(op, ControlFlowOpWrapper):
+ op = op.op
+ return op
+
+
+def MakeWrapper(op):
+ """Make a wrapper for op if it is in a WhileContext."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ return ControlFlowOpWrapper(op)
+ return op
+
+
+def EnterGradWhileContext(op):
+ """Enter the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ grad_ctxt = forward_ctxt.CreateGradWhileContext()
+ grad_ctxt.Enter()
+
+
+def ExitGradWhileContext(op):
+ """Exit the WhileContext for gradient computation."""
+ forward_ctxt = op._get_control_flow_context()
+ if forward_ctxt and isinstance(forward_ctxt, WhileContext):
+ assert forward_ctxt.grad_context
+ forward_ctxt.grad_context.Exit()
+
+
+def _GetRealValue(value):
+ """Get the real value.
+
+ If backprop "uses" a value produced by forward inference, an
+ accumulator is added in the forward loop to accumulate its values,
+ so we use the accumulated value, indexed by the backprop counter.
+
+ Args:
+ value: A tensor to be captured.
+
+ Returns:
+ The same tensor value from the saved history.
+ """
+ real_value = value
+ forward_ctxt = value.op._get_control_flow_context()
+ real_value = forward_ctxt.history_map.get(value.name)
+ assert value.op.type != "Variable"
+ if real_value is None:
+ if value.op.type == "Enter" and value.op.get_attr("is_constant"):
+ # Use the input of this Enter node
+ real_value = GetRealOp(value.op).inputs[0]
+ else:
+ # Accumulate the history of this value.
+ # NOTE(yuanbyu): Don't accumulate for constants. One approach is
+ # to deepcopy the constants for the grad while context.
+ history_value = forward_ctxt.AddForwardAccumulateLoop(value)
+
+ # The shapes of the whole history and a single event element.
+ forward_ctxt.grad_context.Exit()
+ elem_rank = array_ops.rank(history_value) - 1
+ elem_rank_vec = array_ops.expand_dims(elem_rank, 0)
+ elem_shape = array_ops.slice(array_ops.shape(history_value), [1],
+ elem_rank_vec)
+ slice_shape = array_ops.concat(0, [[1], elem_shape])
+ forward_ctxt.grad_context.Enter()
+
+ # The begin position of the slice at slice_index.
+ slice_index = forward_ctxt.grad_context.index
+ b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32)
+ b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1])
+
+ # The slice at slice_index.
+ # TODO(irving): Replace with gather once that's GPU accelerated
+ real_value = array_ops.squeeze(
+ array_ops.slice(history_value,
+ b,
+ slice_shape,
+ name="real"),
+ squeeze_dims=[0])
+ forward_ctxt.history_map[value.name] = real_value
+ return real_value
+
+
+def IsLoopSwitch(op):
+ """Returns true if `op` is the Switch for a While loop."""
+ if op.type == "Switch":
+ ctxt = op._get_control_flow_context()
+ return ctxt and isinstance(ctxt, WhileContext)
+ return False
+
+
+class ControlFlowContext(object):
+ """The base class for control flow context.
+
+ The usage pattern is a sequence of (Enter, Exit) followed by a final
+ ExitResult.
+ """
+
+ def AddName(self, name):
+ self._values.add(name)
+
+ # pylint: disable=protected-access
+ def Enter(self):
+ """Enter the current context."""
+ self._outer_context = ops.get_default_graph()._get_control_flow_context()
+ ops.get_default_graph()._set_control_flow_context(self)
+
+ def Exit(self):
+ """Exit the current context."""
+ ops.get_default_graph()._set_control_flow_context(self._outer_context)
+ # pylint: enable=protected-access
+
+ def ExitResult(self, result):
+ """Make a list of tensors available in the outer context."""
+ if self._outer_context is not None:
+ for x in result:
+ self._outer_context.AddName(x.name)
+
+ def GetWhileContext(self):
+ """Get the current while context."""
+ if self._outer_context is not None:
+ return self._outer_context.GetWhileContext()
+ return None
+
+ def AddToWhileContext(self, op):
+ """Add a control dependency to the containing WhileContext.
+
+ The added control dependency ensures that the outputs of this op
+ belong to the WhileContext.
+
+ Args:
+ op: An operation.
+ """
+ while_ctxt = self.GetWhileContext()
+ if while_ctxt is not None:
+ # pylint: disable=protected-access
+ op._add_control_input(while_ctxt.GetControlPivot().op)
+ # pylint: enable=protected-access
+
+
+class CondContext(ControlFlowContext):
+ """The context for the conditional construct."""
+
+ def __init__(self, pred, pivot, branch):
+ self._pred = pred
+ self._outer_context = None
+ self._pivot = pivot
+ self._branch = branch
+ self._values = set()
+ self._values.add(pred.name)
+ self._values.add(pivot.name)
+ self._external_values = {}
+
+ @property
+ def pred(self):
+ return self._pred
+
+ @property
+ def pivot(self):
+ return self._pivot
+
+ @property
+ def branch(self):
+ return self._branch
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ result = with_dependencies([self._pivot], result)
+ self._external_values[val.name] = result
+ return result
+
+ def AddOp(self, op):
+ """Add 'op' to the current context."""
+ if not op.inputs:
+ # Add this op to the enclosing while context
+ self.AddToWhileContext(op)
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot.op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ if x.name not in self._values:
+ self._values.add(x.name)
+ # Add this value to the parent contexts up to the context that
+ # creates this value.
+ real_x = x
+ if self._outer_context is not None:
+ real_x = self._outer_context.AddValue(x)
+ real_x = _SwitchRefOrTensor(real_x, self._pred)[self._branch]
+ self._external_values[x.name] = real_x
+ x = self._external_values.get(x.name)
+ if x is not None:
+ op._update_input(index, x)
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def BuildCondBranch(self, fn):
+ """Add the subgraph defined by fn() to the graph."""
+ r = fn()
+ result = []
+ if r is not None:
+ if not isinstance(r, list) and not isinstance(r, _basetuple):
+ r = [r]
+ for v in r:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], self._pivot)
+ elif v.name not in self._values:
+ self._values.add(v.name)
+ if self._outer_context is not None:
+ v = self._outer_context.AddValue(v)
+ v = _SwitchRefOrTensor(v, self._pred)[self._branch]
+ else:
+ external_v = self._external_values.get(v.name)
+ if external_v is not None:
+ v = external_v
+ result.append(v)
+ return result
+
+
+def cond(pred, fn1, fn2, name=None):
+ """Return either 'fn1()' or 'fn2()' based on the boolean predicate 'pred'.
+
+ `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
+ the same number and type of outputs.
+
+ Args:
+ pred: A scalar determining whether to return the result of `fn1` or `fn2`.
+ fn1: The function to be performed if pred is true.
+ fn2: The function to be performed if pref is false.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ Tensors returned by the call to either `fn1` or `fn2`. If the functions
+ return a singleton list, the element is extracted from the list.
+
+ Raises:
+ TypeError: if `fn1` or `fn2` is not callable.
+ ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
+ return tensors of different types.
+
+ Example:
+ ```python
+ x = constant(2)
+ y = constant(5)
+ def f1(): return constant(17)
+ def f2(): return constant(23)
+ r = cond(math_ops.less(x, y), f1, f2)
+ # r is set to f1()
+ ```
+ """
+ with ops.op_scope([pred], name, "Cond") as name:
+ if not callable(fn1):
+ raise TypeError("fn1 must be callable.")
+ if not callable(fn2):
+ raise TypeError("fn2 must be callable.")
+
+ # Add the Switch to the graph.
+ p_2, p_1 = switch(pred, pred)
+ pivot_1 = array_ops.identity(p_1, name="switch_t")
+ pivot_2 = array_ops.identity(p_2, name="switch_f")
+ pred = array_ops.identity(pred, name="pred_id")
+
+ # Build the graph for the true branch in a new context.
+ context_t = CondContext(pred, pivot_1, 1)
+ context_t.Enter()
+ res_t = context_t.BuildCondBranch(fn1)
+ context_t.ExitResult(res_t)
+ context_t.Exit()
+
+ # Build the graph for the false branch in a new context.
+ context_f = CondContext(pred, pivot_2, 0)
+ context_f.Enter()
+ res_f = context_f.BuildCondBranch(fn2)
+ context_t.ExitResult(res_f)
+ context_f.Exit()
+
+ # Add the final merge to the graph.
+ if len(res_t) != len(res_f):
+ raise ValueError("fn1 and fn2 must return the same number of tensors.")
+ for x, y in zip(res_f, res_t):
+ assert ((isinstance(x, ops.IndexedSlices) and
+ isinstance(y, ops.IndexedSlices)) or
+ (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
+ val_x = x if isinstance(x, ops.Tensor) else x.values
+ val_y = y if isinstance(y, ops.Tensor) else y.values
+ if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
+ raise ValueError("Outputs of fn1 and fn2 must have the same type: "
+ "%s, %s" % (val_x.dtype.name, val_y.dtype.name))
+ merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
+ return merges[0] if len(merges) == 1 else merges
+
+
+# TODO(yuanbyu): We should probably separate the notion of context so it
+# could be used not only for conditionals and loops but also subgraphs.
+class WhileContext(ControlFlowContext):
+ """The context for the loop construct."""
+
+ def __init__(self, parallel_iterations, back_prop, name):
+ self._name = ops.get_default_graph().unique_name(name)
+ self._parallel_iterations = parallel_iterations
+ self._back_prop = back_prop
+ self._outer_context = None
+ # We use this node to control constants created by the pred lambda.
+ self._pivot_for_pred = None
+ # We use this node to control constants created by the body lambda.
+ self._pivot_for_body = None
+ # The boolean tensor for loop termination condition. Used in code
+ # generation for gradient computation
+ self._pivot = None
+
+ # The tensors for the counters added by AddForwardCounterLoop or
+ # AddBackPropCounterLoop
+ self._index = None
+
+ # Information needed by backprop
+ self._grad_context = None
+ self._total_iterations = None
+ self._history_map = {}
+ self._switch_map = {}
+
+ # values considered to have been already seen in this context
+ self._values = set()
+
+ # values referenced by but external to this context
+ self._external_values = {}
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def parallel_iterations(self):
+ """The number of iterations allowed to run in parallel."""
+ return self._parallel_iterations
+
+ @property
+ def back_prop(self):
+ """True iff backprop is enabled for this While loop."""
+ return self._back_prop
+
+ @property
+ def pivot(self):
+ """The boolean tensor representing the loop termination condition."""
+ return self._pivot
+
+ @property
+ def index(self):
+ """The loop index representing the current iteration."""
+ return self._index
+
+ @property
+ def grad_context(self):
+ """The corresponding WhileContext for gradient."""
+ return self._grad_context
+
+ @property
+ def history_map(self):
+ """The map that records all the tensors needed for backprop."""
+ return self._history_map
+
+ @property
+ def switch_map(self):
+ """The map that records all the Switch ops in the While loop."""
+ return self._switch_map
+
+ @property
+ def total_iterations(self):
+ """The total number of iterations of the while loop."""
+ return self._total_iterations
+
+ def GetWhileContext(self):
+ return self
+
+ def GetControlPivot(self):
+ if self._pivot_for_body:
+ return self._pivot_for_body
+ return self._pivot_for_pred
+
+ def AddValue(self, val):
+ """Add 'val' to the current context and its outer context recursively."""
+ result = val
+ if val.name not in self._values:
+ self._values.add(val.name)
+ if self._outer_context is not None:
+ result = self._outer_context.AddValue(val)
+ # Create an Enter that makes 'result' known to this context.
+ enter = _Enter(result, self._name, is_constant=True,
+ parallel_iterations=self._parallel_iterations)
+ self._values.add(enter.name)
+ self._external_values[val.name] = enter
+ result = enter
+ else:
+ actual_val = self._external_values.get(val.name)
+ if actual_val is not None:
+ result = actual_val
+ return result
+
+ def AddOp(self, op):
+ """Adds 'op' to the current context."""
+ if not op.inputs:
+ if not op.control_inputs:
+ # Add a control edge from the control pivot to this op.
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ else:
+ # Control edges must be in the same context.
+ for x in op.control_inputs:
+ assert x._get_control_flow_context() == self, (
+ "Control inputs must come from Operations in the same while "
+ "loop context (not an outer context).")
+ for x in op.outputs:
+ self._values.add(x.name)
+ else:
+ for index in range(len(op.inputs)):
+ x = op.inputs[index]
+ self.AddValue(x)
+ real_x = self._external_values.get(x.name)
+ if real_x is not None:
+ op._update_input(index, real_x)
+ # Add a control dependency to prevent loop invariants from
+ # enabling ops that should not be executed.
+ if real_x.op.type == "RefEnter" and real_x.op.get_attr("is_constant"):
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot().op)
+ # pylint: enable=protected-access
+ for x in op.outputs:
+ self._values.add(x.name)
+
+ def CreateGradWhileContext(self):
+ """Creates the WhileContext for backprop gradient computation."""
+ if self._grad_context is None:
+ cnt = self.AddForwardCounterLoop()
+ self._grad_context = WhileContext(self._parallel_iterations,
+ self._back_prop, self._name)
+ self._grad_context.AddBackPropCounterLoop(cnt)
+ return self._grad_context
+
+ def AddForwardCounterLoop(self):
+ """Adds a loop that counts the number of iterations.
+
+ This is added to the forward loop at the time when we start to
+ create the loop for backprop gradient computation.
+
+ The pseudocode is:
+ `n = 0; while (_pivot) { n++; }`
+
+ Returns:
+ The number of iterations taken by the forward loop.
+ """
+ n = constant_op.constant(0, name="f_count")
+ self.Enter()
+ self.AddName(n.name)
+ enter_n = _Enter(n, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_count")
+ merge_n = merge([enter_n, enter_n])[0]
+ switch_n = switch(merge_n, self._pivot)
+ self._index = switch_n[1]
+
+ add_n = math_ops.add(self._index, 1)
+ next_n = next_iteration(add_n)
+ merge_n.op._update_input(1, next_n)
+
+ self._total_iterations = exit(switch_n[0], name="f_count")
+ self.Exit()
+ return self._total_iterations
+
+ def AddForwardAccumulateLoop(self, value):
+ """Add an accumulation loop for each value needed in backprop.
+
+ This is added to the forward loop at the first time when a value
+ in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ if (index == 0) [value] else Concat(acc, [value]);
+ }
+ ```
+
+ Args:
+ value: The tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+
+ Raises:
+ ValueError: If the shape of "value" is not known statically.
+ """
+ if not value.get_shape().is_fully_defined():
+ raise ValueError("Must have known shape: %s" % value)
+ self._grad_context.Exit()
+ # TODO(irving): Now that acc starts out empty, most of the
+ # conditional logic can go away.
+ acc = constant_op.constant([],
+ value.dtype,
+ shape=[0] + value.get_shape().as_list(),
+ name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ # If index = 0 then [value] else Concat(acc, [value]).
+ cond = math_ops.greater(self._index, 0)
+ switch_add_acc = switch(switch_acc[1], cond)
+ expand_value = array_ops.expand_dims(value, 0)
+ true_branch = array_ops.concat(0, [switch_add_acc[1], expand_value])
+ false_branch = array_ops.identity(switch_add_acc[0])
+ false_branch = with_dependencies([false_branch], expand_value)
+ add_acc = merge([false_branch, true_branch])[0]
+
+ next_acc = next_iteration(add_acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddForwardAccumulateCondLoop(self, value):
+ """Add an accumulation loop for each conditional switch.
+
+ This is added to the forward loop at the first time when a conditional
+ switch in the forward loop is used by backprop gradient computation loop.
+
+ The pseudocode is:
+ ```
+ acc;
+ while (_pivot) {
+ Concat(acc, value);
+ }
+ ```
+
+ Args:
+ value: The boolean tensor that is accumulated.
+
+ Returns:
+ The accumulated history of value.
+ """
+ self._grad_context.Exit()
+ acc = constant_op.constant(False, name="f_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_acc")
+ merge_acc = merge([enter_acc, enter_acc])[0]
+ switch_acc = switch(merge_acc, self._pivot)
+ acc = array_ops.concat(0, [switch_add_acc[1], value])
+ next_acc = next_iteration(acc)
+ merge_acc.op._update_input(1, next_acc)
+
+ exit_acc = exit(switch_acc[0], name="f_acc")
+ self.Exit()
+ self._grad_context.Enter()
+ return exit_acc
+
+ def AddBackPropCounterLoop(self, count):
+ """Add the backprop loop that controls the iterations.
+
+ This is added to the backprop loop. It is used to control the loop
+ termination and the slice index.
+
+ The pseudocode is:
+ `n = count; while (n >= 1) { n--; }`
+
+ Args:
+ count: The number of iterations for backprop.
+
+ Returns:
+ always 0.
+ """
+ one = constant_op.constant(1, name="b_count")
+ self.Enter()
+ self.AddName(count.name)
+ enter_count = _Enter(count, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_count")
+ merge_count = merge([enter_count, enter_count])[0]
+ self._pivot_for_pred = merge_count
+
+ cond = math_ops.greater_equal(merge_count, one)
+ self._pivot = loop_cond(cond, name="b_count")
+ switch_count = switch(merge_count, self._pivot)
+
+ # Add next_iteration right after Switch to match the gradient function.
+ next_count = next_iteration(switch_count[1])
+ self._pivot_for_body = next_count
+ self._index = math_ops.sub(next_count, one)
+ merge_count.op._update_input(1, self._index)
+
+ exit_count = exit(switch_count[0], name="b_count")
+ self.Exit()
+ return exit_count
+
+ def AddBackPropAccumulateLoop(self, value):
+ """Add an accumulation loop for every loop invariant.
+
+ This is added to the backprop loop. It is used to accumulate partial
+ gradients for each loop iteration. Called when in the while context
+ for gradient.
+
+ The pseudocode is:
+ ```
+ acc = 0;
+ while (_pivot) {
+ acc += value;
+ }
+ ```
+
+ Args:
+ value: The partial gradient of an iteration for a loop invariant.
+
+ Returns:
+ The gradient for a loop invariant.
+ """
+ self.Exit()
+ acc = constant_op.constant(0, value.dtype, name="b_acc")
+ self.Enter()
+ self.AddName(acc.name)
+ enter_acc = _Enter(acc, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_acc")
+ merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
+ switch_acc = switch(merge_acc, self._pivot)
+
+ next_acc = next_iteration(switch_acc[1])
+ add_acc = math_ops.add(next_acc, value)
+ merge_acc.op._update_input(1, add_acc)
+
+ exit_acc = exit(switch_acc[0], name="b_acc")
+ return exit_acc
+
+ def BuildLoop(self, pred, body, loop_vars):
+ """Add the loop termination condition and body to the graph."""
+
+ loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
+ # Let the context know the loop variabes so the _Enter nodes below
+ # would be added into the context correctly.
+ self._values = set([x.name for x in loop_vars])
+ if self._outer_context is not None:
+ real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
+ else:
+ real_vars = loop_vars
+ enter_vars = [_Enter(x, self._name, is_constant=False,
+ parallel_iterations=self._parallel_iterations)
+ for x in real_vars]
+ self._values = set([x.name for x in enter_vars])
+
+ merge_vars = [merge([x, x])[0] for x in enter_vars]
+ self._pivot_for_pred = merge_vars[0]
+
+ # Build the graph for pred.
+ c = ops.convert_to_tensor(pred(*merge_vars))
+ self._pivot = loop_cond(c, name="LoopCond")
+ switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
+
+ # Build the graph for body.
+ vars_for_body = [_Identity(x[1]) for x in switch_vars]
+ self._pivot_for_body = vars_for_body[0]
+
+ body_result = body(*vars_for_body)
+ if not isinstance(body_result, (list, _basetuple)):
+ body_result = [body_result]
+ result = ops.convert_n_to_tensor_or_indexed_slices(body_result)
+ next_vars = [next_iteration(x) for x in result]
+
+ # Add the back edges to complete the loop.
+ assert len(merge_vars) == len(next_vars)
+ for x in zip(merge_vars, next_vars):
+ x[0].op._update_input(1, x[1])
+
+ # Add the exit ops.
+ exit_vars = [exit(x[0]) for x in switch_vars]
+
+ for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
+ if m_var.get_shape().is_compatible_with(n_var.get_shape()):
+ e_var.set_shape(m_var.get_shape().merge_with(n_var.get_shape()))
+
+ # Exit the loop.
+ self.ExitResult(exit_vars)
+ self.Exit()
+ return exit_vars[0] if len(exit_vars) == 1 else exit_vars
+
+
+def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
+ name=None):
+ """Repeat `body` while the condition `cond` is true.
+
+ `cond` is a function taking a list of tensors and returning a boolean scalar
+ tensor. `body` is a function taking a list of tensors and returning a list of
+ tensors of the same length and with the same types as the input. `loop_vars`
+ is a list of tensors that is passed to both `cond` and `body`.
+
+ While `cond` evaluates to true, `body` is executed.
+
+ Args:
+ cond: The termination condition of the loop.
+ body: A function that represents the loop body.
+ loop_vars: The list of variable input tensors.
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ back_prop: Whether backprop is enabled for this while loop.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ The output tensors for the loop variables after the loop.
+
+ Raises:
+ TypeError: if `cond` or `body` is not callable.
+ ValueError: if `loop_var` is empty.
+
+ Example:
+ ```python
+ i = Constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: math_ops.add(i, 1)
+ r = While(c, b, [i])
+ ```
+ """
+ with ops.op_scope(loop_vars, name, "While") as name:
+ if not loop_vars:
+ raise ValueError("No loop variables provided")
+ if not callable(cond):
+ raise TypeError("cond must be callable.")
+ if not callable(body):
+ raise TypeError("body must be callable.")
+
+ context = WhileContext(parallel_iterations, back_prop, name)
+ context.Enter()
+ return context.BuildLoop(cond, body, loop_vars)
+
+
+def _AsTensorList(x, p):
+ """Return x as a list of Tensors or IndexedSlices.
+
+ For entries of `x` that are Operations, this returns an Identity of `p`
+ with a dependency on the operation.
+
+ Args:
+ x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
+ p: A Tensor to return for entries in `x` that are Operations.
+
+ Returns:
+ A list of Tensors or IndexedSlices.
+ """
+ if not isinstance(x, list) and not isinstance(x, _basetuple):
+ x = [x]
+
+ l = []
+ for v in x:
+ if isinstance(v, ops.Operation):
+ v = with_dependencies([v], p)
+ v = ops.convert_to_tensor_or_indexed_slices(v)
+ if isinstance(v, ops.Tensor):
+ l.append(array_ops.identity(v))
+ else:
+ l.append(ops.IndexedSlices(array_ops.identity(v.values),
+ array_ops.identity(v.indices)))
+ return l
+
+
+def _CheckResults(a, b):
+ assert len(a) == len(b), (
+ "Values returned by a() and b() must have the same length.")
+ for x, y in zip(a, b):
+ assert x.dtype == y.dtype, (
+ "Values returned by a() [%s] and b() [%s] must have "
+ "the same type: %s, %s." %
+ (x.name, y.name, x.dtype.name, y.dtype.name))
+
+
+def with_dependencies(dependencies, output_tensor, name=None):
+ """Produces the content of `output_tensor` only after `dependencies`.
+
+ In some cases, a user may want the output of an operation to be
+ consumed externally only after some other dependencies have run
+ first. This function ensures returns `output_tensor`, but only after all
+ operations in `dependencies` have run. Note that this means that there is
+ no guarantee that `output_tensor` will be evaluated after any `dependencies`
+ have run.
+
+ See also `tuple` and `group`.
+
+ Args:
+ dependencies: A list of operations to run before this op finishes.
+ output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
+ name: (Optional) A name for this operation.
+
+ Returns:
+ Same as `output_tensor`.
+
+ Raises:
+ TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
+ """
+ with ops.op_scope(dependencies + [output_tensor], name,
+ "control_dependency") as name:
+ with ops.device(output_tensor.device
+ or ops.get_default_graph().get_default_device()):
+ with ops.control_dependencies(dependencies):
+ output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
+ if isinstance(output_tensor, ops.Tensor):
+ return _Identity(output_tensor, name=name)
+ else:
+ return ops.IndexedSlices(_Identity(output_tensor.values, name=name),
+ output_tensor.indices,
+ output_tensor.dense_shape)
+
+
+def _GroupControlDeps(dev, deps, name=None):
+ with ops.control_dependencies(deps):
+ if dev is None:
+ return no_op(name=name)
+ else:
+ with ops.device(dev):
+ return no_op(name=name)
+
+
+# TODO(mdevin): Accept "inputs" as a list.
+def group(*inputs, **kwargs):
+ """Create an op that groups multiple operations.
+
+ When this op finishes, all ops in `input` have finished. This op has no
+ output.
+
+ See also `tuple` and `with_dependencies`.
+
+ Args:
+ *inputs: One or more tensors to group.
+ **kwargs: Optional parameters to pass when constructing the NodeDef.
+ name: A name for this operation (optional).
+
+ Returns:
+ An Operation that executes all its inputs.
+
+ Raises:
+ ValueError: If an unknown keyword argument is provided, or if there are
+ no inputs.
+ """
+ name = kwargs.pop("name", None)
+ if kwargs:
+ raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
+ if not inputs:
+ # TODO(mdevin): Would make sense to return a NoOp.
+ raise ValueError("No inputs provided")
+ with ops.op_scope(inputs, name, "group_deps") as name:
+ # Sorts *inputs according to their devices.
+ ops_on_device = {} # device -> operations specified on the device.
+ for inp in inputs:
+ dev = inp.device
+ if dev in ops_on_device:
+ ops_on_device[dev].append(inp)
+ else:
+ ops_on_device[dev] = [inp]
+ if len(ops_on_device) == 1:
+ # 1-level tree. The root node is the returned NoOp node.
+ dev, deps = ops_on_device.items()[0]
+ return _GroupControlDeps(dev, deps, name=name)
+ # 2-level tree. The root node is the returned NoOp node.
+ # deps contains 1 NoOp node for each device.
+ deps = []
+ for dev in sorted(ops_on_device.iterkeys()):
+ deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
+ return _GroupControlDeps(None, deps, name=name)
+
+def tuple(tensors, name=None, control_inputs=None):
+ """Group tensors together.
+
+ This creates a tuple of tensors with the same values as the `tensors`
+ argument, except that the value of each tensor is only returned after the
+ values of all tensors have been computed.
+
+ `control_inputs` contains additional ops that have to finish before this op
+ finishes, but whose outputs are not returned.
+
+ This can be used as a "join" mechanism for parallel computations: all the
+ argument tensors can be computed in parallel, but the values of any tensor
+ returned by `tuple` are only available after all the parallel computations
+ are done.
+
+ See also `group` and `with_dependencies`.
+
+ Args:
+ tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
+ name: (optional) A name to use as a `name_scope` for the operation.
+ control_inputs: List of additional ops to finish before returning.
+
+ Returns:
+ Same as `tensors`.
+
+ Raises:
+ ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
+
+ """
+ with ops.op_scope(tensors, name, "tuple") as name:
+ gating_ops = [t.op for t in tensors if t]
+ if control_inputs:
+ gating_ops += control_inputs
+ # Note that in order to ensure ordering in the pbtxt, we must take care to
+ # ensure the order here.
+ gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
+ if not gating_ops:
+ raise ValueError("Must have at least one Tensor: %s" % tensors)
+ gate = group(*gating_ops)
+ tpl = []
+ for t in tensors:
+ if t:
+ tpl.append(with_dependencies([gate], t))
+ else:
+ tpl.append(None)
+ return tpl
+
+
+# TODO(yuanbyu): It would be nicer if we could have the distributed list
+# support that Derek has been proposing.
+# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
+def fold(fn, elems, elem_shape, name=None):
+ """The fold operator on slices of a tensor.
+
+ This fold operator applies the function `fn` to slices of `elems` on
+ dimension 0. The shape of the slices is specified by `elem_shape`. `elems`
+ must contain at least one slice (`shape(elems)[0] / elem_shape[0] > 0`).
+
+ Args:
+ fn: The function to be performed on each slice of the tensor.
+ elems: The tensor to whose slices we want to apply `fn`.
+ elem_shape: The shape definition for the slices.
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ A tensor resulting from applying `fn` consecutively on each slice of
+ `elems`.
+
+ Raises:
+ TypeError: if `fn` is not callable.
+ """
+ with ops.op_scope([elems], name, "Fold") as name:
+ if not callable(fn):
+ raise TypeError("fn must be callable.")
+
+ s0 = array_ops.shape(elems)[0]
+ d0 = elem_shape[0]
+ n = math_ops.div(s0, d0)
+ b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0),
+ dtype=types.int32)
+ # Initialize the output with slice 0
+ b = array_ops.concat(0, [[0], b1])
+ o = array_ops.slice(elems, b, elem_shape)
+ i = ops.convert_to_tensor(d0)
+
+ def Compute(i, o):
+ b = array_ops.concat(0, [array_ops.expand_dims(i, 0), b1])
+ x = array_ops.slice(elems, b, elem_shape)
+ o = fn(o, x)
+ i = math_ops.add(i, d0)
+ return [i, o]
+ r = While(lambda i, o: math_ops.less(i, n), Compute, [i, o])
+ return r[1]
+
+
+def case(pred_fn_pairs, default, exclusive=False, name="Case"):
+ """Create a Case operation.
+
+ The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
+ Each pair contains a boolean scalar tensor and a python callable that
+ creates the tensors to be returned if the boolean evaluates to True. `default`
+ is a callable generating a list of tensors. All the callables in
+ `pred_fn_pairs` as well as `default` should return the same number and types
+ of tensors.
+
+ If `exclusive==True`, all predicates are evaluated, and a logging operation
+ with an error is returned if more than one of the predicates evaluates to
+ True. If `exclusive==False`, execution stops are the first predicate which
+ evaluates to True, and the tensors generated by the corresponding function
+ are returned immediately. If none of the predicates evaluate to True, this
+ operation returns the tensors generated by `default`.
+
+ Example 1:
+ Pseudocode:
+ ```
+ if (x < y) return 17;
+ else return 23;
+ ```
+
+ Expressions:
+ ```
+ f1 = lambda: Constant(17)
+ f2 = lambda: Constant(23)
+ r = Case([(math_ops.less(x, y), f1)], default=f2)
+ ```
+
+ Example 2:
+ Pseudocode:
+ ```
+ if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
+ if (x < y) return 17;
+ else if (x > z) return 23;
+ else return -1;
+ ```
+
+ Expressions:
+ ```
+ def f1(): return Constant(17)
+ def f2(): return Constant(23)
+ def f3(): return Constant(-1)
+ r = Case({math_ops.less(x, y): f1, math_ops.greater(x, z): f2},
+ default=f3, exclusive=True)
+ ```
+
+ Args:
+ pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
+ callable which returns a list of tensors.
+ default: A callable that returns a list of tensors.
+ exclusive: True iff more than one predicate is allowed to evaluate to True.
+ name: A name for this operation (optional).
+
+ Returns:
+ The tensors returned by the first pair whose predicate evaluated to True, or
+ those returned by `default` if none does.
+
+ Raises:
+ TypeError: If `pred_fn_pairs` is not a list/dictionary.
+ TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
+ TypeError: If `fns[i]` is not callable for any i, or `default` is not
+ callable.
+ """
+ pfp = pred_fn_pairs # For readability
+ if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)
+ or isinstance(pfp, dict)):
+ raise TypeError("fns must be a list, tuple, or dict")
+ if isinstance(pfp, dict):
+ pfp = pfp.items()
+ if not exclusive:
+ logging.warn("%s: Provided dictionary of predicate/fn pairs, but "
+ "exclusive=False. Order of conditional tests is "
+ "not guaranteed." % name)
+ for tup in pfp:
+ if not isinstance(tup, _basetuple) or len(tup) != 2:
+ raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
+ pred, fn = tup
+ if pred.dtype != types.bool:
+ raise TypeError("pred must be of type bool: %s", pred.name)
+ if not callable(fn):
+ raise TypeError("fn for pred %s must be callable." % pred.name)
+ if not callable(default):
+ raise TypeError("default must be callable.")
+
+ preds, fns = map(list, zip(*pfp))
+ with ops.op_scope([[f() for f in fns] + preds + [default()]], name, "Case"):
+ if not preds:
+ return default()
+ not_preds = []
+ for i, p in enumerate(preds):
+ with ops.name_scope("not_%d" % i):
+ not_preds.append(math_ops.logical_not(p))
+ and_not_preds = [constant_op.constant(True, name="and_not_true")]
+ for i, notp in enumerate(not_preds[:-1]):
+ with ops.name_scope("and_not_%d" % i):
+ and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp))
+
+ # preds = [p1, p2, p3]
+ # fns = [f1, f2, f3]
+ # not_preds = [~p1, ~p2, ~p3]
+ # case_preds = [p1 & True,
+ # p2 & ~p1,
+ # p3 & ~p1 & ~ p2]
+ case_preds = []
+ for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds)):
+ with ops.name_scope("case_%d" % i):
+ case_preds.append(math_ops.logical_and(p, and_not_p_prev))
+
+ # case_sequence = [Cond(p3 & ..., f3, default),
+ # Cond(p2 & ..., f2, lambda: case_sequence[0]),
+ # ...
+ # Cond(p1 & True, f1, lambda: case_sequence[i-1])]
+ # and prev_case_seq will loop from case_sequence[0] to case_sequence[-1]
+ if exclusive:
+ # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))
+ preds_c = array_ops.concat(0, preds, name="preds_c")
+ num_true_conditions = math_ops.reduce_sum(
+ math_ops.cast(preds_c, types.int32), name="num_true_conds")
+ at_most_one_true_condition = math_ops.less(
+ num_true_conditions, constant_op.constant(2, name="two_true_conds"))
+
+ error_msg = [
+ ("More than one condition evaluated as True but "
+ "exclusive=True. Conditions: (%s), Values:"
+ % ", ".join([p.name for p in preds])),
+ preds_c]
+ with ops.control_dependencies([
+ logging_ops.Assert(condition=at_most_one_true_condition,
+ data=error_msg, summarize=len(preds))]):
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+ else:
+ prev_case_seq = default()
+ for i, (cp, fn) in enumerate(zip(case_preds, fns)[::-1]):
+ prev_case_seq = cond(cp, fn, lambda: prev_case_seq, name="If_%d" % i)
+
+ return prev_case_seq
+
+
+ops.RegisterShape("Enter")(common_shapes.unchanged_shape)
+ops.RegisterShape("Exit")(common_shapes.unknown_shape)
+ops.RegisterShape("NextIteration")(common_shapes.unchanged_shape)
+ops.RegisterShape("RefEnter")(common_shapes.unchanged_shape)
+ops.RegisterShape("ControlTrigger")(common_shapes.no_outputs)
+ops.RegisterShape("NoOp")(common_shapes.no_outputs)
+
+
+@ops.RegisterShape("LoopCond")
+def _LoopCondShape(op):
+ """Shape function for the LoopCond op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("Merge")
+def _MergeShape(op):
+ """Shape function for the Merge op.
+
+ The Merge op takes many inputs of arbitrary shapes, and produces a
+ first output that is one of those inputs, and a second scalar
+ output.
+
+ This function conservatively assumes that if any of its inputs is
+ not fully defined, the output shape is unknown. If all of the inputs
+ have the exact same known shape, the output must have that shape.
+
+ Args:
+ op: A Merge Operation.
+
+ Returns:
+ A single-element list containing the Shape of the Merge op.
+
+ """
+ first_input_shape = op.inputs[0].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[1:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+ return [first_input_shape, tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape(), tensor_shape.scalar()]
+
+
+@ops.RegisterShape("RefSelect")
+def _RefSelectShape(op):
+ """Shape function for the RefSelect op.
+
+ The RefSelect takes one scalar input and N inputs of arbitrary
+ shapes, and produces one output, which is one of those N inputs.
+
+ This function conservatively assumes that if any of the N inputs is
+ not fully defined, the output shape is unknown. If all of the N
+ inputs have the exact same known shape, the output must have that
+ shape.
+
+ Args:
+ op: A RefSelect Operation.
+
+ Returns:
+ A single-element list containing the Shape of the RefSelect op.
+ """
+ unused_shape = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ first_input_shape = op.inputs[1].get_shape()
+ if first_input_shape.is_fully_defined():
+ for input_ in op.inputs[2:]:
+ input_shape = input_.get_shape()
+ if (not input_shape.is_fully_defined()
+ or not input_shape.is_compatible_with(first_input_shape)):
+ return [tensor_shape.unknown_shape()]
+ return [first_input_shape]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("RefSwitch")
+@ops.RegisterShape("Switch")
+def _SwitchShape(op):
+ input_shape = op.inputs[0].get_shape()
+ unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
+ return [input_shape] * 2
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
new file mode 100644
index 0000000000..34b1ab0a25
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -0,0 +1,88 @@
+"""Tests for control_flow_ops.py."""
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops as tf
+from tensorflow.python.platform import googletest
+
+
+class GroupTestCase(TensorFlowTestCase):
+
+ def _StripNode(self, nd):
+ snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
+ if nd.device:
+ snode.device = nd.device
+ return snode
+
+ def _StripGraph(self, gd):
+ """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
+ return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
+
+ def testGroup_NoDevices(self):
+ with ops.Graph().as_default() as g:
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ c = tf.constant(0, name="c")
+ tf.group(a.op, b.op, c.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const"}
+ node { name: "b" op: "Const"}
+ node { name: "c" op: "Const"}
+ node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
+ """, self._StripGraph(gd))
+
+ def testGroup_OneDevice(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/task:0"):
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ tf.group(a.op, b.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const" device: "/task:0" }
+ node { name: "b" op: "Const" device: "/task:0" }
+ node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
+ """, self._StripGraph(gd))
+
+ def testGroup_MultiDevice(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/task:0"):
+ a = tf.constant(0, name="a")
+ b = tf.constant(0, name="b")
+ with g.device("/task:1"):
+ c = tf.constant(0, name="c")
+ d = tf.constant(0, name="d")
+ with g.device("/task:2"):
+ tf.group(a.op, b.op, c.op, d.op, name="root")
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "a" op: "Const" device: "/task:0"}
+ node { name: "b" op: "Const" device: "/task:0"}
+ node { name: "c" op: "Const" device: "/task:1"}
+ node { name: "d" op: "Const" device: "/task:1"}
+ node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
+ device: "/task:0" }
+ node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
+ device: "/task:1" }
+ node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
+ device: "/task:2" }
+ """, self._StripGraph(gd))
+
+
+class ShapeTestCase(TensorFlowTestCase):
+
+ def testShape(self):
+ with ops.Graph().as_default():
+ tensor = tf.constant([1.0, 2.0])
+ self.assertEquals([2], tensor.get_shape())
+ self.assertEquals([2],
+ control_flow_ops.with_dependencies(
+ [tf.constant(1.0)], tensor).get_shape())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
new file mode 100644
index 0000000000..d2473490ce
--- /dev/null
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -0,0 +1,37 @@
+"""Gradients for operators defined in data_flow_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+@ops.RegisterGradient("DynamicStitch")
+def _DynamicStitchGrads(op, grad):
+ """Gradients for DynamicStitch."""
+
+ num_values = len(op.inputs) / 2
+ indices_grad = [None] * num_values
+
+ def AsInt32(x):
+ return (x if op.inputs[0].dtype == types.int32 else
+ math_ops.cast(x, types.int32))
+ inputs = [AsInt32(op.inputs[i]) for i in range(num_values)]
+ if isinstance(grad, ops.IndexedSlices):
+ output_shape = array_ops.shape(op.outputs[0])
+ output_rows = output_shape[0]
+ grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows)
+ values_grad = [array_ops.gather(grad, inp) for inp in inputs]
+ return indices_grad + values_grad
+
+
+ops.NoGradient("Queue")
+ops.NoGradient("QueueEnqueue")
+ops.NoGradient("QueueEnqueueMany")
+ops.NoGradient("QueueDequeue")
+ops.NoGradient("QueueDequeueMany")
+ops.NoGradient("QueueClose")
+ops.NoGradient("QueueSize")
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
new file mode 100644
index 0000000000..5c8ab66297
--- /dev/null
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -0,0 +1,680 @@
+"""Data Flow Operations."""
+# pylint: disable=g-bad-name
+import re
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_data_flow_ops import *
+
+
+def _as_type_list(dtypes):
+ """Convert dtypes to a list of types."""
+ assert dtypes is not None
+ if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
+ # We have a single type.
+ return [dtypes]
+ else:
+ # We have a list or tuple of types.
+ return list(dtypes)
+
+
+def _as_shape_list(shapes, dtypes):
+ """Convert shapes to a list of tuples of int (or None)."""
+ if shapes is None: return None
+ if isinstance(shapes, tensor_shape.TensorShape):
+ shapes = [shapes]
+ if not isinstance(shapes, (tuple, list)):
+ raise TypeError(
+ "shapes must be a TensorShape or a list or tuple of TensorShapes.")
+ if all(isinstance(shape, int) for shape in shapes):
+ # We have a single shape.
+ shapes = [shapes]
+ shapes = [tensor_shape.as_shape(shape) for shape in shapes]
+ if any(not shape.is_fully_defined() for shape in shapes):
+ raise ValueError("All shapes must be fully defined.")
+ return shapes
+
+
+# pylint: disable=protected-access
+class QueueBase(object):
+ """Base class for queue implementations.
+
+ A queue is a TensorFlow data structure that stores tensors across
+ multiple steps, and exposes operations that enqueue and dequeue
+ tensors.
+
+ Each queue element is a tuple of one or more tensors, where each
+ tuple component has a static dtype, and may have a static shape. The
+ queue implementations support versions of enqueue and dequeue that
+ handle single elements, versions that support enqueuing and
+ dequeuing a batch of elements at once.
+
+ See [`tf.FIFOQueue`](#FIFOQueue) and
+ [`tf.RandomShuffleQueue`](#RandomShuffleQueue) for concrete
+ implementations of this class, and instructions on how to create
+ them.
+
+ @@enqueue
+ @@enqueue_many
+
+ @@dequeue
+ @@dequeue_many
+
+ @@size
+
+ @@close
+
+ """
+
+ def __init__(self, dtypes, shapes, queue_ref):
+ """Constructs a queue object from a queue reference.
+
+ Args:
+ dtypes: A list of types. The length of dtypes must equal the number
+ of tensors in each element.
+ shapes: Constraints on the shapes of tensors in an element:
+ A list of shape tuples or None. This list is the same length
+ as dtypes. If the shape of any tensors in the element are constrained,
+ all must be; shapes can be None if the shapes should not be constrained.
+ queue_ref: The queue reference, i.e. the output of the queue op.
+ """
+ self._dtypes = dtypes
+ if shapes is not None:
+ self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
+ else:
+ self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
+ self._queue_ref = queue_ref
+ self._name = self._queue_ref.op.name.split("/")[-1]
+
+ @staticmethod
+ def from_list(index, queues):
+ """Create a queue using the queue reference from `queues[index]`.
+
+ Args:
+ index: An integer scalar tensor that determines the input that gets
+ selected.
+ queues: A list of `QueueBase` objects.
+
+ Returns:
+ A `QueueBase` object.
+
+ Raises:
+ TypeError: when `queues` is not a list of `QueueBase` objects,
+ or when the data types of `queues` are not all the same.
+ """
+ if ((not queues) or
+ (not isinstance(queues, list)) or
+ (not all([isinstance(x, QueueBase) for x in queues]))):
+ raise TypeError("A list of queues expected")
+
+ dtypes = queues[0].dtypes
+ if not all([dtypes == q.dtypes for q in queues[1:]]):
+ raise TypeError("Queues do not have matching component dtypes.")
+
+ queue_refs = [x.queue_ref for x in queues]
+ selected_queue = control_flow_ops.ref_select(index, queue_refs)
+ # TODO(josh11b): Unify the shapes of the queues too?
+ return QueueBase(dtypes=dtypes, shapes=None, queue_ref=selected_queue)
+
+ @property
+ def queue_ref(self):
+ """The underlying queue reference."""
+ return self._queue_ref
+
+ @property
+ def name(self):
+ """The name of the underlying queue."""
+ return self._queue_ref.op.name
+
+ @property
+ def dtypes(self):
+ """The list of dtypes for each component of a queue element."""
+ return self._dtypes
+
+ def enqueue(self, vals, name=None):
+ """Enqueues one element to this queue.
+
+ If the queue is full when this operation executes, it will block
+ until the element has been enqueued.
+
+ Args:
+ vals: The tuple of `Tensor` objects to be enqueued.
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that enqueues a new tuple of tensors to the queue.
+ """
+ if name is None:
+ name = "%s_enqueue" % self._name
+ ret = gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the Queue object.
+ for val, shape in zip(ret.inputs[1:], self._shapes):
+ val.get_shape().assert_is_compatible_with(shape)
+
+ return ret
+
+ def enqueue_many(self, vals, name=None):
+ """Enqueues zero or elements to this queue.
+
+ This operation slices each component tensor along the 0th dimension to
+ make multiple queue elements. All of the tensors in `vals` must have the
+ same size in the 0th dimension.
+
+ If the queue is full when this operation executes, it will block
+ until all of the elements have been enqueued.
+
+ Args:
+ vals: The tensor or tuple of tensors from which the queue elements
+ are taken.
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that enqueues a batch of tuples of tensors to the queue.
+ """
+ if name is None:
+ name = "%s_EnqueueMany" % self._name
+
+ ret = gen_data_flow_ops._queue_enqueue_many(
+ self._queue_ref, vals, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ batch_dim = ret.inputs[1].get_shape()[0]
+ for val, shape in zip(ret.inputs[1:], self._shapes):
+ batch_dim.merge_with(val.get_shape()[0])
+ val.get_shape()[1:].assert_is_compatible_with(shape)
+
+ return ret
+
+ def dequeue(self, name=None):
+ """Dequeues one element from this queue.
+
+ If the queue is empty when this operation executes, it will block
+ until there is an element to dequeue.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of tensors that was dequeued.
+ """
+ if name is None:
+ name = "%s_Dequeue" % self._name
+ ret = gen_data_flow_ops._queue_dequeue(
+ self._queue_ref, self._dtypes, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the `QueueBase` object.
+ op = ret[0].op
+ for output, shape in zip(op.values(), self._shapes):
+ output.set_shape(shape)
+
+ return ret if len(ret) != 1 else ret[0]
+
+ def dequeue_many(self, n, name=None):
+ """Dequeues and concatenates `n` elements from this queue.
+
+ This operation concatenates queue-element component tensors along
+ the 0th dimension to make a single component tensor. All of the
+ components in the dequeued tuple will have size `n` in the 0th dimension.
+
+ If the queue contains fewer than `n` elements when this operation
+ executes, it will block until `n` elements have been dequeued.
+
+ Args:
+ n: A scalar `Tensor` containing the number of elements to dequeue.
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of concatenated tensors that was dequeued.
+ """
+ if name is None:
+ name = "%s_DequeueMany" % self._name
+
+ ret = gen_data_flow_ops._queue_dequeue_many(
+ self._queue_ref, n, self._dtypes, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the Queue object.
+ op = ret[0].op
+ batch_dim = tensor_shape.Dimension(tensor_util.ConstantValue(op.inputs[1]))
+ for output, shape in zip(op.values(), self._shapes):
+ output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
+
+ return ret if len(ret) != 1 else ret[0]
+
+ def close(self, cancel_pending_enqueues=False, name=None):
+ """Closes this queue.
+
+ This operation signals that no more elements will be enqueued in
+ the given queue. Subsequent `enqueue` and `enqueue_many`
+ operations will fail. Subsequent `dequeue` and `dequeue_many`
+ operations will continue to succeed if sufficient elements remain
+ in the queue. Subsequent `dequeue` and `dequeue_many` operations
+ that would block will fail immediately.
+
+ If `cancel_pending_enqueues` is `True`, all pending requests will also
+ be cancelled.
+
+ Args:
+ cancel_pending_enqueues: (Optional.) A boolean, defaulting to
+ `False` (described above).
+ name: A name for the operation (optional).
+
+ Returns:
+ The operation that closes the queue.
+ """
+ if name is None:
+ name = "%s_Close" % self._name
+ return gen_data_flow_ops._queue_close(
+ self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ name=name)
+
+ def size(self, name=None):
+ """Compute the number of elements in this queue.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this queue.
+ """
+ if name is None:
+ name = "%s_Size" % self._name
+ return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
+
+
+class RandomShuffleQueue(QueueBase):
+ """A queue implementation that dequeues elements in a random order.
+
+ See [`tf.QueueBase`](#QueueBase) for a description of the methods on
+ this class.
+
+ @@__init__
+ """
+
+ def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None,
+ seed=None, shared_name=None, name="random_shuffle_queue"):
+ """Create a queue that dequeues elements in a random order.
+
+ A `RandomShuffleQueue` has bounded capacity; supports multiple
+ concurrent producers and consumers; and provides exactly-once
+ delivery.
+
+ A `RandomShuffleQueue` holds a list of up to `capacity`
+ elements. Each element is a fixed-length tuple of tensors whose
+ dtypes are described by `dtypes`, and whose shapes are optionally
+ described by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a queue
+ element must have the respective fixed shape. If it is
+ unspecified, different queue elements may have different shapes,
+ but the use of `dequeue_many` is disallowed.
+
+ The `min_after_dequeue` argument allows the caller to specify a
+ minimum number of elements that will remain in the queue after a
+ `dequeue` or `dequeue_many` operation completes, to ensure a
+ minimum level of mixing of elements. This invariant is maintained
+ by blocking those operations until sufficient elements have been
+ enqueued. The `min_after_dequeue` argument is ignored after the
+ queue has been closed.
+
+ Args:
+ capacity: An integer. The upper bound on the number of elements
+ that may be stored in this queue.
+ min_after_dequeue: An integer (described above).
+ dtypes: A list of `DType` objects. The length of `dtypes` must equal
+ the number of tensors in each queue element.
+ shapes: (Optional.) A list of fully-defined `TensorShape` objects,
+ with the same length as `dtypes` or `None`.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ shared_name: (Optional.) If non-empty, this queue will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the queue operation.
+ """
+ dtypes = _as_type_list(dtypes)
+ shapes = _as_shape_list(shapes, dtypes)
+ seed1, seed2 = random_seed.get_seed(seed)
+ queue_ref = gen_data_flow_ops._random_shuffle_queue(
+ component_types=dtypes, shapes=shapes, capacity=capacity,
+ min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
+ shared_name=shared_name, name=name)
+
+ super(RandomShuffleQueue, self).__init__(dtypes, shapes, queue_ref)
+
+
+class FIFOQueue(QueueBase):
+ """A queue implementation that dequeues elements in first-in-first out order.
+
+ See [`tf.QueueBase`](#QueueBase) for a description of the methods on
+ this class.
+
+ @@__init__
+ """
+
+ def __init__(self, capacity, dtypes, shapes=None, shared_name=None,
+ name="fifo_queue"):
+ """Creates a queue that dequeues elements in a first-in first-out order.
+
+ A `FIFOQueue` has bounded capacity; supports multiple concurrent
+ producers and consumers; and provides exactly-once delivery.
+
+ A `FIFOQueue` holds a list of up to `capacity` elements. Each
+ element is a fixed-length tuple of tensors whose dtypes are
+ described by `dtypes`, and whose shapes are optionally described
+ by the `shapes` argument.
+
+ If the `shapes` argument is specified, each component of a queue
+ element must have the respective fixed shape. If it is
+ unspecified, different queue elements may have different shapes,
+ but the use of `dequeue_many` is disallowed.
+
+ Args:
+ capacity: An integer. The upper bound on the number of elements
+ that may be stored in this queue.
+ dtypes: A list of `DType` objects. The length of `dtypes` must equal
+ the number of tensors in each queue element.
+ shapes: (Optional.) A list of fully-defined `TensorShape` objects,
+ with the same length as `dtypes` or `None`.
+ shared_name: (Optional.) If non-empty, this queue will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the queue operation.
+ """
+ dtypes = _as_type_list(dtypes)
+ shapes = _as_shape_list(shapes, dtypes)
+ queue_ref = gen_data_flow_ops._fifo_queue(
+ component_types=dtypes, shapes=shapes, capacity=capacity,
+ shared_name=shared_name, name=name)
+
+ super(FIFOQueue, self).__init__(dtypes, shapes, queue_ref)
+
+
+# TODO(josh11b): class BatchQueue(QueueBase):
+
+
+# pylint: disable=protected-access
+class LookupTableBase(object):
+ """Represents a lookup table that persists across different steps."""
+
+ def __init__(self, key_dtype, value_dtype, default_value, table_ref):
+ """Construct a table object from a table reference.
+
+ Args:
+ key_dtype: The key data type of the table.
+ value_dtype: The kvalue data type of the table.
+ default_value: The scalar tensor to be used when a key is not present in
+ the table.
+ table_ref: The table reference, i.e. the output of the lookup table ops.
+ """
+ self._key_dtype = types.as_dtype(key_dtype)
+ self._value_dtype = types.as_dtype(value_dtype)
+ self._shapes = [tensor_shape.TensorShape([1])]
+ self._table_ref = table_ref
+ self._name = self._table_ref.op.name.split("/")[-1]
+ self._default_value = ops.convert_to_tensor(default_value,
+ dtype=self._value_dtype)
+ self._default_value.get_shape().merge_with(tensor_shape.scalar())
+
+ @property
+ def table_ref(self):
+ """Get the underlying table reference."""
+ return self._table_ref
+
+ @property
+ def key_dtype(self):
+ """The key dtype supported by the table."""
+ return self._key_dtype
+
+ @property
+ def value_dtype(self):
+ """The value dtype supported by the table."""
+ return self._value_dtype
+
+ @property
+ def name(self):
+ """The name of the table."""
+ return self._name
+
+ @property
+ def default_value(self):
+ """The default value of the table."""
+ return self._default_value
+
+ def size(self, name=None):
+ """Compute the number of elements in this table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar tensor containing the number of elements in this table.
+ """
+ if name is None:
+ name = "%s_Size" % self._name
+ return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
+
+ def lookup(self, keys, name=None):
+ """Returns the values for the given 'keys' tensor.
+
+ If an element on the key tensor is not found in the table, the default_value
+ is used.
+
+ Args:
+ keys: The tensor for the keys.
+ name: Optional name for the op.
+
+ Returns:
+ The operation that looks up the keys.
+
+ Raises:
+ TypeError: when 'keys' or 'default_value' doesn't match the table data
+ types.
+ """
+ if name is None:
+ name = "%s_lookup_table_find" % self._name
+
+ if keys.dtype != self._key_dtype:
+ raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
+ self._key_dtype, keys.dtype))
+
+ return gen_data_flow_ops._lookup_table_find(
+ self._table_ref, keys, self._default_value, name=name)
+
+ def initialize_from(self, keys, values, name=None):
+ """Initialize the lookup table with the provided keys and values tensors.
+
+ Construct an initializer object from keys and value tensors.
+
+ Args:
+ keys: The tensor for the keys.
+ values: The tensor for the values.
+ name: Optional name for the op.
+
+ Returns:
+ The operation that initializes a lookup table.
+
+ Raises:
+ TypeError: when the 'keys' and 'values' data type do not match the table
+ key and value data types.
+ """
+ if name is None:
+ name = "%s_initialize_table" % self.name
+ with ops.op_scope([keys, values], None, name):
+ keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
+ values = ops.convert_to_tensor(values, dtype=self.value_dtype,
+ name="values")
+
+ init_op = gen_data_flow_ops._initialize_table(
+ self.table_ref, keys, values, name=name)
+ ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
+ return init_op
+
+ def _check_table_dtypes(self, key_dtype, value_dtype):
+ """Check that the given key_dtype and value_dtype matches the table dtypes'.
+
+ Args:
+ key_dtype: The key data type to check.
+ value_dtype: The value data type to check.
+
+ Raises:
+ TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
+ types.
+ """
+ if key_dtype != self.key_dtype:
+ raise TypeError("Invalid key dtype, expected %s but got %s." % (
+ self.key_dtype, key_dtype))
+ if value_dtype != self.value_dtype:
+ raise TypeError("Invalid value dtype, expected %s but got %s." % (
+ self.value_dtype, value_dtype))
+
+
+class HashTable(LookupTableBase):
+ """A generic hash table implementation."""
+
+ def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
+ name="hash_table"):
+ """Create a generic hash table.
+
+ A table holds a key-value pairs. The key and value types are
+ described by key_dtype and value_dtype respectively.
+
+ Args:
+ key_dtype: The key data type of the table.
+ value_dtype: The kvalue data type of the table.
+ default_value: The scalar tensor to be used when a key is not present in
+ the table.
+ shared_name: Optional. If non-empty, this table will be shared under
+ the given name across multiple sessions.
+ name: Optional name for the hash table op.
+
+ Returns:
+ A table object that can be used to lookup data.
+ """
+ table_ref = gen_data_flow_ops._hash_table(
+ shared_name=shared_name, key_dtype=key_dtype,
+ value_dtype=value_dtype, name=name)
+
+ super(HashTable, self).__init__(key_dtype, value_dtype, default_value,
+ table_ref)
+
+
+def initialize_all_tables(name="init_all_tables"):
+ """Returns an Op that initializes all tables of the default graph.
+
+ Returns:
+ An Op that initializes all tables. Note that if there are
+ not tables the returned Op is a NoOp.
+ """
+ initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
+ if initializers:
+ return control_flow_ops.group(*initializers, name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+ops.NoGradient("LookupTableFind")
+ops.NoGradient("LookupTableSize")
+ops.NoGradient("HashTable")
+ops.NoGradient("InitializeTable")
+
+
+ops.RegisterShape("QueueSize")(common_shapes.scalar_shape)
+ops.RegisterShape("Queue")(common_shapes.scalar_shape)
+ops.RegisterShape("FIFOQueue")(common_shapes.scalar_shape)
+ops.RegisterShape("RandomShuffleQueue")(common_shapes.scalar_shape)
+
+
+# NOTE(mrry): The following ops use higher-level information in the
+# Queue class to provide shape information.
+ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape)
+
+
+@ops.RegisterShape("QueueClose")
+def _ScalarToVoidShape(op):
+ """Shape function for ops that take a scalar and produce no outputs."""
+ unused_input_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("DynamicPartition")
+def _DynamicPartitionShape(op):
+ """Shape function for data_flow_ops.dynamic_partition."""
+ data_shape = op.inputs[0].get_shape()
+ partitions_shape = op.inputs[1].get_shape()
+ # If we don't know the rank of partitions, we don't know anything
+ mid = partitions_shape.ndims
+ if mid is None:
+ result_shape = tensor_shape.unknown_shape()
+ else:
+ # data_shape must start with partitions_shape
+ partitions_shape.assert_is_compatible_with(data_shape[:mid])
+ # The partition shape is dynamic in the 0th dimension, and matches
+ # data_shape in the remaining dimensions.
+ result_shape = tensor_shape.TensorShape([None]).concatenate(
+ data_shape[mid:])
+ return [result_shape] * op.get_attr("num_partitions")
+
+
+@ops.RegisterShape("DynamicStitch")
+def _DynamicStitchShape(op):
+ """Shape function for data_flow_ops.dynamic_stitch."""
+ num_partitions = op.get_attr("N")
+ indices_shapes = [t.get_shape() for t in op.inputs[0:num_partitions]]
+ data_shapes = [t.get_shape() for t in op.inputs[num_partitions:]]
+ output_shape = tensor_shape.unknown_shape()
+ extra_shape = tensor_shape.TensorShape(None)
+ for indices_shape, data_shape in zip(indices_shapes, data_shapes):
+ indices_ndims = indices_shape.ndims
+ if indices_ndims is not None:
+ # Assert that data_shape starts with indices_shape
+ indices_shape.merge_with(data_shape[:indices_ndims])
+ # The rest belongs to output
+ extra_shape = extra_shape.merge_with(data_shape[indices_ndims:])
+ return [tensor_shape.TensorShape([None]).concatenate(extra_shape)]
+
+
+@ops.RegisterShape("LookupTableFind")
+def _LookupTableFindShape(op):
+ """Shape function for data_flow_ops._lookup_table_find."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ shape_in = op.inputs[1].get_shape()
+ return [shape_in]
+
+
+@ops.RegisterShape("LookupTableSize")
+def _LookupTableSizeShape(op):
+ """Shape function for data_flow_ops._lookup_table_find."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("HashTable")
+def _HashTableShape(unused_op):
+ """Shape function for data_flow_ops._hash_table."""
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("InitializeTable")
+def _InitializeLookupTableShape(op):
+ """Shape function for data_flow_ops._initialize_table."""
+ unused_table_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ keys_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_values_shape = op.inputs[2].get_shape().merge_with(keys_shape)
+ return []
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
new file mode 100644
index 0000000000..bc64593d23
--- /dev/null
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -0,0 +1,197 @@
+"""Operations for embeddings."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def embedding_lookup(params, ids, name=None):
+ """Return a tensor of embedding values by looking up "ids" in "params".
+
+ Args:
+ params: List of tensors of the same shape. A single tensor is
+ treated as a singleton list.
+ ids: Tensor of integers containing the ids to be looked up in
+ 'params'. Let P be len(params). If P > 1, then the ids are
+ partitioned by id % P, and we do separate lookups in params[p]
+ for 0 <= p < P, and then stitch the results back together into
+ a single result tensor.
+ name: Optional name for the op.
+
+ Returns:
+ A tensor of shape ids.shape + params[0].shape[1:] containing the
+ values params[i % P][i] for each i in ids.
+
+ Raises:
+ ValueError: if some parameters are invalid.
+ """
+ if not isinstance(params, list):
+ params = [params]
+ with ops.op_scope(params + [ids], name, "embedding_lookup") as name:
+ if not params:
+ raise ValueError("Need at least one param")
+ np = len(params) # Number of partitions
+ params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
+ if np == 1:
+ with ops.device(params[0].device):
+ return array_ops.gather(params[0], ids, name=name)
+ else:
+ ids = ops.convert_to_tensor(ids, name="ids")
+ flat_ids = array_ops.reshape(ids, [-1])
+ original_indices = math_ops.range(0, array_ops.size(flat_ids))
+ # Compute flat_ids % partitions for each id
+ ids_mod_p = flat_ids % np
+ if ids_mod_p.dtype != types.int32:
+ ids_mod_p = math_ops.cast(ids_mod_p, types.int32)
+ # Partition single list of ids based on ids % np into np separate lists
+ plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np)
+ # Similarly, partition the original indices.
+ pindices = data_flow_ops.dynamic_partition(original_indices, ids_mod_p,
+ np)
+ # Do np separate lookups, finding embeddings for plist[p] in params[p]
+ partitioned_result = []
+ for p in range(np):
+ # TODO(agarwal): handle device allocations here and later in the
+ # colocate code.
+ gather_ids = plist[p] / np
+ with ops.device(params[p].device):
+ partitioned_result.append(array_ops.gather(params[p], gather_ids))
+ # Stitch these back together
+ ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
+ name=name)
+ # Reshape to reverse the flattening of ids.
+ # It's important that we compute params[0].shape on the right device
+ # to avoid data motion.
+ with ops.device(params[0].device):
+ params_shape = array_ops.shape(params[0])
+ ret = array_ops.reshape(ret, array_ops.concat(0, [
+ array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])]))
+ # output shape = ids.shape + params[*].shape[1:]
+ # Normally the reshape is sufficient, but setting shape explicitly
+ # teaches shape inference that params[1:].get_shape() matters.
+ element_shape = params[0].get_shape()[1:]
+ for p in params[1:]:
+ element_shape = element_shape.merge_with(p.get_shape()[1:])
+ ret.set_shape(ids.get_shape().concatenate(element_shape))
+ return ret
+
+
+# TODO(lif): Add support for higher-rank SparseTensors
+def embedding_lookup_sparse(params, sp_ids, sp_weights,
+ name=None,
+ combiner="mean"):
+ """Computes embeddings for the given ids and weights.
+
+ This op assumes that there is at least one id for each row in the dense tensor
+ represented by sp_ids (i.e. there are no rows with empty features), and that
+ all the indices of sp_ids are in canonical row-major order.
+
+ It also assumes that all id values lie in the range [0, p0), where p0
+ is the sum of the size of params along dimension 0.
+
+ Args:
+ params: A single tensor representing the complete embedding tensor,
+ or a list of P tensors all of same shape except for the first dimension,
+ representing sharded embedding tensors. In the latter case, the ids are
+ partitioned by id % P, and we do separate lookups in params[p] for
+ 0 <= p < P, and then stitch the results back together into a single
+ result tensor. The first dimension is allowed to vary as the vocab
+ size is not necessarily a multiple of P.
+ sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ where N is typically batch size and M is arbitrary.
+ sp_weights: either a SparseTensor of float / double weights, or None to
+ indicate all weights should be taken to be 1. If specified, sp_weights
+ must have exactly the same shape and indices as sp_ids.
+ name: Optional name for the op.
+ combiner: A string specifying the reduction op. Currently "mean" and "sum"
+ are supported.
+ "sum" computes the weighted sum of the embedding results for each row.
+ "mean" is the weighted sum divided by the total weight.
+
+ Returns:
+ A dense tensor representing the combined embeddings for the
+ sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ looks up the embeddings for all ids in that row, multiplies them by the
+ corresponding weight, and combines these embeddings as specified.
+
+ In other words, if
+ shape(combined params) = [p0, p1, ..., pm]
+ and
+ shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
+ then
+ shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
+
+ For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
+
+ [0, 0]: id 1, weight 2.0
+ [0, 1]: id 3, weight 0.5
+ [1, 0]: id 0, weight 1.0
+ [2, 3]: id 1, weight 3.0
+
+ with combiner="mean", then the output will be a 3x20 matrix where
+ output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
+ output[1, :] = params[0, :] * 1.0
+ output[2, :] = params[1, :] * 3.0
+
+ Raises:
+ TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
+ None nor SparseTensor.
+ ValueError: If combiner is not one of {"mean", "sum"}.
+ """
+ if combiner not in ("mean", "sum"):
+ raise ValueError("combiner must be one of 'mean' or 'sum'")
+ if not isinstance(params, list):
+ params = [params]
+ if not isinstance(sp_ids, ops.SparseTensor):
+ raise TypeError("sp_ids must be SparseTensor")
+ ignore_weights = sp_weights is None
+ if not ignore_weights and not isinstance(sp_weights, ops.SparseTensor):
+ raise TypeError("sp_weights must be either None or SparseTensor")
+
+ with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name:
+ segment_ids = sp_ids.indices[:, 0]
+ if segment_ids.dtype != types.int32:
+ segment_ids = math_ops.cast(segment_ids, types.int32)
+
+ ids = sp_ids.values
+ if ignore_weights:
+ ids, idx = array_ops.unique(ids)
+ else:
+ idx = None
+
+ embeddings = embedding_lookup(params, ids)
+ if not ignore_weights:
+ weights = sp_weights.values
+ if weights.dtype != embeddings.dtype:
+ weights = math_ops.cast(weights, embeddings.dtype)
+
+ # Reshape weights to allow broadcast
+ ones = array_ops.fill(
+ array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
+ bcast_weights_shape = array_ops.concat(0, [
+ array_ops.shape(weights), ones])
+ weights = array_ops.reshape(weights, bcast_weights_shape)
+ embeddings *= weights
+
+ if combiner == "sum":
+ embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
+ elif combiner == "mean":
+ embeddings = math_ops.segment_sum(embeddings, segment_ids)
+ weight_sum = math_ops.segment_sum(weights, segment_ids)
+ embeddings = math_ops.div(embeddings, weight_sum, name=name)
+ else:
+ assert False, "Unrecognized combiner"
+ else:
+ assert idx is not None
+ if combiner == "sum":
+ embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
+ name=name)
+ elif combiner == "mean":
+ embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
+ name=name)
+ else:
+ assert False, "Unrecognized combiner"
+
+ return embeddings
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
new file mode 100644
index 0000000000..ffa7828c04
--- /dev/null
+++ b/tensorflow/python/ops/gradients.py
@@ -0,0 +1,661 @@
+"""Implements the graph generation for computation of gradients."""
+
+import collections
+import warnings
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+# pylint: disable=unused-import
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_grad
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import linalg_grad
+from tensorflow.python.ops import math_grad
+# pylint: enable=unused-import
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.platform import logging
+
+
+# Warn the user if we convert a sparse representation to dense with at
+# least this number of elements.
+_LARGE_SPARSE_NUM_ELEMENTS = 100000000
+
+
+def _IndexedSlicesToTensor(value, dtype=None, name=None):
+ """Converts an IndexedSlices object `value` to a Tensor.
+
+ NOTE(mrry): This function is potentially expensive.
+
+ Args:
+ value: An ops.IndexedSlices object.
+ dtype: The dtype of the Tensor to be returned.
+ name: Optional name to use for the returned Tensor.
+
+ Returns:
+ A dense Tensor representing the values in the given IndexedSlices.
+
+ Raises:
+ ValueError: If the IndexedSlices does not have the same dtype.
+ """
+ if dtype and not dtype.is_compatible_with(value.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for IndexedSlices with dtype %s"
+ % (dtype.name, value.dtype.name))
+ if value.dense_shape is None:
+ raise ValueError(
+ "Tensor conversion requested for IndexedSlices without dense_shape: %s"
+ % str(value))
+ # TODO(mrry): Consider adding static shape information to
+ # IndexedSlices, to avoid using numpy here.
+ dense_shape_value = tensor_util.ConstantValue(value.dense_shape)
+ if dense_shape_value is not None:
+ num_elements = np.prod(dense_shape_value)
+ if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor with %d elements. "
+ "This may consume a large amount of memory." % num_elements)
+ else:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
+ "This may consume a large amount of memory.")
+ return math_ops.unsorted_segment_sum(
+ value.values, value.indices, value.dense_shape[0], name=name)
+
+
+ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor)
+
+
+def _MarkReachedOps(from_ops, reached_ops):
+ """Mark all ops reached from "from_ops".
+
+ Args:
+ from_ops: list of Operations.
+ reached_ops: list of booleans, indexed by operation id.
+ """
+ queue = collections.deque()
+ queue.extend(from_ops)
+ while queue:
+ op = queue.popleft()
+ if not reached_ops[op._id]:
+ reached_ops[op._id] = True
+ for output in op.outputs:
+ queue.extend(output.consumers())
+
+
+def _GatherInputs(to_ops, reached_ops):
+ """List all inputs of to_ops that are in reached_ops.
+
+ Args:
+ to_ops: list of Operations.
+ reached_ops: list of booleans, indexed by operation id.
+
+ Returns:
+ The list of all inputs of to_ops that are in reached_ops.
+ That list includes all elements of to_ops.
+ """
+ inputs = []
+ queue = collections.deque()
+ queue.extend(to_ops)
+ while queue:
+ op = queue.popleft()
+ # We are interested in this op.
+ if reached_ops[op._id]:
+ inputs.append(op)
+ # Clear the boolean so we won't add the inputs again.
+ reached_ops[op._id] = False
+ for inp in op.inputs:
+ queue.append(inp.op)
+ return inputs
+
+
+def _GetGradsDevice(op, colocate_gradients_with_ops):
+ """Gets the device to which to assign gradients of "op".
+
+ Args:
+ op: an Operation.
+ colocate_gradients_with_ops: If True, try colocating gradients with the
+ corresponding op.
+
+ Returns:
+ A device string.
+ """
+ if colocate_gradients_with_ops and op.device:
+ return op.device
+ else:
+ return op.graph.get_default_device()
+
+
+def _PendingCount(graph, to_ops, from_ops):
+ """Initialize the pending count for ops between two lists of Operations.
+
+ 'pending_count[op._id]' indicates the number of backprop inputs
+ to this operation.
+
+ Args:
+ graph: a Graph.
+ to_ops: list of Operations.
+ from_ops: list of Operations.
+
+ Returns:
+ A tuple containing: (1) a list of integers indexed by operation id,
+ indicating the number of backprop inputs to this operation, and (2)
+ a boolean which is True if any of the ops in between from_ops and to_ops
+ contain control flow loops.
+ """
+ # Mark reachable ops from from_ops.
+ reached_ops = [False] * (graph._last_id + 1)
+ for op in to_ops:
+ reached_ops[op._id] = True
+ _MarkReachedOps(from_ops, reached_ops)
+
+ # Mark between ops.
+ between_ops = [False] * (graph._last_id + 1)
+ between_op_list = []
+ queue = collections.deque()
+ queue.extend(to_ops)
+ while queue:
+ op = queue.popleft()
+ # We are interested in this op.
+ if reached_ops[op._id]:
+ between_ops[op._id] = True
+ between_op_list.append(op)
+ # Clear the boolean so we won't add the inputs again.
+ reached_ops[op._id] = False
+ for inp in op.inputs:
+ queue.append(inp.op)
+
+ # Initialize pending count for between ops.
+ pending_count = [0] * (graph._last_id + 1)
+ has_control_flow = False
+ for op in between_op_list:
+ for x in op.inputs:
+ if between_ops[x.op._id]:
+ pending_count[x.op._id] += 1
+ for x in op.control_inputs:
+ if between_ops[x._id]:
+ pending_count[x._id] += 1
+ if op.type == "Exit":
+ has_control_flow = True
+ return pending_count, has_control_flow
+
+
+def _AsList(x):
+ return x if isinstance(x, (list, tuple)) else [x]
+
+
+def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
+ """Fill in default values for grad_ys.
+
+ Args:
+ grad_ys: List of gradients, can contain None.
+ ys: List of tensors.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+
+ Returns:
+ A list of gradients to use, without None.
+
+ Raises:
+ ValueError: If one of the grad_ys is invalid.
+ """
+ if len(grad_ys) != len(ys):
+ raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
+ grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
+ for i in xrange(len(grad_ys)):
+ grad_y = grad_ys[i]
+ y = ys[i]
+ if grad_y is None:
+ with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)):
+ grad_ys[i] = array_ops.fill(array_ops.shape(y),
+ constant_op.constant(1, dtype=y.dtype))
+ else:
+ if grad_y.dtype != y.dtype:
+ raise ValueError("Y and ys_grad must be of the same type, "
+ "not y: %s, ys_grad: %s " %
+ (types.as_dtype(y.dtype).name,
+ types.as_dtype(grad_y.dtype).name))
+ return grad_ys
+
+
+def _VerifyGeneratedGradients(grads, op):
+ """Verify that gradients are valid in number and type.
+
+ Args:
+ grads: List of generated gradients.
+ op: Operation for which the gradients where generated.
+
+ Raises:
+ ValueError: if the gradients are invalid.
+ """
+ if len(grads) != len(op.inputs):
+ raise ValueError("Num gradients %d generated for op %s do not match num "
+ "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
+ for i in xrange(len(grads)):
+ grad = grads[i]
+ inp = op.inputs[i]
+ if grad is not None:
+ if not grad.dtype.is_compatible_with(inp.dtype):
+ raise ValueError(
+ "Gradient type %s generated for op %s does "
+ "not match input type %s" %
+ (types.as_dtype(grad.dtype).name, op.node_def,
+ types.as_dtype(inp.dtype).name))
+
+
+def _StopOps(from_ops, pending_count):
+ """The set of ops that terminate the gradient computation.
+
+ This computes the frontier of the forward graph *before* which backprop
+ should stop. Operations in the returned set will not be differentiated.
+ This set is defined as the subset of `from_ops` containing ops that have
+ no predecessor in `from_ops`. `pending_count` is the result of
+ `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
+ iff pending_count[op._id] > 0.
+
+ Args:
+ from_ops: list of Operations.
+ pending_count: List of integers, indexed by operation id.
+
+ Returns:
+ The set of operations.
+ """
+ stop_ops = set()
+ for op in from_ops:
+ is_stop_op = True
+ for inp in op.inputs:
+ if pending_count[inp.op._id] > 0:
+ is_stop_op = False
+ break
+ if is_stop_op:
+ stop_ops.add(op._id)
+ return stop_ops
+
+
+def gradients(ys, xs, grad_ys=None, name="gradients",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None):
+ """Constructs symbolic partial derivatives of `ys` w.r.t. x in `xs`.
+
+ `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
+ is a list of `Tensor`, holding the gradients received by the
+ `ys`. The list must be the same length as `ys`.
+
+ `gradients()` adds ops to the graph to output the partial
+ derivatives of `ys` with respect to `xs`. It returns a list of
+ `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
+ for y in `ys`.
+
+ `grad_ys` is a list of tensors of the same length as `ys` that holds
+ the initial gradients for each y in `ys`. When `grad_ys` is None,
+ we fill in a tensor of '1's of the shape of y for each y in `ys`. A
+ user can provide their own initial 'grad_ys` to compute the
+ derivatives using a different initial gradient for each y (e.g., if
+ one wanted to weight the gradient differently for each value in
+ each y).
+
+ Args:
+ ys: A `Tensor` or list of tensors to be differentiated.
+ xs: A `Tensor` or list of tensors to be used for differentiation.
+ grad_ys: Optional. A `Tensor` or list of tensors the same size as
+ `ys` and holding the gradients computed for each y in `ys`.
+ name: Optional name to use for grouping all the gradient ops together.
+ defaults to 'gradients'.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ gate_gradients: If True, add a tuple around the gradients returned
+ for an operations. This avoids some race conditions.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+
+ Returns:
+ A list of `sum(dy/dx)` for each x in `xs`.
+
+ Raises:
+ LookupError: if one of the operations between `x` and `y` does not
+ have a registered gradient function.
+ ValueError: if the arguments are invalid.
+
+ """
+ ys = _AsList(ys)
+ xs = _AsList(xs)
+ if grad_ys is None:
+ grad_ys = [None] * len(ys)
+ else:
+ grad_ys = _AsList(grad_ys)
+ with ops.op_scope(ys + xs + grad_ys, name, "gradients"):
+ ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
+ xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x")
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
+
+ # The approach we take here is as follows: Create a list of all ops in the
+ # subgraph between the ys and xs. Visit these ops in reverse order of ids
+ # to ensure that when we visit an op the gradients w.r.t its outputs have
+ # been collected. Then aggregate these gradients if needed, call the op's
+ # gradient function, and add the generated gradients to the gradients for
+ # its input.
+
+ # Initialize the pending count for ops in the connected subgraph from ys
+ # to the xs.
+ to_ops = [t.op for t in ys]
+ from_ops = [t.op for t in xs]
+ pending_count, has_control_flow = _PendingCount(
+ ops.get_default_graph(), to_ops, from_ops)
+
+ # Iterate over the collected ops.
+ #
+ # grads: op => list of gradients received on each output endpoint of the
+ # op. The gradients for each endpoint are initially collected as a list.
+ # When it is time to call the op's gradient function, for each endpoint we
+ # aggregate the list of received gradients into a Add() Operation if there
+ # is more than one.
+ grads = {}
+
+ # Add the initial gradients for the ys.
+ for y, grad_y in zip(ys, grad_ys):
+ _SetGrad(grads, y, grad_y)
+
+ # Initialize queue with to_ops.
+ queue = collections.deque()
+ # Add the ops in 'to_ops' into the queue.
+ to_ops_set = set()
+ for op in to_ops:
+ if op._id not in to_ops_set:
+ to_ops_set.add(op._id)
+ queue.append(op)
+ # The set of 'from_ops'.
+ stop_ops = _StopOps(from_ops, pending_count)
+ while queue:
+ # generate gradient subgraph for op.
+ op = queue.popleft()
+ with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
+ if has_control_flow:
+ control_flow_ops.EnterGradWhileContext(op)
+ out_grads = _AggregatedGrads(grads, op, has_control_flow,
+ aggregation_method)
+ grad_fn = None
+ if any(out_grads) and op._id not in stop_ops:
+ # A grad_fn must be defined, either as a function or as None
+ # for ops that do not have gradients.
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ raise LookupError(
+ "No gradient defined for operation '%s' (op type: %s)" %
+ (op.name, op.type))
+ if grad_fn and any(out_grads):
+ # NOTE: If _AggregatedGrads didn't compute a value for the i'th
+ # output, it means that the cost does not depend on output[i],
+ # therefore dC/doutput[i] is 0.
+ for i, out_grad in enumerate(out_grads):
+ if (not out_grad
+ and types.as_dtype(op.outputs[i].dtype).base_dtype in (
+ types.float32, types.float64)):
+ # Only floating-point outputs get a zero gradient. Gradient
+ # functions should ignore the gradient for other outputs.
+ out_grads[i] = array_ops.zeros_like(op.outputs[i])
+ with ops.name_scope(op.name + "_grad"):
+ # pylint: disable=protected-access
+ with ops.get_default_graph()._original_op(op):
+ # pylint: enable=protected-access
+ op_wrapper = op
+ if has_control_flow:
+ op_wrapper = control_flow_ops.MakeWrapper(op)
+ in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
+ _VerifyGeneratedGradients(in_grads, op)
+ if gate_gradients and len(in_grads) > 1:
+ in_grads = control_flow_ops.tuple(in_grads)
+ logging.vlog(1, "Gradient for '" + op.name + "'")
+ logging.vlog(1, " in --> %s",
+ ", ".join([x.name for x in out_grads if x]))
+ logging.vlog(1, " out --> %s",
+ ", ".join([x.name for x in in_grads if x]))
+ else:
+ # If no grad_fn is defined or none of out_grads is available,
+ # just propagates a list of None backwards.
+ in_grads = [None] * len(op.inputs)
+ for t_in, in_grad in zip(op.inputs, in_grads):
+ if in_grad:
+ _SetGrad(grads, t_in, in_grad)
+ if has_control_flow:
+ control_flow_ops.ExitGradWhileContext(op)
+
+ # update pending count for the inputs of op.
+ for x in op.inputs:
+ pending_count[x.op._id] -= 1
+ ready = (pending_count[x.op._id] == 0)
+ if has_control_flow and not ready:
+ ready = (pending_count[x.op._id] > 0 and
+ control_flow_ops.IsLoopSwitch(x.op))
+ if ready:
+ queue.append(x.op)
+ for x in op.control_inputs:
+ pending_count[x._id] -= 1
+ if pending_count[x._id] is 0:
+ queue.append(x)
+ return [_GetGrad(grads, x) for x in xs]
+
+
+def _SetGrad(grads, t, grad):
+ """Sets gradient "grad" in "grads" for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads:
+ op_grads = [[] for _ in xrange(len(op.outputs))]
+ grads[op] = op_grads
+ t_grads = op_grads[t.value_index]
+ if isinstance(t_grads, list):
+ t_grads.append(grad)
+ else:
+ assert op.type == "Switch"
+ op_grads[t.value_index] = grad
+
+
+def _GetGrad(grads, t):
+ """Gets gradient for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads: return None
+ t_grad = op_grads[t.value_index]
+ assert not isinstance(t_grad, list), (
+ "gradients list should have been aggregated by now.")
+ return t_grad
+
+
+def _GetGrads(grads, op):
+ """Gets all gradients for op."""
+ if op in grads:
+ return grads[op]
+ else:
+ return [[] for _ in xrange(len(op.outputs))]
+
+
+def _HandleNestedIndexedSlices(grad):
+ assert isinstance(grad, ops.IndexedSlices)
+ if isinstance(grad.values, ops.Tensor):
+ return grad
+ else:
+ assert isinstance(grad.values, ops.IndexedSlices)
+ g = _HandleNestedIndexedSlices(grad.values)
+ return ops.IndexedSlices(
+ g.values, array_ops.gather(grad.indices, g.indices), g.dense_shape)
+
+
+def _AccumulatorShape(inputs):
+ shape = tensor_shape.unknown_shape()
+ for i in inputs:
+ if isinstance(i, ops.Tensor):
+ shape = shape.merge_with(i.get_shape())
+ return shape
+
+
+class AggregationMethod(object):
+ """A class listing aggregation methods used to combine gradients.
+
+ Computing partial derivatives can require aggregating gradient
+ contributions. This class lists the various methods that can
+ be used to combine gradients in the graph:
+
+ * `ADD_N`: All of the gradient terms are summed as part of one
+ operation using the "AddN" op. It has the property that all
+ gradients must be ready before any aggregation is performed.
+ * `DEFAULT`: The system-chosen default aggregation method.
+ """
+ ADD_N = 0
+ DEFAULT = ADD_N
+ # The following are experimental and may not be supported in future releases.
+ EXPERIMENTAL_TREE = 1
+ EXPERIMENTAL_ACCUMULATE_N = 2
+
+
+def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None):
+ """Get the aggregated gradients for op.
+
+ Args:
+ grads: The map of memoized gradients.
+ op: The op to get gradients for.
+ has_control_flow: True iff the graph contains control flow ops.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+
+ Returns:
+ A list of gradients, one per each output of `op`. If the gradients
+ for a particular output is a list, this function aggregates it
+ before returning.
+
+ Raises:
+ TypeError: if the incoming grads are not Tensors or IndexedSlices.
+ ValueError: if the arguments are invalid.
+
+ """
+ if aggregation_method is None:
+ aggregation_method = AggregationMethod.DEFAULT
+ if aggregation_method not in [AggregationMethod.ADD_N,
+ AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N]:
+ raise ValueError("Invalid aggregation_method specified.")
+ out_grads = _GetGrads(grads, op)
+ for i, out_grad in enumerate(out_grads):
+ if has_control_flow:
+ if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
+ assert op.type == "Switch"
+ continue
+ # Grads have to be Tensors or IndexedSlices
+ if not all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in out_grad if g]):
+ raise TypeError("gradients have to be either all Tensors "
+ "or all IndexedSlices")
+ # Aggregate multiple gradients, and convert [] to None.
+ if out_grad:
+ if all([isinstance(g, ops.Tensor) for g in out_grad if g]):
+ tensor_shape = _AccumulatorShape(out_grad)
+ if len(out_grad) < 2:
+ used = "nop"
+ out_grads[i] = out_grad[0]
+ elif (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
+ # The benefit of using AccumulateN is that its inputs can be combined
+ # in any order and this can allow the expression to be evaluated with
+ # a smaller memory footprint. When used with gpu_allocator_retry,
+ # it is possible to compute a sum of terms which are much larger than
+ # total GPU memory.
+ # AccumulateN can currently only be used if we know the shape for
+ # an accumulator variable. If this is not known, or if we only have
+ # 2 grads then we fall through to the "tree" case below.
+ used = "accumulate_n"
+ out_grads[i] = math_ops.accumulate_n(out_grad)
+ elif aggregation_method in [AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ ]:
+ # Aggregate all gradients by doing pairwise sums: this may
+ # reduce performance, but it can improve memory because the
+ # gradients can be released earlier.
+ #
+ # TODO(vrv): Consider replacing this with a version of
+ # tf.AddN() that eagerly frees its inputs as soon as they are
+ # ready, so the order of this tree does not become a problem.
+ used = "tree"
+ with ops.name_scope(op.name + "_gradient_sum"):
+ running_sum = out_grad[0]
+ for grad in out_grad[1:]:
+ running_sum = math_ops.add_n([running_sum, grad])
+ out_grads[i] = running_sum
+ else:
+ used = "add_n"
+ out_grads[i] = math_ops.add_n(out_grad)
+ logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
+ tensor_shape, used)
+ else:
+ out_grad = math_ops._as_indexed_slices_list([g for g in out_grad if g])
+ out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
+ # Form IndexedSlices out of the concatenated values and
+ # indices.
+ out_grads[i] = ops.IndexedSlices(
+ array_ops.concat(0, [x.values for x in out_grad]),
+ array_ops.concat(0, [x.indices for x in out_grad]),
+ out_grad[0].dense_shape)
+ else:
+ out_grads[i] = []
+ return out_grads
+
+
+# TODO(vrv): Make this available when we want to make it public.
+def _hessian_vector_product(ys, xs, v):
+ """Multiply the Hessian of `ys` wrt `xs` by `v`.
+
+ This is an efficient construction that uses a backprop-like approach
+ to compute the product between the Hessian and another vector. The
+ Hessian is usually too large to be explicitly computed or even
+ represented, but this method allows us to at least multiply by it
+ for the same big-O cost as backprop.
+
+ Implicit Hessian-vector products are the main practical, scalable way
+ of using second derivatives with neural networks. They allow us to
+ do things like construct Krylov subspaces and approximate conjugate
+ gradient descent.
+
+ Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y,
+ x, v)` will return an expression that evaluates to the same values
+ as (A + A.T) `v`.
+
+ Args:
+ ys: A scalar value, or a tensor or list of tensors to be summed to
+ yield a scalar.
+ xs: A list of tensors that we should construct the Hessian over.
+ v: A list of tensors, with the same shapes as xs, that we want to
+ multiply by the Hessian.
+
+ Returns:
+ A list of tensors (or if the list would be length 1, a single tensor)
+ containing the product between the Hessian and `v`.
+
+ Raises:
+ ValueError: `xs` and `v` have different length.
+
+ """
+
+ # Validate the input
+ length = len(xs)
+ if len(v) != length:
+ raise ValueError("xs and v must have the same length.")
+
+ # First backprop
+ grads = gradients(ys, xs)
+
+ assert len(grads) == length
+ elemwise_products = [math_ops.mul(grad_elem, array_ops.stop_gradient(v_elem))
+ for grad_elem, v_elem in zip(grads, v)
+ if grad_elem is not None]
+
+ # Second backprop
+ return gradients(elemwise_products, xs)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
new file mode 100644
index 0000000000..dac0ebbb60
--- /dev/null
+++ b/tensorflow/python/ops/gradients_test.py
@@ -0,0 +1,337 @@
+"""Tests for tensorflow.ops.gradients."""
+import warnings
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+# pylint: disable=unused-import
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.ops import state_grad
+# pylint: enable=unused-import
+from tensorflow.python.ops.constant_op import constant
+from tensorflow.python.ops.nn_ops import bias_add
+from tensorflow.python.platform import googletest
+
+
+def _OpsBetween(graph, to_ops, from_ops):
+ """Build the list of operations between two lists of Operations.
+
+ Args:
+ graph: a Graph.
+ to_ops: list of Operations.
+ from_ops: list of Operations.
+
+ Returns:
+ The list of operations between "from_ops" and "to_ops", sorted by
+ decreasing operation id. This list contains all elements of to_ops.
+
+ TODO(mdevin): Think about returning an empty list if from_ops are not
+ reachable from to_ops. Presently it returns to_ops in that case.
+ """
+ # List of booleans, indexed by operation id, indicating if
+ # an op is reached from the output of "input_ops".
+ reached_ops = [False] * (graph._last_id + 1)
+ # We only care to reach up to "output_ops" so we mark the
+ # output ops as reached to avoid recursing past them.
+ for op in to_ops:
+ reached_ops[op._id] = True
+ gradients._MarkReachedOps(from_ops, reached_ops)
+ between_ops = gradients._GatherInputs(to_ops, reached_ops)
+ between_ops.sort(lambda x, y: y._id - x._id)
+ return between_ops
+
+
+class GradientsTest(test_util.TensorFlowTestCase):
+
+ def _OpNames(self, op_list):
+ return ["%s/%d" % (str(op.name), op._id) for op in op_list]
+
+ def _assertOpListEqual(self, ops1, ops2):
+ self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
+
+ def testOpsBetweenSimple(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ # Full graph
+ self._assertOpListEqual([t3.op, t2.op, t1.op],
+ _OpsBetween(g, [t3.op], [t1.op, t2.op]))
+ # Only t1, t3.
+ self._assertOpListEqual([t3.op, t1.op],
+ _OpsBetween(g, [t3.op], [t1.op]))
+
+ def testOpsBetweenUnreachable(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ _ = array_ops.pack([t1, t2])
+ t4 = constant(1.0)
+ t5 = constant(2.0)
+ t6 = array_ops.pack([t4, t5])
+ # Elements of to_ops are always listed.
+ self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
+
+ def testOpsBetweenCut(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ t4 = constant([1.0])
+ t5 = array_ops.concat(0, [t4, t3])
+ t6 = constant([2.0])
+ t7 = array_ops.concat(0, [t5, t6])
+ self._assertOpListEqual([t7.op, t5.op, t4.op],
+ _OpsBetween(g, [t7.op], [t4.op]))
+
+ def testOpsBetweenCycle(self):
+ with ops.Graph().as_default() as g:
+ t1 = constant(1.0)
+ t2 = constant(2.0)
+ t3 = array_ops.pack([t1, t2])
+ t4 = array_ops.concat(0, [t3, t3, t3])
+ t5 = constant([1.0])
+ t6 = array_ops.concat(0, [t4, t5])
+ t7 = array_ops.concat(0, [t6, t3])
+ self._assertOpListEqual([t6.op, t4.op, t3.op],
+ _OpsBetween(g, [t6.op], [t3.op]))
+ self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
+ _OpsBetween(g, [t7.op], [t1.op, t5.op]))
+ self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
+ _OpsBetween(g, [t6.op], [t2.op, t5.op]))
+
+ def testGradients(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[32, 100], name="in")
+ w = constant(1.0, shape=[100, 10], name="w")
+ b = constant(1.0, shape=[10], name="b")
+ xw = math_ops.matmul(inp, w, name="xw")
+ h = bias_add(xw, b, name="h")
+ w_grad = gradients.gradients(h, w)[0]
+ self.assertEquals("MatMul", w_grad.op.type)
+ self.assertEquals(w_grad.op._original_op, xw.op)
+ self.assertTrue(w_grad.op.get_attr("transpose_a"))
+ self.assertFalse(w_grad.op.get_attr("transpose_b"))
+
+ def testUnusedOutput(self):
+ with ops.Graph().as_default():
+ w = constant(1.0, shape=[2, 2])
+ x = constant(1.0, shape=[2, 2])
+ wx = math_ops.matmul(w, x)
+ split_wx = array_ops.split(0, 2, wx)
+ c = math_ops.reduce_sum(split_wx[1])
+ gw = gradients.gradients(c, [w])[0]
+ self.assertEquals("MatMul", gw.op.type)
+
+ def testColocateGradients(self):
+ with ops.Graph().as_default() as g:
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ with g.device("/gpu:0"):
+ wx = math_ops.matmul(w, x)
+ gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEquals("/gpu:0", gw.device)
+
+ def testColocateGradientsWithAggregation(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/gpu:1"):
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ y = constant(1.0, shape=[1, 2])
+ wx = math_ops.matmul(w, x)
+ wy = math_ops.matmul(w, y)
+ with g.device("/gpu:0"):
+ z = wx + wy
+ gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEquals("/gpu:1", gw1.device)
+ gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ self.assertEquals(None, gw2.device)
+
+ def testBoundaryStop(self):
+ # Test that we don't differentiate 'x'. The gradient function for 'x' is
+ # set explicitly to None so we will get an exception if the gradient code
+ # tries to differentiate 'x'.
+ with ops.Graph().as_default() as g:
+ c = constant(1.0)
+ x = array_ops.identity(c)
+ y = x + 1.0
+ z = y + 1
+ grads = gradients.gradients(z, [x])
+ self.assertTrue(all([x for x in grads]))
+
+ def testBoundaryContinue(self):
+ # Test that we differentiate both 'x' and 'y' correctly when x is a
+ # predecessor of y.
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y * 3.0
+ grads = gradients.gradients(z, [x, y])
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(6.0, grads[0].eval())
+
+ def testAggregationMethodAccumulateN(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=
+ gradients.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testAggregationMethodAddN(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=gradients.AggregationMethod.ADD_N)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testAggregationMethodTree(self):
+ with self.test_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z,
+ [x, y],
+ aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
+ self.assertTrue(all([x for x in grads]))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testNoGradientForStringOutputs(self):
+ with ops.Graph().as_default() as g:
+ @ops.RegisterGradient("TestOp")
+ def _TestOpGrad(op, float_grad, string_grad):
+ """Gradient function for TestOp."""
+ self.assertEquals(float_grad.dtype, types.float32)
+ self.assertFalse(string_grad)
+ return float_grad
+ ops.RegisterShape("TestOp")(None)
+
+ c = constant(1.0)
+ x, y = g.create_op("TestOp", [c], [types.float32, types.string]).outputs
+ z = x * 2.0
+ w = z * 3.0
+ grads = gradients.gradients(z, [c])
+ self.assertTrue(isinstance(grads[0], ops.Tensor))
+
+
+class StopGradientTest(test_util.TensorFlowTestCase):
+
+ def testStopGradient(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[100, 32], name="in")
+ out = array_ops.stop_gradient(inp)
+ igrad = gradients.gradients(out, inp)[0]
+ assert igrad is None
+
+
+class HessianVectorProductTest(test_util.TensorFlowTestCase):
+
+ def testHessianVectorProduct(self):
+ # Manually compute the Hessian explicitly for a low-dimensional problem
+ # and check that HessianVectorProduct matches multiplication by the
+ # explicit Hessian.
+ # Specifically, the Hessian of f(x) = x^T A x is
+ # H = A + A^T.
+ # We expect HessianVectorProduct(f(x), x, v) to be H v.
+ m = 4
+ rng = np.random.RandomState([1, 2, 3])
+ mat_value = rng.randn(m, m).astype("float32")
+ v_value = rng.randn(m, 1).astype("float32")
+ x_value = rng.randn(m, 1).astype("float32")
+ hess_value = mat_value + mat_value.T
+ hess_v_value = np.dot(hess_value, v_value)
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ mat = constant_op.constant(mat_value)
+ v = constant_op.constant(v_value)
+ x = constant_op.constant(x_value)
+ mat_x = math_ops.matmul(mat, x, name="Ax")
+ x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
+ hess_v = gradients._hessian_vector_product(x_mat_x, [x], [v])[0]
+ hess_v_actual = hess_v.eval()
+ self.assertAllClose(hess_v_value, hess_v_actual)
+
+
+class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
+
+ def testIndexedSlicesToTensor(self):
+ with self.test_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.mul(c_sparse, 1.0)
+ self.assertAllClose(np_val, c_dense.eval())
+
+ def testInt64Indices(self):
+ with self.test_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ c_sparse = ops.IndexedSlices(
+ c_sparse.values, math_ops.cast(c_sparse.indices, types.int64),
+ c_sparse.dense_shape)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.mul(c_sparse, 1.0)
+ self.assertAllClose(np_val, c_dense.eval())
+
+ def testWarnings(self):
+ # Smaller than the threshold: no warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ constant([4, 4, 4, 4]))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(0, len(w))
+
+ # Greater than or equal to the threshold: warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ constant([100, 100, 100, 100]))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "with 100000000 elements. This may consume a large amount of memory."
+ in str(w[0].message))
+
+ # Unknown dense shape: warning.
+ c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32),
+ array_ops.placeholder(types.int32),
+ array_ops.placeholder(types.int32))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.mul(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "of unknown shape. This may consume a large amount of memory."
+ in str(w[0].message))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
new file mode 100644
index 0000000000..1b4f4aef22
--- /dev/null
+++ b/tensorflow/python/ops/image_ops.py
@@ -0,0 +1,786 @@
+"""## Encoding and Decoding.
+
+TensorFlow provides Ops to decode and encode JPEG and PNG formats. Encoded
+images are represented by scalar string Tensors, decoded images by 3-D uint8
+tensors of shape `[height, width, channels]`.
+
+The encode and decode Ops apply to one image at a time. Their input and output
+are all of variable size. If you need fixed size images, pass the output of
+the decode Ops to one of the cropping and resizing Ops.
+
+Note: The PNG encode and decode Ops support RGBA, but the conversions Ops
+presently only support RGB, HSV, and GrayScale.
+
+@@decode_jpeg
+@@encode_jpeg
+
+@@decode_png
+@@encode_png
+
+## Resizing.
+
+The resizing Ops accept input images as tensors of several types. They always
+output resized images as float32 tensors.
+
+The convenience function [resize_images()](#resize_images) supports both 4-D
+and 3-D tensors as input and output. 4-D tensors are for batches of images,
+3-D tensors for individual images.
+
+Other resizing Ops only support 3-D individual images as input:
+[resize_area](#resize_area), [resize_bicubic](#resize_bicubic),
+[resize_bilinear](#resize_bilinear),
+[resize_nearest_neighbor](#resize_nearest_neighbor).
+
+Example:
+
+```python
+# Decode a JPG image and resize it to 299 by 299.
+image = tf.image.decode_jpeg(...)
+resized_image = tf.image.resize_bilinear(image, [299, 299])
+```
+
+<i>Maybe refer to the Queue examples that show how to add images to a Queue
+after resizing them to a fixed size, and how to dequeue batches of resized
+images from the Queue.</i>
+
+@@resize_images
+
+@@resize_area
+@@resize_bicubic
+@@resize_bilinear
+@@resize_nearest_neighbor
+
+
+## Cropping.
+
+@@resize_image_with_crop_or_pad
+
+@@pad_to_bounding_box
+@@crop_to_bounding_box
+@@random_crop
+@@extract_glimpse
+
+## Flipping and Transposing.
+
+@@flip_up_down
+@@random_flip_up_down
+
+@@flip_left_right
+@@random_flip_left_right
+
+@@transpose_image
+
+## Image Adjustments.
+
+TensorFlow provides functions to adjust images in various ways: brightness,
+contrast, hue, and saturation. Each adjustment can be done with predefined
+parameters or with random parameters picked from predefined intervals. Random
+adjustments are often useful to expand a training set and reduce overfitting.
+
+@@adjust_brightness
+@@random_brightness
+
+@@adjust_contrast
+@@random_contrast
+
+@@per_image_whitening
+"""
+import math
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_image_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_image_ops import *
+from tensorflow.python.ops.gen_attention_ops import *
+# pylint: enable=wildcard-import
+
+ops.NoGradient('ResizeBilinear')
+ops.NoGradient('RandomCrop')
+
+
+def _ImageDimensions(images):
+ """Returns the dimensions of an image tensor.
+
+ Args:
+ images: 4-D Tensor of shape [batch, height, width, channels]
+
+ Returns:
+ list of integers [batch, height, width, channels]
+ """
+ # A simple abstraction to provide names for each dimension. This abstraction
+ # should make it simpler to switch dimensions in the future (e.g. if we ever
+ # want to switch height and width.)
+ return images.get_shape().as_list()
+
+
+def _Check3DImage(image):
+ """Assert that we are working with properly shaped image.
+
+ Args:
+ image: 3-D Tensor of shape [height, width, channels]
+
+ Raises:
+ ValueError: if image.shape is not a [3] vector.
+ """
+ if not image.get_shape().is_fully_defined():
+ raise ValueError('\'image\' must be fully defined.')
+ if image.get_shape().ndims != 3:
+ raise ValueError('\'image\' must be three-dimensional.')
+ if not all(x > 0 for x in image.get_shape()):
+ raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
+ image.get_shape())
+
+
+def _CheckAtLeast3DImage(image):
+ """Assert that we are working with properly shaped image.
+
+ Args:
+ image: >= 3-D Tensor of size [*, height, width, depth]
+
+ Raises:
+ ValueError: if image.shape is not a [>= 3] vector.
+ """
+ if not image.get_shape().is_fully_defined():
+ raise ValueError('\'image\' must be fully defined.')
+ if image.get_shape().ndims < 3:
+ raise ValueError('\'image\' must be at least three-dimensional.')
+ if not all(x > 0 for x in image.get_shape()):
+ raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
+ image.get_shape())
+
+
+def random_flip_up_down(image, seed=None):
+ """Randomly flips an image vertically (upside down).
+
+ With a 1 in 2 chance, outputs the contents of `image` flipped along the first
+ dimension, which is `height`. Otherwise output the image as-is.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror = math_ops.less(array_ops.pack([uniform_random, 1.0, 1.0]), 0.5)
+ return array_ops.reverse(image, mirror)
+
+
+def random_flip_left_right(image, seed=None):
+ """Randomly flip an image horizontally (left to right).
+
+ With a 1 in 2 chance, outputs the contents of `image` flipped along the
+ second dimension, which is `width`. Otherwise output the image as-is.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror = math_ops.less(array_ops.pack([1.0, uniform_random, 1.0]), 0.5)
+ return array_ops.reverse(image, mirror)
+
+
+def flip_left_right(image):
+ """Flip an image horizontally (left to right).
+
+ Outputs the contents of `image` flipped along the second dimension, which is
+ `width`.
+
+ See also `reverse()`.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.reverse(image, [False, True, False])
+
+
+def flip_up_down(image):
+ """Flip an image horizontally (upside down).
+
+ Outputs the contents of `image` flipped along the first dimension, which is
+ `height`.
+
+ See also `reverse()`.
+
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.reverse(image, [True, False, False])
+
+
+def transpose_image(image):
+ """Transpose an image by swapping the first and second dimension.
+
+ See also `transpose()`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`
+
+ Returns:
+ A 3-D tensor of shape `[width, height, channels]`
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ _Check3DImage(image)
+ return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
+
+
+def pad_to_bounding_box(image, offset_height, offset_width, target_height,
+ target_width):
+ """Pad `image` with zeros to the specified `height` and `width`.
+
+ Adds `offset_height` rows of zeros on top, `offset_width` columns of
+ zeros on the left, and then pads the image on the bottom and right
+ with zeros until it has dimensions `target_height`, `target_width`.
+
+ This op does nothing if `offset_*` is zero and the image already has size
+ `target_height` by `target_width`.
+
+ Args:
+ image: 3-D tensor with shape `[height, width, channels]`
+ offset_height: Number of rows of zeros to add on top.
+ offset_width: Number of columns of zeros to add on the left.
+ target_height: Height of output image.
+ target_width: Width of output image.
+
+ Returns:
+ 3-D tensor of shape `[target_height, target_width, channels]`
+
+ Raises:
+ ValueError: If the shape of `image` is incompatible with the `offset_*` or
+ `target_*` arguments
+ """
+ _Check3DImage(image)
+ height, width, depth = _ImageDimensions(image)
+
+ if target_width < width:
+ raise ValueError('target_width must be >= width')
+ if target_height < height:
+ raise ValueError('target_height must be >= height')
+
+ after_padding_width = target_width - offset_width - width
+ after_padding_height = target_height - offset_height - height
+
+ if after_padding_width < 0:
+ raise ValueError('target_width not possible given '
+ 'offset_width and image width')
+ if after_padding_height < 0:
+ raise ValueError('target_height not possible given '
+ 'offset_height and image height')
+
+ # Do not pad on the depth dimensions.
+ if (offset_width or offset_height or after_padding_width or
+ after_padding_height):
+ paddings = [[offset_height, after_padding_height],
+ [offset_width, after_padding_width], [0, 0]]
+ padded = array_ops.pad(image, paddings)
+ padded.set_shape([target_height, target_width, depth])
+ else:
+ padded = image
+
+ return padded
+
+
+def crop_to_bounding_box(image, offset_height, offset_width, target_height,
+ target_width):
+ """Crops an image to a specified bounding box.
+
+ This op cuts a rectangular part out of `image`. The top-left corner of the
+ returned image is at `offset_height, offset_width` in `image`, and its
+ lower-right corner is at
+ `offset_height + target_height, offset_width + target_width'.
+
+ Args:
+ image: 3-D tensor with shape `[height, width, channels]`
+ offset_height: Vertical coordinate of the top-left corner of the result in
+ the input.
+ offset_width: Horizontal coordinate of the top-left corner of the result in
+ the input.
+ target_height: Height of the result.
+ target_width: Width of the result.
+
+ Returns:
+ 3-D tensor of image with shape `[target_height, target_width, channels]`
+
+ Raises:
+ ValueError: If the shape of `image` is incompatible with the `offset_*` or
+ `target_*` arguments
+ """
+ _Check3DImage(image)
+ height, width, _ = _ImageDimensions(image)
+
+ if offset_width < 0:
+ raise ValueError('offset_width must be >= 0.')
+ if offset_height < 0:
+ raise ValueError('offset_height must be >= 0.')
+
+ if width < (target_width + offset_width):
+ raise ValueError('width must be >= target + offset.')
+ if height < (target_height + offset_height):
+ raise ValueError('height must be >= target + offset.')
+
+ cropped = array_ops.slice(image, [offset_height, offset_width, 0],
+ [target_height, target_width, -1])
+
+ return cropped
+
+
+def resize_image_with_crop_or_pad(image, target_height, target_width):
+ """Crops and/or pads an image to a target width and height.
+
+ Resizes an image to a target width and height by either centrally
+ cropping the image or padding it evenly with zeros.
+
+ If `width` or `height` is greater than the specified `target_width` or
+ `target_height` respectively, this op centrally crops along that dimension.
+ If `width` or `height` is smaller than the specified `target_width` or
+ `target_height` respectively, this op centrally pads with 0 along that
+ dimension.
+
+ Args:
+ image: 3-D tensor of shape [height, width, channels]
+ target_height: Target height.
+ target_width: Target width.
+
+ Raises:
+ ValueError: if `target_height` or `target_width` are zero or negative.
+
+ Returns:
+ Cropped and/or padded image of shape
+ `[target_height, target_width, channels]`
+ """
+ _Check3DImage(image)
+ original_height, original_width, _ = _ImageDimensions(image)
+
+ if target_width <= 0:
+ raise ValueError('target_width must be > 0.')
+ if target_height <= 0:
+ raise ValueError('target_height must be > 0.')
+
+ offset_crop_width = 0
+ offset_pad_width = 0
+ if target_width < original_width:
+ offset_crop_width = int((original_width - target_width) / 2)
+ elif target_width > original_width:
+ offset_pad_width = int((target_width - original_width) / 2)
+
+ offset_crop_height = 0
+ offset_pad_height = 0
+ if target_height < original_height:
+ offset_crop_height = int((original_height - target_height) / 2)
+ elif target_height > original_height:
+ offset_pad_height = int((target_height - original_height) / 2)
+
+ # Maybe crop if needed.
+ cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
+ min(target_height, original_height),
+ min(target_width, original_width))
+
+ # Maybe pad if needed.
+ resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
+ target_height, target_width)
+
+ if resized.get_shape().ndims is None:
+ raise ValueError('resized contains no shape.')
+ if not resized.get_shape()[0].is_compatible_with(target_height):
+ raise ValueError('resized height is not correct.')
+ if not resized.get_shape()[1].is_compatible_with(target_width):
+ raise ValueError('resized width is not correct.')
+ return resized
+
+
+class ResizeMethod(object):
+ BILINEAR = 0
+ NEAREST_NEIGHBOR = 1
+ BICUBIC = 2
+ AREA = 3
+
+
+def resize_images(images, new_height, new_width, method=ResizeMethod.BILINEAR):
+ """Resize `images` to `new_width`, `new_height` using the specified `method`.
+
+ Resized images will be distorted if their original aspect ratio is not
+ the same as `new_width`, `new_height`. To avoid distortions see
+ [resize_image_with_crop_or_pad](#resize_image_with_crop_or_pad).
+
+ `method` can be one of:
+
+ * <b>ResizeMethod.BILINEAR</b>: [Bilinear interpolation.]
+ (https://en.wikipedia.org/wiki/Bilinear_interpolation)
+ * <b>ResizeMethod.NEAREST_NEIGHBOR</b>: [Nearest neighbor interpolation.]
+ (https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation)
+ * <b>ResizeMethod.BICUBIC</b>: [Bicubic interpolation.]
+ (https://en.wikipedia.org/wiki/Bicubic_interpolation)
+ * <b>ResizeMethod.AREA</b>: Area interpolation.
+
+ Args:
+ images: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ new_height: integer.
+ new_width: integer.
+ method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`.
+
+ Raises:
+ ValueError: if the shape of `images` is incompatible with the
+ shape arguments to this function
+ ValueError: if an unsupported resize method is specified.
+
+ Returns:
+ If `images` was 4-D, a 4-D float Tensor of shape
+ `[batch, new_height, new_width, channels]`.
+ If `images` was 3-D, a 3-D float Tensor of shape
+ `[new_height, new_width, channels]`.
+ """
+ if images.get_shape().ndims is None:
+ raise ValueError('\'images\' contains no shape.')
+ # TODO(shlens): Migrate this functionality to the underlying Op's.
+ is_batch = True
+ if len(images.get_shape()) == 3:
+ is_batch = False
+ images = array_ops.expand_dims(images, 0)
+
+ _, height, width, depth = _ImageDimensions(images)
+
+ if width == new_width and height == new_height:
+ return images
+
+ if method == ResizeMethod.BILINEAR:
+ images = gen_image_ops.resize_bilinear(images, [new_height, new_width])
+ elif method == ResizeMethod.NEAREST_NEIGHBOR:
+ images = gen_image_ops.resize_nearest_neighbor(images, [new_height,
+ new_width])
+ elif method == ResizeMethod.BICUBIC:
+ images = gen_image_ops.resize_bicubic(images, [new_height, new_width])
+ elif method == ResizeMethod.AREA:
+ images = gen_image_ops.resize_area(images, [new_height, new_width])
+ else:
+ raise ValueError('Resize method is not implemented.')
+
+ if not is_batch:
+ images = array_ops.reshape(images, [new_height, new_width, depth])
+ return images
+
+
+def per_image_whitening(image):
+ """Linearly scales `image` to have zero mean and unit norm.
+
+ This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
+ of all values in image, and
+ `adjusted_stddev = max(stddev, 1.0/srqt(image.NumElements()))`.
+
+ `stddev` is the standard deviation of all values in `image`. It is capped
+ away from zero to protect against division by 0 when handling uniform images.
+
+ Note that this implementation is limited:
+ * It only whitens based on the statistics of an individual image.
+ * It does not take into account the covariance structure.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+
+ Returns:
+ The whitened image with same shape as `image`.
+
+ Raises:
+ ValueError: if the shape of 'image' is incompatible with this function.
+ """
+ _Check3DImage(image)
+ height, width, depth = _ImageDimensions(image)
+ num_pixels = height * width * depth
+
+ image = math_ops.cast(image, dtype=types.float32)
+ image_mean = math_ops.reduce_mean(image)
+
+ variance = (math_ops.reduce_mean(math_ops.square(image)) -
+ math_ops.square(image_mean))
+ stddev = math_ops.sqrt(variance)
+
+ # Apply a minimum normalization that protects us against uniform images.
+ min_stddev = constant_op.constant(1.0 / math.sqrt(num_pixels))
+ pixel_value_scale = math_ops.maximum(stddev, min_stddev)
+ pixel_value_offset = image_mean
+
+ image = math_ops.sub(image, pixel_value_offset)
+ image = math_ops.div(image, pixel_value_scale)
+ return image
+
+
+def random_brightness(image, max_delta, seed=None):
+ """Adjust the brightness of images by a random factor.
+
+ Equivalent to `adjust_brightness()` using a `delta` randomly picked in the
+ interval `[-max_delta, max_delta)`.
+
+ Note that `delta` is picked as a float. Because for integer type images,
+ the brightness adjusted result is rounded before casting, integer images may
+ have modifications in the range `[-max_delta,max_delta]`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+ max_delta: float, must be non-negative.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ 3-D tensor of images of shape `[height, width, channels]`
+
+ Raises:
+ ValueError: if max_delta is negative.
+ """
+ _Check3DImage(image)
+
+ if max_delta < 0:
+ raise ValueError('max_delta must be non-negative.')
+
+ delta = random_ops.random_uniform([], -max_delta, max_delta, seed=seed)
+ return adjust_brightness(image, delta)
+
+
+def random_contrast(image, lower, upper, seed=None):
+ """Adjust the contrase of an image by a random factor.
+
+ Equivalent to `adjust_constrast()` but uses a `contrast_factor` randomly
+ picked in the interval `[lower, upper]`.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`.
+ lower: float. Lower bound for the random contrast factor.
+ upper: float. Upper bound for the random contrast factor.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ 3-D tensor of shape `[height, width, channels]`.
+
+ Raises:
+ ValueError: if `upper <= lower` or if `lower < 0`.
+ """
+ _Check3DImage(image)
+
+ if upper <= lower:
+ raise ValueError('upper must be > lower.')
+
+ if lower < 0:
+ raise ValueError('lower must be non-negative.')
+
+ # Generate an a float in [lower, upper]
+ contrast_factor = random_ops.random_uniform([], lower, upper, seed=seed)
+ return adjust_contrast(image, contrast_factor)
+
+
+def adjust_brightness(image, delta, min_value=None, max_value=None):
+ """Adjust the brightness of RGB or Grayscale images.
+
+ The value `delta` is added to all components of the tensor `image`. `image`
+ and `delta` are cast to `float` before adding, and the resulting values are
+ clamped to `[min_value, max_value]`. Finally, the result is cast back to
+ `images.dtype`.
+
+ If `min_value` or `max_value` are not given, they are set to the minimum and
+ maximum allowed values for `image.dtype` respectively.
+
+ Args:
+ image: A tensor.
+ delta: A scalar. Amount to add to the pixel values.
+ min_value: Minimum value for output.
+ max_value: Maximum value for output.
+
+ Returns:
+ A tensor of the same shape and type as `image`.
+ """
+ if min_value is None:
+ min_value = image.dtype.min
+ if max_value is None:
+ max_value = image.dtype.max
+
+ with ops.op_scope([image, delta, min_value, max_value], None,
+ 'adjust_brightness') as name:
+ adjusted = math_ops.add(
+ math_ops.cast(image, types.float32),
+ math_ops.cast(delta, types.float32),
+ name=name)
+ if image.dtype.is_integer:
+ rounded = math_ops.round(adjusted)
+ else:
+ rounded = adjusted
+ clipped = clip_ops.clip_by_value(rounded, float(min_value),
+ float(max_value))
+ output = math_ops.cast(clipped, image.dtype)
+ return output
+
+
+def adjust_contrast(images, contrast_factor, min_value=None, max_value=None):
+ """Adjust contrast of RGB or grayscale images.
+
+ `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
+ interpreted as `[height, width, channels]`. The other dimensions only
+ represent a collection of images, such as `[batch, height, width, channels].`
+
+ Contrast is adjusted independently for each channel of each image.
+
+ For each channel, this Op first computes the mean of the image pixels in the
+ channel and then adjusts each component `x` of each pixel to
+ `(x - mean) * contrast_factor + mean`.
+
+ The adjusted values are then clipped to fit in the `[min_value, max_value]`
+ interval. If `min_value` or `max_value` is not given, it is replaced with the
+ minimum and maximum values for the data type of `images` respectively.
+
+ The contrast-adjusted image is always computed as `float`, and it is
+ cast back to its original type after clipping.
+
+ Args:
+ images: Images to adjust. At least 3-D.
+ contrast_factor: A float multiplier for adjusting contrast.
+ min_value: Minimum value for clipping the adjusted pixels.
+ max_value: Maximum value for clipping the adjusted pixels.
+
+ Returns:
+ The constrast-adjusted image or images.
+
+ Raises:
+ ValueError: if the arguments are invalid.
+ """
+ _CheckAtLeast3DImage(images)
+
+ # If these are None, the min/max should be a nop, but still prevent overflows
+ # from the cast back to images.dtype at the end of adjust_contrast.
+ if min_value is None:
+ min_value = images.dtype.min
+ if max_value is None:
+ max_value = images.dtype.max
+
+ with ops.op_scope(
+ [images, contrast_factor, min_value,
+ max_value], None, 'adjust_contrast') as name:
+ adjusted = gen_image_ops.adjust_contrast(images,
+ contrast_factor=contrast_factor,
+ min_value=min_value,
+ max_value=max_value,
+ name=name)
+ if images.dtype.is_integer:
+ return math_ops.cast(math_ops.round(adjusted), images.dtype)
+ else:
+ return math_ops.cast(adjusted, images.dtype)
+
+
+ops.RegisterShape('AdjustContrast')(
+ common_shapes.unchanged_shape_with_rank_at_least(3))
+
+
+@ops.RegisterShape('ResizeBilinear')
+@ops.RegisterShape('ResizeNearestNeighbor')
+@ops.RegisterShape('ResizeBicubic')
+@ops.RegisterShape('ResizeArea')
+def _ResizeShape(op):
+ """Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ size = tensor_util.ConstantValue(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ return [tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])]
+
+
+@ops.RegisterShape('DecodeJpeg')
+@ops.RegisterShape('DecodePng')
+def _ImageDecodeShape(op):
+ """Shape function for image decoding ops."""
+ unused_input_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ channels = op.get_attr('channels') or None
+ return [tensor_shape.TensorShape([None, None, channels])]
+
+
+@ops.RegisterShape('EncodeJpeg')
+@ops.RegisterShape('EncodePng')
+def _ImageEncodeShape(op):
+ """Shape function for image encoding ops."""
+ unused_input_shape = op.inputs[0].get_shape().with_rank(3)
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape('RandomCrop')
+def _random_cropShape(op):
+ """Shape function for the random_crop op."""
+ input_shape = op.inputs[0].get_shape().with_rank(3)
+ unused_size_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(2))
+ size = tensor_util.ConstantValue(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ channels = input_shape[2]
+ return [tensor_shape.TensorShape([height, width, channels])]
+
+
+def random_crop(image, size, seed=None, name=None):
+ """Randomly crops `image` to size `[target_height, target_width]`.
+
+ The offset of the output within `image` is uniformly random. `image` always
+ fully contains the result.
+
+ Args:
+ image: 3-D tensor of shape `[height, width, channels]`
+ size: 1-D tensor with two elements, specifying target `[height, width]`
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for this operation (optional).
+
+ Returns:
+ A cropped 3-D tensor of shape `[target_height, target_width, channels]`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_image_ops.random_crop(image, size, seed=seed1, seed2=seed2,
+ name=name)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
new file mode 100644
index 0000000000..2c51299198
--- /dev/null
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -0,0 +1,771 @@
+"""Tests for tensorflow.ops.image_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import googletest
+
+
+class FlipTest(test_util.TensorFlowTestCase):
+
+ def testIdempotentLeftRight(self):
+ x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testLeftRight(self):
+ x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testIdempotentUpDown(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testUpDown(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testIdempotentTranspose(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testTranspose(self):
+ x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
+ y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
+
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+
+class RandomFlipTest(test_util.TensorFlowTestCase):
+
+ def testRandomLeftRight(self):
+ x_np = np.array([0, 1], dtype=np.uint8).reshape([1, 2, 1])
+ num_iterations = 500
+
+ hist = [0, 0]
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_left_right(x_tf)
+ for _ in xrange(num_iterations):
+ y_np = y.eval().flatten()[0]
+ hist[y_np] += 1
+
+ # Ensure that each entry is observed within 4 standard deviations.
+ four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
+ self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
+
+ def testRandomUpDown(self):
+ x_np = np.array([0, 1], dtype=np.uint8).reshape([2, 1, 1])
+ num_iterations = 500
+
+ hist = [0, 0]
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_up_down(x_tf)
+ for _ in xrange(num_iterations):
+ y_np = y.eval().flatten()[0]
+ hist[y_np] += 1
+
+ # Ensure that each entry is observed within 4 standard deviations.
+ four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
+ self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
+
+
+class AdjustContrastTest(test_util.TensorFlowTestCase):
+
+ def _testContrast(self, x_np, y_np, contrast_factor, min_value, max_value):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.adjust_contrast(x,
+ contrast_factor,
+ min_value=min_value,
+ max_value=max_value)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testDoubleContrastUint8(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [0, 0, 0, 63, 169, 255, 29, 0, 255, 135, 255, 0]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=None,
+ max_value=None)
+
+ def testDoubleContrastFloat(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.float).reshape(x_shape)
+
+ y_data = [0, 0, 0, 62.75, 169.25, 255, 28.75, 0, 255, 134.75, 255, 0]
+ y_np = np.array(y_data, dtype=np.float).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=0,
+ max_value=255)
+
+ def testHalfContrastUint8(self):
+ x_shape = [1, 2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [23, 53, 66, 50, 118, 172, 41, 54, 176, 68, 178, 60]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=0.5,
+ min_value=None,
+ max_value=None)
+
+ def testBatchDoubleContrast(self):
+ x_shape = [2, 1, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [0, 0, 0, 81, 200, 255, 11, 0, 255, 117, 255, 0]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testContrast(x_np,
+ y_np,
+ contrast_factor=2.0,
+ min_value=None,
+ max_value=None)
+
+
+class AdjustBrightnessTest(test_util.TensorFlowTestCase):
+
+ def _testBrightness(self, x_np, y_np, delta, min_value, max_value):
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.adjust_brightness(x,
+ delta,
+ min_value=min_value,
+ max_value=max_value)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+ def testPositiveDeltaUint8(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 255, 11]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
+
+ def testPositiveDeltaFloat(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.float32).reshape(x_shape)
+
+ y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 265, 11]
+ y_np = np.array(y_data, dtype=np.float32).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
+
+ def testNegativeDelta(self):
+ x_shape = [2, 2, 3]
+ x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
+ x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
+
+ y_data = [5, 5, 5, 44, 125, 216, 27, 5, 224, 80, 245, 5]
+ y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
+
+ self._testBrightness(x_np, y_np, delta=-10.0, min_value=5, max_value=None)
+
+
+class RandomCropTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ # No random cropping is performed since the target width and height
+ # are match the image dimensions.
+ height = 4
+ width = 5
+ x_shape = [height, width, 3]
+ x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
+ target_shape_np = np.array([height, width], dtype=np.int64)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ target_shape = constant_op.constant(target_shape_np, shape=[2])
+ y = image_ops.random_crop(x, target_shape)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testRandomization(self):
+ # Run 1x1 crop num_samples times in an image and ensure that one finds each
+ # pixel 1/num_pixels of the time.
+ num_samples = 1000
+ height = 5
+ width = 4
+
+ num_pixels = height * width
+ data = np.arange(num_pixels).reshape([height, width, 1])
+ x_np = np.array(data).astype(np.int32)
+
+ target_shape_np = np.array([1, 1], dtype=np.int64)
+
+ y = []
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ target_shape = constant_op.constant(target_shape_np, shape=[2])
+ y_tf = image_ops.random_crop(x, target_shape)
+ for _ in xrange(num_samples):
+ y_np = y_tf.eval()
+ self.assertAllEqual(y_np.shape, [1, 1, 1])
+ y.extend(y_np.flatten())
+
+ # Calculate the mean and 4 * standard deviation.
+ mean = [num_samples / num_pixels] * num_pixels
+ four_stddev = 4.0 * np.sqrt(mean)
+
+ # Ensure that each entry is observed in 1/num_pixels of the samples
+ # within 4 standard deviations.
+ counts = np.bincount(y)
+ self.assertAllClose(counts, mean, atol=four_stddev)
+
+
+class PerImageWhiteningTest(test_util.TensorFlowTestCase):
+
+ def _NumpyPerImageWhitening(self, x):
+ num_pixels = np.prod(x.shape)
+ x2 = np.square(x).astype(np.float32)
+ mn = np.mean(x)
+ vr = np.mean(x2) - (mn * mn)
+ stddev = max(math.sqrt(vr), 1.0 / math.sqrt(num_pixels))
+
+ y = x.astype(np.float32)
+ y -= mn
+ y /= stddev
+ return y
+
+ def testBasic(self):
+ x_shape = [13, 9, 3]
+ x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
+ y_np = self._NumpyPerImageWhitening(x_np)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.per_image_whitening(x)
+ y_tf = y.eval()
+ self.assertAllClose(y_tf, y_np, atol=1e-4)
+
+
+class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ x_shape = [13, 9, 3]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ target_height = x_shape[0]
+ target_width = x_shape[1]
+ y = image_ops.crop_to_bounding_box(x, 0, 0, target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testCropping(self):
+ x_np = np.arange(0, 30, dtype=np.int32).reshape([6, 5, 1])
+
+ offset_height = 1
+ after_height = 2
+
+ offset_width = 0
+ after_width = 3
+
+ target_height = x_np.shape[0] - offset_height - after_height
+ target_width = x_np.shape[1] - offset_width - after_width
+
+ y_np = x_np[offset_height:offset_height + target_height,
+ offset_width:offset_width + target_width, :]
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.crop_to_bounding_box(x, offset_height, offset_width,
+ target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf.flatten(), y_np.flatten())
+
+
+class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
+
+ def testNoOp(self):
+ x_shape = [13, 9, 3]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ target_height = x_shape[0]
+ target_width = x_shape[1]
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.pad_to_bounding_box(x, 0, 0, target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
+ def testPadding(self):
+ x_shape = [3, 4, 1]
+ x_np = np.ones(x_shape, dtype=np.float32)
+
+ offset_height = 2
+ after_height = 3
+
+ offset_width = 1
+ after_width = 4
+
+ target_height = x_shape[0] + offset_height + after_height
+ target_width = x_shape[1] + offset_width + after_width
+
+ # Note the padding are along batch, height, width and depth.
+ paddings = ((offset_height, after_height),
+ (offset_width, after_width),
+ (0, 0))
+
+ y_np = np.pad(x_np, paddings, 'constant')
+
+ with self.test_session():
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.pad_to_bounding_box(x, offset_height, offset_width,
+ target_height, target_width)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
+
+class ResizeImagesTest(test_util.TensorFlowTestCase):
+
+ OPTIONS = [image_ops.ResizeMethod.BILINEAR,
+ image_ops.ResizeMethod.NEAREST_NEIGHBOR,
+ image_ops.ResizeMethod.BICUBIC,
+ image_ops.ResizeMethod.AREA]
+
+ def testNoOp(self):
+ img_shape = [1, 6, 4, 1]
+ data = [128, 128, 64, 64,
+ 128, 128, 64, 64,
+ 64, 64, 128, 128,
+ 64, 64, 128, 128,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 6
+ target_width = 4
+
+ for opt in self.OPTIONS:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ resized = y.eval()
+ self.assertAllClose(resized, img_np, atol=1e-5)
+
+ def testResizeDown(self):
+
+ data = [128, 128, 64, 64,
+ 128, 128, 64, 64,
+ 64, 64, 128, 128,
+ 64, 64, 128, 128,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ expected_data = [128, 64,
+ 64, 128,
+ 50, 100]
+ target_height = 3
+ target_width = 2
+
+ # Test out 3-D and 4-D image shapes.
+ img_shapes = [[1, 6, 4, 1], [6, 4, 1]]
+ target_shapes = [[1, target_height, target_width, 1],
+ [target_height, target_width, 1]]
+
+ for target_shape, img_shape in zip(target_shapes, img_shapes):
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ for opt in self.OPTIONS:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ expected = np.array(expected_data).reshape(target_shape)
+ resized = y.eval()
+ self.assertAllClose(resized, expected, atol=1e-5)
+
+ def testResizeUp(self):
+ img_shape = [1, 3, 2, 1]
+ data = [128, 64,
+ 64, 128,
+ 50, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 6
+ target_width = 4
+ expected_data = {}
+ expected_data[image_ops.ResizeMethod.BILINEAR] = [
+ 128.0, 96.0, 64.0, 64.0,
+ 96.0, 96.0, 96.0, 96.0,
+ 64.0, 96.0, 128.0, 128.0,
+ 57.0, 85.5, 114.0, 114.0,
+ 50.0, 75.0, 100.0, 100.0,
+ 50.0, 75.0, 100.0, 100.0]
+ expected_data[image_ops.ResizeMethod.NEAREST_NEIGHBOR] = [
+ 128.0, 128.0, 64.0, 64.0,
+ 128.0, 128.0, 64.0, 64.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 50.0, 50.0, 100.0, 100.0,
+ 50.0, 50.0, 100.0, 100.0]
+ expected_data[image_ops.ResizeMethod.AREA] = [
+ 128.0, 128.0, 64.0, 64.0,
+ 128.0, 128.0, 64.0, 64.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 128.0, 128.0,
+ 50.0, 50.0, 100.0, 100.0,
+ 50.0, 50.0, 100.0, 100.0]
+
+ for opt in [
+ image_ops.ResizeMethod.BILINEAR,
+ image_ops.ResizeMethod.NEAREST_NEIGHBOR,
+ image_ops.ResizeMethod.AREA]:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ resized = y.eval()
+ expected = np.array(expected_data[opt]).reshape(
+ [1, target_height, target_width, 1])
+ self.assertAllClose(resized, expected, atol=1e-05)
+
+ def testResizeUpBicubic(self):
+ img_shape = [1, 6, 6, 1]
+ data = [128, 128, 64, 64, 128, 128, 64, 64,
+ 64, 64, 128, 128, 64, 64, 128, 128,
+ 50, 50, 100, 100, 50, 50, 100, 100,
+ 50, 50, 100, 100, 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 8
+ target_width = 8
+ expected_data = [128, 135, 96, 55, 64, 114, 134, 128,
+ 78, 81, 68, 52, 57, 118, 144, 136,
+ 55, 49, 79, 109, 103, 89, 83, 84,
+ 74, 70, 95, 122, 115, 69, 49, 55,
+ 100, 105, 75, 43, 50, 89, 105, 100,
+ 57, 54, 74, 96, 91, 65, 55, 58,
+ 70, 69, 75, 81, 80, 72, 69, 70,
+ 105, 112, 75, 36, 45, 92, 111, 105]
+
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ image_ops.ResizeMethod.BICUBIC)
+ resized = y.eval()
+ expected = np.array(expected_data).reshape(
+ [1, target_height, target_width, 1])
+ self.assertAllClose(resized, expected, atol=1)
+
+ def testResizeDownArea(self):
+ img_shape = [1, 6, 6, 1]
+ data = [128, 64, 32, 16, 8, 4,
+ 4, 8, 16, 32, 64, 128,
+ 128, 64, 32, 16, 8, 4,
+ 5, 10, 15, 20, 25, 30,
+ 30, 25, 20, 15, 10, 5,
+ 5, 10, 15, 20, 25, 30]
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ target_height = 4
+ target_width = 4
+ expected_data = [73, 33, 23, 39,
+ 73, 33, 23, 39,
+ 14, 16, 19, 21,
+ 14, 16, 19, 21]
+
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ image_ops.ResizeMethod.AREA)
+ expected = np.array(expected_data).reshape(
+ [1, target_height, target_width, 1])
+ resized = y.eval()
+ self.assertAllClose(resized, expected, atol=1)
+
+
+class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
+
+ def _ResizeImageWithCropOrPad(self, original, original_shape,
+ expected, expected_shape):
+ x_np = np.array(original, dtype=np.uint8).reshape(original_shape)
+ y_np = np.array(expected).reshape(expected_shape)
+
+ target_height = expected_shape[0]
+ target_width = expected_shape[1]
+
+ with self.test_session():
+ image = constant_op.constant(x_np, shape=original_shape)
+ y = image_ops.resize_image_with_crop_or_pad(image,
+ target_height,
+ target_width)
+ resized = y.eval()
+ self.assertAllClose(resized, y_np, atol=1e-5)
+
+ def testBasic(self):
+ # Basic no-op.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ original, [2, 4, 1])
+
+ def testPad(self):
+ # Pad even along col.
+ original = [1, 2, 3, 4, 5, 6, 7, 8]
+ expected = [0, 1, 2, 3, 4, 0,
+ 0, 5, 6, 7, 8, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 6, 1])
+ # Pad odd along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 1, 2, 3, 4, 0, 0,
+ 0, 5, 6, 7, 8, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 7, 1])
+
+ # Pad even along row.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [4, 4, 1])
+ # Pad odd along row.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [5, 4, 1])
+
+ def testCrop(self):
+ # Crop even along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [2, 3,
+ 6, 7]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [2, 2, 1])
+ # Crop odd along col.
+
+ original = [1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12]
+ expected = [2, 3, 4,
+ 8, 9, 10]
+ self._ResizeImageWithCropOrPad(original, [2, 6, 1],
+ expected, [2, 3, 1])
+
+ # Crop even along row.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8]
+ expected = [3, 4,
+ 5, 6]
+ self._ResizeImageWithCropOrPad(original, [4, 2, 1],
+ expected, [2, 2, 1])
+
+ # Crop odd along row.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8,
+ 9, 10,
+ 11, 12,
+ 13, 14,
+ 15, 16]
+ expected = [3, 4,
+ 5, 6,
+ 7, 8,
+ 9, 10,
+ 11, 12]
+ self._ResizeImageWithCropOrPad(original, [8, 2, 1],
+ expected, [5, 2, 1])
+
+ def testCropAndPad(self):
+ # Pad along row but crop along col.
+ original = [1, 2, 3, 4,
+ 5, 6, 7, 8]
+ expected = [0, 0,
+ 2, 3,
+ 6, 7,
+ 0, 0]
+ self._ResizeImageWithCropOrPad(original, [2, 4, 1],
+ expected, [4, 2, 1])
+
+ # Crop along row but pad along col.
+ original = [1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8]
+ expected = [0, 3, 4, 0,
+ 0, 5, 6, 0]
+ self._ResizeImageWithCropOrPad(original, [4, 2, 1],
+ expected, [2, 4, 1])
+
+
+def _SimpleColorRamp():
+ """Build a simple color ramp RGB image."""
+ w, h = 256, 200
+ i = np.arange(h)[:, None]
+ j = np.arange(w)
+ image = np.empty((h, w, 3), dtype=np.uint8)
+ image[:, :, 0] = i
+ image[:, :, 1] = j
+ image[:, :, 2] = (i + j) >> 1
+ return image
+
+
+class JpegTest(test_util.TensorFlowTestCase):
+
+ # TODO(irving): Add self.assertAverageLess or similar to test_util
+ def averageError(self, image0, image1):
+ self.assertEqual(image0.shape, image1.shape)
+ image0 = image0.astype(int) # Avoid overflow
+ return np.abs(image0 - image1).sum() / float(np.prod(image0.shape))
+
+ def testExisting(self):
+ # Read a real jpeg and verify shape
+ path = ('tensorflow/core/lib/jpeg/testdata/'
+ 'jpeg_merge_test1.jpg')
+ with self.test_session() as sess:
+ jpeg0 = io_ops.read_file(path)
+ image0 = image_ops.decode_jpeg(jpeg0)
+ image1 = image_ops.decode_jpeg(image_ops.encode_jpeg(image0))
+ jpeg0, image0, image1 = sess.run([jpeg0, image0, image1])
+ self.assertEqual(len(jpeg0), 3771)
+ self.assertEqual(image0.shape, (256, 128, 3))
+ self.assertLess(self.averageError(image0, image1), 0.8)
+
+ def testSynthetic(self):
+ with self.test_session() as sess:
+ # Encode it, then decode it, then encode it
+ image0 = constant_op.constant(_SimpleColorRamp())
+ jpeg0 = image_ops.encode_jpeg(image0)
+ image1 = image_ops.decode_jpeg(jpeg0)
+ image2 = image_ops.decode_jpeg(image_ops.encode_jpeg(image1))
+ jpeg0, image0, image1, image2 = sess.run([jpeg0, image0, image1, image2])
+
+ # The decoded-encoded image should be similar to the input
+ self.assertLess(self.averageError(image0, image1), 0.6)
+
+ # We should be very close to a fixpoint
+ self.assertLess(self.averageError(image1, image2), 0.02)
+
+ # Smooth ramps compress well (input size is 153600)
+ self.assertGreaterEqual(len(jpeg0), 5000)
+ self.assertLessEqual(len(jpeg0), 6000)
+
+ def testShape(self):
+ with self.test_session() as sess:
+ jpeg = constant_op.constant('nonsense')
+ for channels in 0, 1, 3:
+ image = image_ops.decode_jpeg(jpeg, channels=channels)
+ self.assertEqual(image.get_shape().as_list(),
+ [None, None, channels or None])
+
+
+class PngTest(test_util.TensorFlowTestCase):
+
+ def testExisting(self):
+ # Read some real PNGs, converting to different channel numbers
+ prefix = 'tensorflow/core/lib/png/testdata/'
+ inputs = (1, 'lena_gray.png'), (4, 'lena_rgba.png')
+ for channels_in, filename in inputs:
+ for channels in 0, 1, 3, 4:
+ with self.test_session() as sess:
+ png0 = io_ops.read_file(prefix + filename)
+ image0 = image_ops.decode_png(png0, channels=channels)
+ png0, image0 = sess.run([png0, image0])
+ self.assertEqual(image0.shape, (26, 51, channels or channels_in))
+ if channels == channels_in:
+ image1 = image_ops.decode_png(image_ops.encode_png(image0))
+ self.assertAllEqual(image0, image1.eval())
+
+ def testSynthetic(self):
+ with self.test_session() as sess:
+ # Encode it, then decode it
+ image0 = constant_op.constant(_SimpleColorRamp())
+ png0 = image_ops.encode_png(image0, compression=7)
+ image1 = image_ops.decode_png(png0)
+ png0, image0, image1 = sess.run([png0, image0, image1])
+
+ # PNG is lossless
+ self.assertAllEqual(image0, image1)
+
+ # Smooth ramps compress well, but not too well
+ self.assertGreaterEqual(len(png0), 400)
+ self.assertLessEqual(len(png0), 750)
+
+ def testShape(self):
+ with self.test_session() as sess:
+ png = constant_op.constant('nonsense')
+ for channels in 0, 1, 3:
+ image = image_ops.decode_png(png, channels=channels)
+ self.assertEqual(image.get_shape().as_list(),
+ [None, None, channels or None])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
new file mode 100644
index 0000000000..09c8801e0e
--- /dev/null
+++ b/tensorflow/python/ops/init_ops.py
@@ -0,0 +1,181 @@
+"""Operations often used for initializing tensors."""
+
+import math
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+
+
+# TODO(mrry): PEP8 these.
+def constant_initializer(value=0.0):
+ """Returns an initializer that generates Tensors with a single value.
+
+ Args:
+ value: A Python scalar. All elements of the initialized variable
+ will be set to this value.
+
+ Returns:
+ An initializer that generates Tensors with a single value.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return constant_op.constant(value, dtype=dtype, shape=shape)
+ return _initializer
+
+def random_uniform_initializer(minval=0.0, maxval=1.0, seed=None):
+ """Returns an initializer that generates Tensors with a uniform distribution.
+
+ Args:
+ minval: a python scalar or a scalar tensor. lower bound of the range
+ of random values to generate.
+ maxval: a python scalar or a scalar tensor. upper bound of the range
+ of random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a uniform distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed)
+ return _initializer
+
+def random_normal_initializer(mean=0.0, stddev=1.0, seed=None):
+ """Returns an initializer that generates Tensors with a normal distribution.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values
+ to generate.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the
+ random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a normal distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed)
+ return _initializer
+
+def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None):
+ """Returns an initializer that generates a truncated normal distribution.
+
+ These values are similar to values from a random_normal_initializer
+ except that values more than two standard deviations from the mean
+ are discarded and re-drawn. This is the recommended initializer for
+ neural network weights and filters.
+
+ Args:
+ mean: a python scalar or a scalar tensor. Mean of the random values
+ to generate.
+ stddev: a python scalar or a scalar tensor. Standard deviation of the
+ random values to generate.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates Tensors with a truncated normal
+ distribution.
+ """
+ def _initializer(shape, dtype=types.float32):
+ return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed)
+ return _initializer
+
+def uniform_unit_scaling_initializer(factor=1.0, seed=None):
+ """Returns an initializer that generates tensors without scaling variance.
+
+ When initializing a deep network, it is in principle advantageous to keep
+ the scale of the input variance constant, so it does not explode or diminish
+ by reaching the final layer. If the input is `x` and the operation `x * W`,
+ and we want to initialize `W` uniformly at random, we need to pick `W` from
+
+ [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
+
+ to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
+ A similar calculation for convolutional networks gives an analogous result
+ with `dim` equal to the product of the first 3 dimensions. When
+ nonlinearities are present, we need to multiply this by a constant `factor`.
+ See <https://arxiv.org/pdf/1412.6558v3.pdf> for deeper motivation, experiments
+ and the calculation of constants. In section 2.3 there, the constants were
+ numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
+
+ Args:
+ factor: Float. A multiplicative factor by which the values will be scaled.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+
+ Returns:
+ An initializer that generates tensors with unit variance.
+ """
+ def _initializer(shape, dtype=types.float32):
+ input_size = 1.0
+ # Estimating input size is not possible to do perfectly, but we try.
+ # The estimate, obtained by multiplying all dimensions but the last one,
+ # is the right thing for matrix multiply and convolutions (see above).
+ for dim in shape[:-1]:
+ input_size *= float(dim)
+ max_val = math.sqrt(float(3) / float(input_size)) * factor
+ return random_ops.random_uniform(shape, -max_val, max_val,
+ dtype, seed=seed)
+ return _initializer
+
+# TODO(vrv): Unhide when we are ready to expose this publicly.
+def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None,
+ name="random_walk"):
+ """Create a random tensor such that backprop neither vanishes nor explodes.
+
+ Args:
+ shape: a python array of int or a 1-d tensor. Sizes of the Tensor.
+ nonlinearity: the brain python function for implementing the
+ nonlinearity in tensor flow.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: string. Optional name for the op.
+
+ Returns:
+ A Tensor of the specified sizes filled with random values.
+ """
+ assert len(shape) == 2, "Random Walk initialization only supports 2D tensors."
+ num_inputs = shape[0]
+ if nonlinearity == math_ops.tanh:
+ # No real formula for this case yet, but this works well for many
+ # layer widths.
+ rwg = 1.13
+ elif nonlinearity == array_ops.identity:
+ rwg = math.exp(1.0 / float(2.0 * num_inputs))
+ elif nonlinearity == nn_ops.relu:
+ rwg = math.sqrt(2.0) * math.exp(1.2 / float(max(num_inputs, 6) - 2.4))
+ else:
+ assert False, "Unsupported nonlinearity for Random Walk initialization."
+
+ mean = 0.0
+ stddev = rwg / math.sqrt(float(num_inputs))
+
+ return random_ops.random_normal(shape, mean=mean, stddev=stddev, dtype=dtype,
+ seed=seed, name=name)
+
+
+# TODO(vrv): Unhide when we are ready to expose this publicly.
+class _RandomWalkInitializer(object):
+ """An Initializer that generates a tensor for Random Walk Initialization."""
+
+ def __init__(self, nonlinearity, seed=None):
+ """Construct a RandomWalkInitializer.
+
+ Args:
+ nonlinearity: the python tensorflow function that computes a nonlinearity
+ in the graph, typically after a Wx+b type operation.
+ seed: A Python integer. Used to create random seeds.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ """
+ self._nonlinearity = nonlinearity
+ self._seed = seed
+
+ def __call__(self, shape, dtype=types.float32):
+ """Generate a tensor used to initialize a variable."""
+ return random_ops._random_walk(shape, self._nonlinearity, dtype,
+ seed=self._seed)
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
new file mode 100644
index 0000000000..9eb3bdfae4
--- /dev/null
+++ b/tensorflow/python/ops/io_ops.py
@@ -0,0 +1,541 @@
+"""## Placeholders
+
+TensorFlow provides a placeholder operation that must be fed with data
+on execution. For more info, see the section on [Feeding
+data](../../how_tos/reading_data/index.md#feeding).
+
+@@placeholder
+
+## Readers
+
+TensorFlow provides a set of Reader classes for reading data formats.
+For more information on inputs and readers, see [Reading
+data](../../how_tos/reading_data/index.md).
+
+@@ReaderBase
+@@TextLineReader
+@@WholeFileReader
+@@IdentityReader
+@@TFRecordReader
+@@FixedLengthRecordReader
+
+## Converting
+
+TensorFlow provides several operations that you can use to convert various data
+formats into tensors.
+
+@@decode_csv
+@@decode_raw
+@@parse_example
+@@parse_single_example
+
+## Queues
+
+TensorFlow provides several implementations of 'Queues', which are
+structures within the TensorFlow computation graph to stage pipelines
+of tensors together. The following describe the basic Queue interface
+and some implementations. To see an example use, see [Threading and
+Queues](../../how_tos/threading_and_queues/index.md).
+
+@@QueueBase
+@@FIFOQueue
+@@RandomShuffleQueue
+
+## Dealing with the filesystem
+
+@@matching_files
+@@read_file
+
+## Input pipeline
+
+TensorFlow functions for setting up an input-prefetching pipeline.
+Please see the [reading data how-to](../../how_tos/reading_data.md)
+for context.
+
+### Beginning of an input pipeline
+
+The "producer" functions add a queue to the graph and a corresponding
+`QueueRunner` for running the subgraph that fills that queue.
+
+@@match_filenames_once
+@@limit_epochs
+@@range_input_producer
+@@slice_input_producer
+@@string_input_producer
+
+### Batching at the end of an input pipeline
+
+These functions add a queue to the graph to assemble a batch of examples, with
+possible shuffling. They also add a `QueueRunner` for running the subgraph
+that fills that queue.
+
+Use [batch](#batch) or [batch_join](#batch_join) for batching examples that have
+already been well shuffled. Use [shuffle_batch](#shuffle_batch) or
+[shuffle_batch_join](#shuffle_batch_join) for examples that
+would benefit from additional shuffling.
+
+Use [batch](#batch) or [shuffle_batch](#shuffle_batch) if you want a
+single thread producing examples to batch, or if you have a
+single subgraph producing examples but you want to run it in N threads
+(where you increase N until it can keep the queue full). Use
+[batch_join](#batch_join) or [shuffle_batch_join](#shuffle_batch_join)
+if you have N different subgraphs producing examples to batch and you
+want them run by N threads.
+
+@@batch
+@@batch_join
+@@shuffle_batch
+@@shuffle_batch_join
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_io_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_io_ops import *
+# pylint: enable=wildcard-import
+
+
+# pylint: disable=protected-access
+def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"):
+ """Save a list of tensors to a file with given names.
+
+ Example usage without slice info:
+ Save("/foo/bar", ["w", "b"], [w, b])
+
+ Example usage with slices:
+ Save("/foo/bar", ["w", "w"], [slice0, slice1],
+ tensor_slices=["4 10 0,2:-", "4 10 2,2:-"])
+
+ Args:
+ filename: the file name of the sstable.
+ tensor_names: a list of strings.
+ tensors: the list of tensors to be saved.
+ tensor_slices: Optional list of strings to specify the shape and slices of
+ a larger virtual tensor that each tensor is a part of. If not specified
+ each tensor is saved as a full slice.
+ name: string. Optional name for the op.
+
+ Requires:
+ The length of tensors should match the size of tensor_names and of
+ tensor_slices.
+
+ Returns:
+ An Operation that saves the tensors.
+ """
+ if tensor_slices is None:
+ return gen_io_ops._save(filename, tensor_names, tensors, name=name)
+ else:
+ return gen_io_ops._save_slices(filename, tensor_names, tensor_slices,
+ tensors, name=name)
+
+
+def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
+ name="restore_slice", preferred_shard=-1):
+ """Restore a tensor slice from a set of files with a given pattern.
+
+ Example usage:
+ RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)
+
+ Args:
+ file_pattern: the file pattern used to match a set of checkpoint files.
+ tensor_name: the name of the tensor to restore.
+ shape_and_slice: the shape-and-slice spec of the slice.
+ tensor_type: the type of the tensor to restore.
+ name: string. Optional name for the op.
+ preferred_shard: Int. Optional shard to open first in the checkpoint file.
+
+ Returns:
+ A tensor of type "tensor_type".
+ """
+ base_type = types.as_dtype(tensor_type).base_dtype
+ return gen_io_ops._restore_slice(
+ file_pattern, tensor_name, shape_and_slice, base_type,
+ preferred_shard, name=name)
+
+
+@ops.RegisterShape("Restore")
+def _RestoreShape(op):
+ """Shape function for Restore op."""
+ # Validate input shapes.
+ unused_file_pattern = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_tensor_name = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("RestoreSlice")
+def _RestoreSliceShape(op):
+ """Shape function for RestoreSlice op."""
+ # Validate input shapes.
+ unused_file_pattern = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_tensor_name = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.scalar())
+ # TODO(mrry): Attempt to parse the shape_and_slice value and use it
+ # to form the shape of the output.
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("Save")
+def _SaveShape(op):
+ """Shape function for Save op."""
+ # Validate input shapes.
+ unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ data_count = len(op.inputs) - 2
+ unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ return []
+
+
+@ops.RegisterShape("SaveSlices")
+def _SaveSlicesShape(op):
+ """Shape function for SaveSlices op."""
+ # Validate input shapes.
+ unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
+ data_count = len(op.inputs) - 3
+ unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.vector(data_count))
+ # TODO(mrry): Attempt to parse the shapes_and_slices values and use
+ # them to constrain the shape of the remaining inputs.
+ return []
+
+
+@ops.RegisterShape("ShardedFilename")
+def _ShardedFilenameShape(op):
+ """Shape function for ShardedFilename op."""
+ # Validate input shapes.
+ unused_basename_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_shard_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_num_shards_shape = op.inputs[2].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ShardedFilespec")
+def _ShardedFilespecShape(op):
+ """Shape function for ShardedFilespec op."""
+ # Validate input shapes.
+ unused_basename_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_num_shards_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+class ReaderBase(object):
+ """Base class for different Reader types, that produce a record every step.
+
+ Conceptually, Readers convert string 'work units' into records (key,
+ value pairs). Typically the 'work units' are filenames and the
+ records are extracted from the contents of those files. We want a
+ single record produced per step, but a work unit can correspond to
+ many records.
+
+ Therefore we introduce some decoupling using a queue. The queue
+ contains the work units and the Reader dequeues from the queue when
+ it is asked to produce a record (via Read()) but it has finished the
+ last work unit.
+ """
+
+ def __init__(self, reader_ref, supports_serialize=False):
+ """Creates a new ReaderBase.
+
+ Args:
+ reader_ref: The operation that implements the reader.
+ supports_serialize: True if the reader implementation can
+ serialize its state.
+ """
+ self._reader_ref = reader_ref
+ self._supports_serialize = supports_serialize
+
+ @property
+ def reader_ref(self):
+ """Op that implements the reader."""
+ return self._reader_ref
+
+ def read(self, queue, name=None):
+ """Returns the next record (key, value pair) produced by a reader.
+
+ Will dequeue a work unit from queue if necessary (e.g. when the
+ Reader needs to start reading from a new file since it has
+ finished with the previous file).
+
+ Args:
+ queue: A Queue or a mutable string Tensor representing a handle
+ to a Queue, with string work items.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of Tensors (key, value).
+ key: A string scalar Tensor.
+ value: A string scalar Tensor.
+ """
+ if isinstance(queue, ops.Tensor):
+ queue_ref = queue
+ else:
+ queue_ref = queue.queue_ref
+ return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name)
+
+ def num_records_produced(self, name=None):
+ """Returns the number of records this reader has produced.
+
+ This is the same as the number of Read executions that have
+ succeeded.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ An int64 Tensor.
+
+ """
+ return gen_io_ops._reader_num_records_produced(self._reader_ref, name=name)
+
+ def num_work_units_completed(self, name=None):
+ """Returns the number of work units this reader has finished processing.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ An int64 Tensor.
+ """
+ return gen_io_ops._reader_num_work_units_completed(self._reader_ref,
+ name=name)
+
+ def serialize_state(self, name=None):
+ """Produce a string tensor that encodes the state of a reader.
+
+ Not all Readers support being serialized, so this can produce an
+ Unimplemented error.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A string Tensor.
+ """
+ return gen_io_ops._reader_serialize_state(self._reader_ref, name=name)
+
+ def restore_state(self, state, name=None):
+ """Restore a reader to a previously saved state.
+
+ Not all Readers support being restored, so this can produce an
+ Unimplemented error.
+
+ Args:
+ state: A string Tensor.
+ Result of a SerializeState of a Reader with matching type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_io_ops._reader_restore_state(self._reader_ref, state, name=name)
+
+ @property
+ def supports_serialize(self):
+ """Whether the Reader implementation can serialize its state."""
+ return self._supports_serialize
+
+ def reset(self, name=None):
+ """Restore a reader to its initial clean state.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_io_ops._reader_reset(self._reader_ref, name=name)
+
+
+ops.NoGradient("ReaderRead")
+ops.NoGradient("ReaderNumRecordsProduced")
+ops.NoGradient("ReaderNumWorkUnitsCompleted")
+ops.NoGradient("ReaderSerializeState")
+ops.NoGradient("ReaderRestoreState")
+ops.NoGradient("ReaderReset")
+
+
+class WholeFileReader(ReaderBase):
+ """A Reader that outputs the entire contents of a file as a value.
+
+ To use, enqueue filenames in a Queue. The output of Read will
+ be a filename (key) and the contents of that file (value).
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self, name=None):
+ """Create a WholeFileReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._whole_file_reader(name=name)
+ super(WholeFileReader, self).__init__(rr, supports_serialize=True)
+
+
+ops.NoGradient("WholeFileReader")
+
+
+class TextLineReader(ReaderBase):
+ """A Reader that outputs the lines of a file delimited by newlines.
+
+ Newlines are stripped from the output.
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, skip_header_lines=None, name=None):
+ """Create a TextLineReader.
+
+ Args:
+ skip_header_lines: An optional int. Defaults to 0. Number of lines
+ to skip from the beginning of every file.
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._text_line_reader(skip_header_lines=skip_header_lines,
+ name=name)
+ super(TextLineReader, self).__init__(rr)
+
+
+ops.NoGradient("TextLineReader")
+
+
+class FixedLengthRecordReader(ReaderBase):
+ """A Reader that outputs fixed-length records from a file.
+
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, record_bytes, header_bytes=None, footer_bytes=None,
+ name=None):
+ """Create a FixedLengthRecordReader.
+
+ Args:
+ record_bytes: An int.
+ header_bytes: An optional int. Defaults to 0.
+ footer_bytes: An optional int. Defaults to 0.
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._fixed_length_record_reader(
+ record_bytes=record_bytes, header_bytes=header_bytes,
+ footer_bytes=footer_bytes, name=name)
+ super(FixedLengthRecordReader, self).__init__(rr)
+
+
+ops.NoGradient("FixedLengthRecordReader")
+
+
+class TFRecordReader(ReaderBase):
+ """A Reader that outputs the records from a TFRecords file.
+
+ See ReaderBase for supported methods.
+ """
+ # TODO(josh11b): Support serializing and restoring state.
+
+ def __init__(self, name=None):
+ """Create a TFRecordReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._tf_record_reader(name=name)
+ super(TFRecordReader, self).__init__(rr)
+
+
+ops.NoGradient("TFRecordReader")
+
+
+class IdentityReader(ReaderBase):
+ """A Reader that outputs the queued work as both the key and value.
+
+ To use, enqueue strings in a Queue. Read will take the front
+ work string and output (work, work).
+
+ See ReaderBase for supported methods.
+ """
+
+ def __init__(self, name=None):
+ """Create a IdentityReader.
+
+ Args:
+ name: A name for the operation (optional).
+ """
+ rr = gen_io_ops._identity_reader(name=name)
+ super(IdentityReader, self).__init__(rr, supports_serialize=True)
+
+
+ops.NoGradient("IdentityReader")
+
+
+ops.RegisterShape("FixedLengthRecordReader")(common_shapes.scalar_shape)
+ops.RegisterShape("IdentityReader")(common_shapes.scalar_shape)
+ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape)
+ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape)
+ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape)
+
+
+@ops.RegisterShape("ReaderNumRecordsProduced")
+@ops.RegisterShape("ReaderNumWorkUnitsCompleted")
+@ops.RegisterShape("ReaderSerializeState")
+def _ReaderScalarShape(op):
+ """Shape function for ops that transform a reader to a scalar."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ReaderRead")
+def _ReaderReadShape(op):
+ """Shape function for the ReaderBase.Read op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_queue_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.scalar(), tensor_shape.scalar()]
+
+
+@ops.RegisterShape("ReaderReset")
+def _ReaderResetShape(op):
+ """Shape function for the ReaderBase.Reset op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("ReaderRestoreState")
+def _ReaderRestoreStateShape(op):
+ """Shape function for the ReaderBase.Restore op."""
+ unused_handle_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ unused_state_shape = op.inputs[1].get_shape().merge_with(
+ tensor_shape.scalar())
+ return []
+
+
+@ops.RegisterShape("ReadFile")
+def _ReadFileShape(op):
+ """Shape function for the ReadFile op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("MatchingFiles")
+def _MatchingFilesShape(op):
+ """Shape function for the MatchingFiles op."""
+ unused_patern_shape = op.inputs[0].get_shape().merge_with(
+ tensor_shape.scalar())
+ return [tensor_shape.unknown_shape(ndims=1)]
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
new file mode 100644
index 0000000000..893618c9dd
--- /dev/null
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -0,0 +1,25 @@
+"""Gradients for operators defined in linalg_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+@ops.RegisterGradient("MatrixInverse")
+def _MatrixInverseGrad(op, grad):
+ """Gradient for MatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.matmul(
+ ainv,
+ math_ops.matmul(grad, ainv, transpose_b=True),
+ transpose_a=True)
+
+@ops.RegisterGradient("BatchMatrixInverse")
+def _BatchMatrixInverseGrad(op, grad):
+ """Gradient for BatchMatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.batch_matmul(
+ ainv,
+ math_ops.batch_matmul(grad, ainv, adj_y=True),
+ adj_x=True)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
new file mode 100644
index 0000000000..76fd83fb3d
--- /dev/null
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -0,0 +1,62 @@
+"""Operations for linear algebra."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_linalg_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_linalg_ops import *
+# pylint: enable=wildcard-import
+
+
+@ops.RegisterShape("Cholesky")
+def _CholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchCholesky")
+def _BatchCholeskyShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]
+
+
+@ops.RegisterShape("MatrixDeterminant")
+def _MatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ if input_shape.ndims is not None:
+ return [tensor_shape.scalar()]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("BatchMatrixDeterminant")
+def _BatchMatrixDeterminantShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ if input_shape.ndims is not None:
+ return [input_shape[:-2]]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+
+@ops.RegisterShape("MatrixInverse")
+def _MatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ # The matrix must be square.
+ input_shape[0].assert_is_compatible_with(input_shape[1])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchMatrixInverse")
+def _BatchMatrixInverseShape(op):
+ input_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ # The matrices in the batch must be square.
+ input_shape[-1].assert_is_compatible_with(input_shape[-2])
+ return [input_shape]
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
new file mode 100644
index 0000000000..0fad4a2dde
--- /dev/null
+++ b/tensorflow/python/ops/logging_ops.py
@@ -0,0 +1,58 @@
+"""Logging Operations."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_logging_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_logging_ops import *
+# pylint: enable=wildcard-import
+
+
+# Assert and Print are special symbols in python, so we must
+# use an upper-case version of them.
+def Assert(condition, data, summarize=None, name=None):
+ """Asserts that the given condition is true.
+
+ If `condition` evaluates to false, print the list of tensors in `data`.
+ `summarize` determines how many entries of the tensors to print.
+
+ Args:
+ condition: The condition to evaluate.
+ data: The tensors to print out when condition is false.
+ summarize: Print this many entries of each tensor.
+ name: A name for this operation (optional).
+ """
+ return gen_logging_ops._assert(condition, data, summarize, name)
+
+
+def Print(input_, data, message=None, first_n=None, summarize=None,
+ name=None):
+ """Prints a list of tensors.
+
+ This is an identity op with the side effect of printing `data` when
+ evaluating.
+
+ Args:
+ input_: A tensor passed through this op.
+ data: A list of tensors to print out when op is evaluated.
+ message: A string, prefix of the error message.
+ first_n: Only log `first_n` number of times. Negative numbers log always;
+ this is the default.
+ summarize: Only print this many entries of each tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same tensor as `input_`.
+ """
+ return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+
+
+@ops.RegisterGradient("Print")
+def _PrintGrad(op, *grad):
+ return list(grad) + [None] * (len(op.inputs) - 1)
+
+
+# NOTE(mrry): Assert and Print produce an empty output, which is
+# presumably never read.
+ops.RegisterShape("Assert")(common_shapes.unknown_shape)
+ops.RegisterShape("Print")(common_shapes.unknown_shape)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
new file mode 100644
index 0000000000..cb808ff5b8
--- /dev/null
+++ b/tensorflow/python/ops/math_grad.py
@@ -0,0 +1,506 @@
+"""Gradients for operators defined in math_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+
+
+def _ReductionGradAssist(op):
+ """Reduction grads have much in common, so factor the commonality out."""
+ inp = op.inputs[0] # Example:
+ input_shape = array_ops.shape(inp) # [2, 3, 5, 7]
+ input_rank = array_ops.rank(inp) # 4
+ indices = op.inputs[1] # [1, 2]
+ indices_shape = array_ops.shape(indices) # [2]
+ new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
+ [math_ops.range(0, input_rank), # [0, 1, 2, 3]
+ indices], # [1, 2]
+ [input_shape, # [2, 3, 5, 7]
+ array_ops.fill(indices_shape, 1)]) # [1, 1]
+ return inp, new_output_shape, input_shape
+
+
+@ops.RegisterGradient("Sum")
+def _SumGrad(op, grad):
+ """Gradient for Sum."""
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad, new_output_shape)
+ return [array_ops.tile(grad, tile_scaling), None]
+
+
+def _MinOrMaxGrad(op, grad):
+ """Gradient for Max or Max. Amazingly it's precisely the same code."""
+ inp, new_output_shape, _ = _ReductionGradAssist(op)
+ y = op.outputs[0]
+ y = array_ops.reshape(y, new_output_shape)
+ grad = array_ops.reshape(grad, new_output_shape)
+ indicators = math_ops.cast(math_ops.equal(y, inp), grad.dtype)
+ return [indicators * grad, None]
+
+
+@ops.RegisterGradient("Max")
+def _MaxGrad(op, grad):
+ """Gradient for Max."""
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Min")
+def _MinGrad(op, grad):
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Mean")
+def _MeanGrad(op, grad):
+ """Gradient for Mean."""
+ sum_grad = _SumGrad(op, grad)[0]
+ input_shape = array_ops.shape(op.inputs[0])
+ output_shape = array_ops.shape(op.outputs[0])
+ factor = (math_ops.reduce_prod(input_shape) /
+ math_ops.reduce_prod(output_shape))
+ return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
+
+
+@ops.RegisterGradient("Prod")
+def _ProdGrad(op, grad):
+ """Gradient for Prod."""
+ # TODO(kearnes): this gives NaNs for 0s in the input tensor
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad * op.outputs[0], new_output_shape)
+ grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0])
+ return grad, None
+
+
+@ops.RegisterGradient("SegmentSum")
+def _SegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SegmentMean")
+def _SegmentMeanGrad(op, grad):
+ """Gradient for SegmentMean."""
+ input_rank = array_ops.rank(op.inputs[0])
+ ones_shape = array_ops.concat(
+ 0, [array_ops.shape(op.inputs[1]),
+ array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)])
+ ones = array_ops.fill(ones_shape,
+ constant_op.constant(1, dtype=grad.dtype))
+ scaled_grad = grad * math_ops.inv(math_ops.segment_sum(ones, op.inputs[1]))
+ return array_ops.gather(scaled_grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SparseSegmentSum")
+def _SparseSegmentSumGrad(op, grad):
+ """Gradient for SparseSegmentSum."""
+ input_rows = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]),
+ op.inputs[1], input_rows), None, None)
+
+
+@ops.RegisterGradient("SparseSegmentMean")
+def _SparseSegmentMeanGrad(op, grad):
+ """Gradient for SparseSegmentMean."""
+ dim0 = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.sparse_segment_mean_grad(grad,
+ op.inputs[1],
+ op.inputs[2],
+ dim0),
+ None, None)
+
+
+@ops.RegisterGradient("SegmentMin")
+def _SegmentMinGrad(op, grad):
+ """Gradient for SegmentMin."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.greater(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("SegmentMax")
+def _SegmentMaxGrad(op, grad):
+ """Gradient for SegmentMax."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.less(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("UnsortedSegmentSum")
+def _UnsortedSegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None, None
+
+
+@ops.RegisterGradient("Abs")
+def _AbsGrad(op, grad):
+ x = op.inputs[0]
+ return grad * math_ops.sign(x)
+
+
+@ops.RegisterGradient("Neg")
+def _NegGrad(_, grad):
+ """Returns -grad."""
+ return - grad
+
+
+@ops.RegisterGradient("Inv")
+def _InvGrad(op, grad):
+ """Returns -grad * (1 / x^2)."""
+ y = op.outputs[0] # y = 1 / x
+ return grad * (- math_ops.square(y))
+
+
+@ops.RegisterGradient("Square")
+def _SquareGrad(op, grad):
+ x = op.inputs[0]
+ return grad * (2.0 * x)
+
+
+@ops.RegisterGradient("Sqrt")
+def _SqrtGrad(op, grad):
+ y = op.outputs[0] # y = x^(1/2)
+ return grad * (.5 * math_ops.inv(y))
+
+
+@ops.RegisterGradient("Rsqrt")
+def _RsqrtGrad(op, grad):
+ x = op.inputs[0]
+ y = op.outputs[0] # y = x^(-1/2)
+ return grad * ((-0.5) * math_ops.inv(x) * y)
+
+
+@ops.RegisterGradient("Exp")
+def _ExpGrad(op, grad):
+ """Returns grad * exp(x)."""
+ y = op.outputs[0] # y = e^x
+ return grad * y
+
+
+@ops.RegisterGradient("Log")
+def _LogGrad(op, grad):
+ """Returns grad * (1/x)."""
+ x = op.inputs[0]
+ return grad * math_ops.inv(x)
+
+
+@ops.RegisterGradient("Tanh")
+def _TanhGrad(op, grad):
+ """Returns grad * (1 - tanh(x) * tanh(x))."""
+ y = op.outputs[0] # y = tanh(x)
+ return grad * (1 - math_ops.square(y))
+
+
+@ops.RegisterGradient("Sigmoid")
+def _SigmoidGrad(op, grad):
+ """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
+ y = op.outputs[0] # y = sigmoid(x)
+ return grad * (y * (1 - y))
+
+
+@ops.RegisterGradient("Sign")
+def _SignGrad(op, _):
+ """Returns 0."""
+ x = op.inputs[0]
+ return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
+
+
+@ops.RegisterGradient("Sin")
+def _SinGrad(op, grad):
+ """Returns grad * cos(x)."""
+ x = op.inputs[0]
+ return grad * math_ops.cos(x)
+
+
+@ops.RegisterGradient("Cos")
+def _CosGrad(op, grad):
+ """Returns grad * -sin(x)."""
+ x = op.inputs[0]
+ return -grad * math_ops.sin(x)
+
+
+@ops.RegisterGradient("AddN")
+def _AddNGrad(op, grad):
+ """Copies the gradient to all inputs."""
+ # Not broadcasting.
+ return [grad] * len(op.inputs)
+
+
+@ops.RegisterGradient("Add")
+def _AddGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Sub")
+def _SubGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Mul")
+def _MulGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ if x.dtype.base_dtype == types.complex64:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy))
+ else:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy))
+
+
+@ops.RegisterGradient("Div")
+def _DivGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad *
+ (-x / math_ops.square(y)), ry), sy))
+
+
+@ops.RegisterGradient("Pow")
+def _PowGrad(op, grad):
+ """Returns grad * (y*x^(y-1), z*log(x))."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ z = op.outputs[0]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ gx = array_ops.reshape(math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx),
+ sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(grad * z * math_ops.log(x), ry), sy)
+ return gx, gy
+
+
+def _MaximumMinimumGrad(op, grad, selector_op):
+ """Factor out the code for the gradient of Maximum or Minimum."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ gdtype = grad.dtype
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ gradshape = array_ops.shape(grad)
+ zeros = array_ops.zeros(gradshape, gdtype)
+ xmask = selector_op(x, y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ xgrad = math_ops.select(xmask, grad, zeros)
+ ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros)
+ gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
+ return (gx, gy)
+
+
+@ops.RegisterGradient("Maximum")
+def _MaximumGrad(op, grad):
+ """Returns grad*(x > y, x <= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
+
+
+@ops.RegisterGradient("Minimum")
+def _MinimumGrad(op, grad):
+ """Returns grad*(x < y, x >= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
+
+
+# Logical operations have no gradients.
+ops.NoGradient("Less")
+ops.NoGradient("LessEqual")
+ops.NoGradient("Greater")
+ops.NoGradient("GreaterEqual")
+ops.NoGradient("Equal")
+ops.NoGradient("NotEqual")
+ops.NoGradient("LogicalAnd")
+ops.NoGradient("LogicalOr")
+ops.NoGradient("LogicalNot")
+
+
+@ops.RegisterGradient("Select")
+def _SelectGrad(op, grad):
+ c = op.inputs[0]
+ x = op.inputs[1]
+ zeros = array_ops.zeros(array_ops.shape(c), dtype=x.dtype)
+ return (None, math_ops.select(c, grad, zeros),
+ math_ops.select(c, zeros, grad))
+
+
+@ops.RegisterGradient("MatMul")
+def _MatMulGrad(op, grad):
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ if not t_a and not t_b:
+ return (math_ops.matmul(grad, op.inputs[1], transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (math_ops.matmul(grad, op.inputs[1]),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_a=True,
+ transpose_b=True),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True,
+ transpose_b=True))
+
+
+@ops.RegisterGradient("SparseMatMul")
+def _SparseMatMulGrad(op, grad):
+ """Gradient for SparseMatMul."""
+
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ is_sparse = {
+ op.inputs[0]: op.get_attr("a_is_sparse"),
+ op.inputs[1]: op.get_attr("b_is_sparse"),
+ # Use heuristic to figure out if grad might be sparse
+ grad: (grad.op.type == "ReluGrad")
+ }
+ def _SparseMatMul(t1, t2, transpose_a=False, transpose_b=False):
+ """Helper function to create SparseMatMul op."""
+
+ assert t1 in is_sparse and t2 in is_sparse
+ t1_sparse = is_sparse[t1]
+ t2_sparse = is_sparse[t2]
+ if not t1_sparse and not t2_sparse:
+ return math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b)
+ transpose_out = False
+ if not t1_sparse:
+ transpose_out = True
+ t1, t2 = t2, t1
+ t1_sparse, t2_sparse = t2_sparse, t1_sparse
+ assert t1_sparse
+ transpose_a, transpose_b = not transpose_b, not transpose_a
+
+ if transpose_b:
+ t2 = array_ops.transpose(t2)
+ transpose_b = False
+ m = math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=t1_sparse,
+ b_is_sparse=t2_sparse)
+ if transpose_out:
+ m = array_ops.transpose(m)
+ return m
+
+ if not t_a and not t_b:
+ return (_SparseMatMul(grad, op.inputs[1], transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (_SparseMatMul(grad, op.inputs[1]),
+ _SparseMatMul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (_SparseMatMul(op.inputs[1], grad, transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (_SparseMatMul(op.inputs[1], grad,
+ transpose_a=True, transpose_b=True),
+ _SparseMatMul(grad, op.inputs[0],
+ transpose_a=True, transpose_b=True))
+
+
+@ops.RegisterGradient("Floor")
+def _FloorGrad(_, grad):
+ return grad
+
+
+@ops.RegisterGradient("BatchMatMul")
+def _BatchMatMul(op, grad):
+ """Returns the gradient of x and y given the gradient of x * y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ adj_x = op.get_attr("adj_x")
+ adj_y = op.get_attr("adj_y")
+
+ if not adj_x:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(grad, y, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, True, False)
+ else:
+ grad_x = math_ops.batch_matmul(grad, y, False, False)
+ grad_y = math_ops.batch_matmul(grad, x, True, False)
+ else:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(y, grad, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, False, False)
+ else:
+ grad_x = math_ops.batch_matmul(y, grad, True, True)
+ grad_y = math_ops.batch_matmul(grad, x, True, True)
+
+ return grad_x, grad_y
+
+
+ops.NoGradient("Range")
+ops.NoGradient("LinSpace")
+
+
+@ops.RegisterGradient("Complex")
+def _ComplexGrad(_, grad):
+ """Returns the real and imaginary components of 'grad', respectively."""
+ return math_ops.real(grad), math_ops.imag(grad)
+
+
+@ops.RegisterGradient("Real")
+def _RealGrad(_, grad):
+ """Returns 'grad' as the real part and set the imaginary part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(grad, zero)
+
+
+@ops.RegisterGradient("Imag")
+def _ImagGrad(_, grad):
+ """Returns 'grad' as the imaginary part and set the real part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(zero, grad)
+
+
+@ops.RegisterGradient("Conj")
+def _ConjGrad(_, grad):
+ """Returns the complex conjugate of grad."""
+ return math_ops.conj(grad)
+
+
+@ops.RegisterGradient("Cast")
+def _CastGrad(op, grad):
+ t = [types.float32, types.float64, types.bfloat16]
+ src_type = op.inputs[0].dtype.base_dtype
+ dst_type = grad.dtype.base_dtype
+ if src_type in t and dst_type in t:
+ return math_ops.cast(grad, src_type)
+ else:
+ return None
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
new file mode 100644
index 0000000000..d96320e96e
--- /dev/null
+++ b/tensorflow/python/ops/math_ops.py
@@ -0,0 +1,1201 @@
+"""## Arithmetic Operators
+
+TensorFlow provides several operations that you can use to add basic arithmetic
+operators to your graph.
+
+@@add
+@@sub
+@@mul
+@@div
+@@mod
+
+## Basic Math Functions
+
+TensorFlow provides several operations that you can use to add basic
+mathematical functions to your graph.
+
+@@add_n
+@@abs
+@@neg
+@@sign
+@@inv
+@@square
+@@round
+@@sqrt
+@@rsqrt
+@@pow
+@@exp
+@@log
+@@ceil
+@@floor
+@@maximum
+@@minimum
+@@cos
+@@sin
+
+## Matrix Math Functions
+
+TensorFlow provides several operations that you can use to add basic
+mathematical functions for matrices to your graph.
+
+@@diag
+@@transpose
+
+@@matmul
+@@batch_matmul
+
+@@matrix_determinant
+@@batch_matrix_determinant
+
+@@matrix_inverse
+@@batch_matrix_inverse
+
+@@cholesky
+@@batch_cholesky
+
+## Complex Number Functions
+
+TensorFlow provides several operations that you can use to add complex number
+functions to your graph.
+
+@@complex
+@@complex_abs
+@@conj
+@@imag
+@@real
+
+## Reduction
+
+TensorFlow provides several operations that you can use to perform
+common math computations that reduce various dimensions of a tensor.
+
+@@reduce_sum
+@@reduce_prod
+@@reduce_min
+@@reduce_max
+@@reduce_mean
+@@reduce_all
+@@reduce_any
+
+@@accumulate_n
+
+## Segmentation
+
+TensorFlow provides several operations that you can use to perform common
+math computations on tensor segments.
+Here a segmentation is a partitioning of a tensor along
+the first dimension, i.e. it defines a mapping from the first dimension onto
+`segment_ids`. The `segment_ids` tensor should be the size of
+the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
+where `k<d0`.
+In particular, a segmentation of a matrix tensor is a mapping of rows to
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+ ==> [[0 0 0 0]
+ [5 6 7 8]]
+```
+
+@@segment_sum
+@@segment_prod
+@@segment_min
+@@segment_max
+@@segment_mean
+
+@@unsorted_segment_sum
+
+@@sparse_segment_sum
+@@sparse_segment_mean
+
+
+## Sequence Comparison and Indexing
+
+TensorFlow provides several operations that you can use to add sequence
+comparison and index extraction to your graph. You can use these operations to
+determine sequence differences and determine the indexes of specific values in
+a tensor.
+
+@@argmin
+@@argmax
+
+@@listdiff
+@@where
+@@unique
+
+@@edit_distance
+
+@@invert_permutation
+"""
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import gen_state_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_math_ops import *
+
+
+# Aliases for some automatically-generated names.
+argmax = gen_math_ops.arg_max
+argmin = gen_math_ops.arg_min
+linspace = gen_math_ops.lin_space
+
+
+# pylint: disable=anomalous-backslash-in-string,protected-access
+def abs(x, name=None):
+ """Computes the absolute value of a tensor.
+
+ Given a tensor of real numbers `x`, this operation returns a tensor
+ containing the absolute value of each element in `x`. For example, if x is
+ an input element and y is an output element, this operation computes
+ \\\\(y = |x|\\\\).
+
+ See [`tf.complex_abs()`](#tf_complex_abs) to compute the absolute value of a complex
+ number.
+
+ Args:
+ x: A `Tensor` of type `float`, `double`, `int32`, or `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` the same size and type as `x` with absolute values.
+ """
+ with ops.op_scope([x], name, "Abs") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype == types.complex64:
+ return gen_math_ops.complex_abs(x, name=name)
+ return gen_math_ops._abs(x, name=name)
+
+
+
+def pow(x, y, name=None):
+ """Computes the power of one value to another.
+
+ Given a tensor `x` and a tensor `y`, this operation computes \\\\(x^y\\\\) for
+ corresponding elements in `x` and `y`. For example:
+
+ ```
+ # tensor 'x' is [[2, 2]], [3, 3]]
+ # tensor 'y' is [[8, 16], [2, 3]]
+ tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+ ```
+
+ Args:
+ x: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
+ y: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`.
+ """
+ with ops.op_scope([x], name, "Pow") as name:
+ return gen_math_ops._pow(x, y, name=name)
+
+
+def complex(real, imag, name=None):
+ """Converts two real numbers to a complex number.
+
+ Given a tensor `real` representing the real part of a complex number, and a
+ tensor `imag` representing the imaginary part of a complex number, this
+ operation computes complex numbers elementwise of the form \\\\(a + bj\\\\),
+ where *a* represents the `real` part and *b* represents the `imag` part.
+
+ The input tensors `real` and `imag` must be the same shape.
+
+ For example:
+
+ ```
+ # tensor 'real' is [2.25, 3.25]
+ # tensor `imag` is [4.75, 5.75]
+ tf.complex(real, imag) ==> [[2.25 + 4.74j], [3.25 + 5.75j]]
+ ```
+
+ Args:
+ real: A `Tensor` of type `float`.
+ imag: A `Tensor` of type `float`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `complex64`.
+ """
+ with ops.op_scope([real, imag], name, "Complex") as name:
+ return gen_math_ops._complex(real, imag, name=name)
+
+
+def round(x, name=None):
+ """Rounds the values of a tensor to the nearest integer, element-wise.
+
+ For example:
+
+ ```python
+ # 'a' is [0.9, 2.5, 2.3, -4.4]
+ tf.round(a) ==> [ 1.0, 3.0, 2.0, -4.0 ]
+ ```
+
+ Args:
+ x: A `Tensor` of type `float` or `double`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of same shape and type as `x`.
+ """
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype.is_integer:
+ return x
+ else:
+ return floor(x + 0.5, name=name)
+
+
+def cast(x, dtype, name=None):
+ """Casts a tensor to a new type.
+
+ The operation casts `x` (in case of `Tensor`) or `x.values`
+ (in case of `SparseTensor`) to `dtype`.
+
+ For example:
+
+ ```python
+ # tensor `a` is [1.8, 2.2], dtype=tf.float
+ tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
+ ```
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ dtype: The destination type.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `dtype`.
+ """
+ with ops.op_scope([x], name, "Cast") as name:
+ if isinstance(x, ops.SparseTensor):
+ values_cast = cast(x.values, dtype, name=name)
+ return ops.SparseTensor(x.indices, values_cast, x.shape)
+ else:
+ # TODO(mdevin): Handle what Josh said.
+ #
+ # Could return ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
+ # allows some conversions that cast() can't do, e.g. casting numbers to
+ # strings.
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype.base_dtype == dtype:
+ return x
+ return gen_math_ops.cast(x, dtype, name=name)
+
+
+def to_float(x, name="ToFloat"):
+ """Casts a tensor to type `float32`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `float32`.
+ """
+ return cast(x, types.float32, name=name)
+
+
+def to_double(x, name="ToDouble"):
+ """Casts a tensor to type `float64`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `float64`.
+ """
+ return cast(x, types.float64, name=name)
+
+
+def to_int32(x, name="ToInt32"):
+ """Casts a tensor to type `int32`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `int32`.
+ """
+ return cast(x, types.int32, name=name)
+
+
+def to_int64(x, name="ToInt64"):
+ """Casts a tensor to type `int64`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `int64`.
+ """
+ return cast(x, types.int64, name=name)
+
+
+def to_bfloat16(x, name="ToBFloat16"):
+ """Casts a tensor to type `bfloat16`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
+
+ Raises:
+ TypeError: If `x` cannot be cast to the `bfloat16`.
+ """
+ return cast(x, types.bfloat16, name=name)
+
+
+ops.Tensor._override_operator("__neg__", neg)
+ops.Tensor._override_operator("__abs__", abs)
+# __invert__ corresponds to the ~ operator. Here we follow the numpy convention
+# ~ marks an elementwise bit-wise inverse. This is only implemented for boolean
+# tensors and will throw a TypeError if used on nonboolean arrays
+ops.Tensor._override_operator("__invert__", logical_not)
+
+
+def _OverrideBinaryOperatorHelper(func, op_name):
+ """Register operators with different tensor and scalar versions.
+
+ Args:
+ func: the operator
+ op_name: name of the operator being overridden
+ """
+
+ def binary_op_wrapper(x, y):
+ with ops.op_scope([x, y], None, op_name) as name:
+ assert isinstance(x, ops.Tensor)
+ y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
+ return func(x, y, name=name)
+
+ ops.Tensor._override_operator("__%s__" % op_name, binary_op_wrapper)
+ del binary_op_wrapper
+
+ def r_binary_op_wrapper(y, x):
+ with ops.op_scope([x, y], None, op_name) as name:
+ assert isinstance(y, ops.Tensor)
+ x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
+ return func(x, y, name=name)
+
+ ops.Tensor._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
+ del r_binary_op_wrapper
+
+
+_OverrideBinaryOperatorHelper(add, "add")
+_OverrideBinaryOperatorHelper(sub, "sub")
+_OverrideBinaryOperatorHelper(mul, "mul")
+_OverrideBinaryOperatorHelper(div, "div")
+_OverrideBinaryOperatorHelper(mod, "mod")
+
+
+def logical_xor(x, y, name="LogicalXor"):
+ """x ^ y = (x | y) & ~(x & y)."""
+ # TODO(alemi) Make this a cwise op if people end up relying on it.
+ return logical_and(logical_or(x, y), logical_not(logical_and(x, y)),
+ name=name)
+
+_OverrideBinaryOperatorHelper(logical_and, "and")
+_OverrideBinaryOperatorHelper(logical_or, "or")
+_OverrideBinaryOperatorHelper(logical_xor, "xor")
+
+ops.Tensor._override_operator("__lt__", less)
+ops.Tensor._override_operator("__le__", less_equal)
+ops.Tensor._override_operator("__gt__", greater)
+ops.Tensor._override_operator("__ge__", greater_equal)
+
+
+def range(start, limit, delta=1, name="range"):
+ """Creates a sequence of integers.
+
+ This operation creates a sequence of integers that begins at `start` and
+ extends by increments of `delta` up to but not including `limit`.
+
+ For example:
+
+ ```
+ # 'start' is 3
+ # 'limit' is 18
+ # 'delta' is 3
+ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+ ```
+
+ Args:
+ start: A 0-D (scalar) of type `int32`. First entry in sequence.
+ limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
+ exclusive.
+ delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
+ Number that increments `start`.
+ name: A name for the operation (optional).
+
+ Returns:
+ An 1-D `int32` `Tensor`.
+ """
+ return gen_math_ops._range(start, limit, delta, name=name)
+
+
+@ops.RegisterShape("Range")
+def _RangeShape(op):
+ start_value = tensor_util.ConstantValue(op.inputs[0])
+ limit_value = tensor_util.ConstantValue(op.inputs[1])
+ delta_value = tensor_util.ConstantValue(op.inputs[2])
+ if start_value is None or limit_value is None or delta_value is None:
+ return [tensor_shape.vector(None)]
+ else:
+ return [tensor_shape.vector(
+ (limit_value - start_value + delta_value - 1) / delta_value)]
+
+
+# Reduction operations
+def _ReductionDims(x, reduction_indices):
+ """Returns range(0, rank(x)) if reduction_indices is None."""
+ if reduction_indices is not None:
+ return reduction_indices
+ else:
+ return range(0, array_ops.rank(x))
+
+
+def reduce_sum(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the sum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[1, 1, 1]]
+ # [1, 1, 1]]
+ tf.reduce_sum(x) ==> 6
+ tf.reduce_sum(x, 0) ==> [2, 2, 2]
+ tf.reduce_sum(x, 1) ==> [3, 3]
+ tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
+ tf.reduce_sum(x, [0, 1]) ==> 6
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._sum(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_mean(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the mean of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[1., 1. ]]
+ # [2., 2.]]
+ tf.reduce_mean(x) ==> 1.5
+ tf.reduce_mean(x, 0) ==> [1.5, 1.5]
+ tf.reduce_mean(x, 1) ==> [1., 2.]
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._mean(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_prod(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the product of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._prod(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_min(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the minimum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._min(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_max(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the maximum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._max(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_all(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the "logical and" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[True, True]]
+ # [False, False]]
+ tf.reduce_all(x) ==> False
+ tf.reduce_all(x, 0) ==> [False, False]
+ tf.reduce_all(x, 1) ==> [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._all(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def reduce_any(input_tensor, reduction_indices=None, keep_dims=False,
+ name=None):
+ """Computes the "logical or" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `reduction_indices`.
+ Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `reduction_indices` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ # 'x' is [[True, True]]
+ # [False, False]]
+ tf.reduce_any(x) ==> True
+ tf.reduce_any(x, 0) ==> [True, True]
+ tf.reduce_any(x, 1) ==> [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ reduction_indices: The dimensions to reduce. If `None` (the defaut),
+ reduces all dimensions.
+ keep_dims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ return gen_math_ops._any(input_tensor, _ReductionDims(input_tensor,
+ reduction_indices),
+ keep_dims, name=name)
+
+
+def matmul(a, b,
+ transpose_a=False, transpose_b=False,
+ a_is_sparse=False, b_is_sparse=False,
+ name=None):
+ """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
+
+ The inputs must be two-dimensional matrices, with matching inner dimensions,
+ possibly after transposition.
+
+ Both matrices must be of the same type. The supported types are:
+ `float`, `double`, `int32`, `complex64`.
+
+ Either matrix can be transposed on the fly by setting the corresponding flag
+ to `True`. This is `False` by default.
+
+ If one or both of the matrices contain a lot of zeros, a more efficient
+ multiplication algorithm can be used by setting the corresponding
+ `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
+
+ For example:
+
+ ```python
+ # 2-D tensor `a`
+ a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.]
+ [4. 5. 6.]]
+ # 2-D tensor `b`
+ b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.]
+ [9. 10.]
+ [11. 12.]]
+ c = tf.matmul(a, b) => [[58 64]
+ [139 154]]
+ ```
+
+ Args:
+ a: `Tensor` of type `float`, `double`, `int32` or `complex64`.
+ b: `Tensor` with same type as `a`.
+ transpose_a: If `True`, `a` is transposed before multiplication.
+ transpose_b: If `True`, `b` is transposed before multiplication.
+ a_is_sparse: If `True`, `a` is treated as a sparse matrix.
+ b_is_sparse: If `True`, `b` is treated as a sparse matrix.
+ name: Name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same type as `a`.
+ """
+ with ops.op_scope([a, b], name, "MatMul") as name:
+ a = ops.convert_to_tensor(a, name="a")
+ b = ops.convert_to_tensor(b, name="b")
+ if a.dtype == types.float32 and (a_is_sparse or b_is_sparse):
+ return sparse_matmul(a, b,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=a_is_sparse,
+ b_is_sparse=b_is_sparse,
+ name=name)
+ else:
+ return gen_math_ops._mat_mul(a, b,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ name=name)
+
+sparse_matmul = gen_math_ops._sparse_mat_mul
+batch_matmul = gen_math_ops._batch_mat_mul
+
+ops.RegisterShape("MatMul")(common_shapes.matmul_shape)
+ops.RegisterShape("SparseMatMul")(common_shapes.matmul_shape)
+
+
+def _as_indexed_slices(x):
+ """Convert 'x' to IndexedSlices.
+
+ Convert a dense Tensor to a block-sparse IndexedSlices.
+
+ Args:
+ x: Either a Tensor object, or an IndexedSlices object.
+
+ Returns:
+ An IndexedSlices object.
+
+ Raises:
+ TypeError: If 'x' is not a Tensor or an IndexedSlices object.
+ """
+ # TODO(mdevin): op_scope
+ if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
+ if isinstance(x, ops.IndexedSlices):
+ return x
+ x_shape = array_ops.shape(x)
+ return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
+
+
+def _as_indexed_slices_list(inputs):
+ """Convert all elements of 'inputs' to IndexedSlices.
+
+ Additionally, homogenize the types of all the indices to
+ either int32 or int64.
+
+ Args:
+ inputs: List containing either Tensor or IndexedSlices objects.
+
+ Returns:
+ A list of IndexedSlices objects.
+
+ Raises:
+ TypeError: If 'inputs' is not a list or a tuple.
+ """
+ if not isinstance(inputs, (list, tuple)):
+ raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
+ outputs = [_as_indexed_slices(i) for i in inputs]
+ with_int32_index = [o.indices for o in outputs
+ if o.indices.dtype == types.int32]
+ if not with_int32_index or len(with_int32_index) == len(outputs):
+ return outputs
+ casted_outputs = []
+ for o in outputs:
+ if o.indices.dtype == types.int32:
+ casted_outputs.append(
+ ops.IndexedSlices(o.values, cast(o.indices, types.int64),
+ o.dense_shape))
+ else:
+ casted_outputs.append(o)
+ return casted_outputs
+
+
+def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
+ """Returns the element-wise sum of a list of tensors.
+
+ Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
+ otherwise, these are inferred.
+
+ For example:
+
+ ```python
+ # tensor 'a' is [[1, 2], [3, 4]
+ # tensor `b` is [[5, 0], [0, 6]]
+ tf.accumulate_n([a, b, a]) ==> [[7, 4], [6, 14]]
+
+ # Explicitly pass shape and type
+ tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
+ ==> [[7, 4], [6, 14]]
+ ```
+
+ Args:
+ inputs: A list of `Tensor` objects, each with same shape and type.
+ shape: Shape of elements of `inputs`.
+ tensor_dtype: The type of `inputs`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of same shape and type as the elements of `inputs`.
+
+ Raises:
+ ValueError: If `inputs` don't all have same shape and dtype or the shape
+ cannot be inferred.
+ """
+ if tensor_dtype is None:
+ if not inputs or not isinstance(inputs, (list, tuple)):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
+ if not all(isinstance(x, ops.Tensor) for x in inputs):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ if not all(x.dtype == inputs[0].dtype for x in inputs):
+ raise ValueError("inputs must be a list of at least one Tensor with the "
+ "same dtype and shape")
+ tensor_dtype = inputs[0].dtype
+ if shape is not None:
+ shape = tensor_shape.as_shape(shape)
+ else:
+ shape = tensor_shape.unknown_shape()
+ for input_tensor in inputs:
+ if isinstance(input_tensor, ops.Tensor):
+ shape = shape.merge_with(input_tensor.get_shape())
+ if not shape.is_fully_defined():
+ # TODO(pbar): Make a version of assign_add that accepts an uninitialized
+ # lvalue, and takes its shape from that? This would allow accumulate_n to
+ # work in all situations that add_n currently works.
+ raise ValueError("Cannot infer the shape of the accumulator for "
+ "accumulate_n. Pass the shape argument, or set the shape "
+ "of at least one of the inputs.")
+ with ops.op_scope(inputs, name, "AccumulateN") as name:
+ var = gen_state_ops._temporary_variable(shape=shape, dtype=tensor_dtype)
+ var_name = var.op.name
+ var = state_ops.assign(var, array_ops.zeros_like(inputs[0]))
+ update_ops = []
+ for input_tensor in inputs:
+ op = state_ops.assign_add(var, input_tensor, use_locking=True)
+ update_ops.append(op)
+ with ops.control_dependencies(update_ops):
+ return gen_state_ops._destroy_temporary_variable(var,
+ var_name=var_name,
+ name=name)
+
+
+@ops.RegisterShape("BatchMatMul")
+def _BatchMatMulShape(op):
+ """Shape function for BatchMatMul op."""
+ a_shape = op.inputs[0].get_shape()
+ adj_a = op.get_attr("adj_x")
+ b_shape = op.inputs[1].get_shape()
+ adj_b = op.get_attr("adj_y")
+ if not a_shape.is_fully_defined() or not b_shape.is_fully_defined():
+ return [tensor_shape.unknown_shape()]
+ batch_dims = a_shape[:-2].merge_with(b_shape[:-2])
+ output_rows = a_shape[-1] if adj_a else a_shape[-2]
+ output_cols = b_shape[-2] if adj_b else b_shape[-1]
+ inner_a = a_shape[-2] if adj_a else a_shape[-1]
+ inner_b = b_shape[-1] if adj_b else b_shape[-2]
+ inner_a.assert_is_compatible_with(inner_b)
+ return [batch_dims.concatenate([output_rows, output_cols])]
+
+
+def sigmoid(x, name=None):
+ """Computes sigmoid of `x` element-wise.
+
+ Specifically, `y = 1 / (1 + exp(-x))`.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32`
+ otherwise the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Sigmoid") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._sigmoid(x, name=name)
+
+
+def tanh(x, name=None):
+ """Computes hyperbolic tangent of `x` element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Tanh") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._tanh(x, name=name)
+
+
+ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
+ops.RegisterShape("Ceil")(common_shapes.unchanged_shape)
+ops.RegisterShape("Conj")(common_shapes.unchanged_shape)
+ops.RegisterShape("Cos")(common_shapes.unchanged_shape)
+ops.RegisterShape("Exp")(common_shapes.unchanged_shape)
+ops.RegisterShape("Floor")(common_shapes.unchanged_shape)
+ops.RegisterShape("Imag")(common_shapes.unchanged_shape)
+ops.RegisterShape("Inv")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsFinite")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsInf")(common_shapes.unchanged_shape)
+ops.RegisterShape("IsNan")(common_shapes.unchanged_shape)
+ops.RegisterShape("Log")(common_shapes.unchanged_shape)
+ops.RegisterShape("LogicalNot")(common_shapes.unchanged_shape)
+ops.RegisterShape("Neg")(common_shapes.unchanged_shape)
+ops.RegisterShape("Real")(common_shapes.unchanged_shape)
+ops.RegisterShape("Rsqrt")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sign")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sin")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sqrt")(common_shapes.unchanged_shape)
+ops.RegisterShape("Square")(common_shapes.unchanged_shape)
+ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape)
+ops.RegisterShape("Tanh")(common_shapes.unchanged_shape)
+ops.RegisterShape("Cast")(common_shapes.unchanged_shape)
+ops.RegisterShape("ComplexAbs")(common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("Add")
+@ops.RegisterShape("Complex")
+@ops.RegisterShape("Div")
+@ops.RegisterShape("Equal")
+@ops.RegisterShape("Greater")
+@ops.RegisterShape("GreaterEqual")
+@ops.RegisterShape("Less")
+@ops.RegisterShape("LessEqual")
+@ops.RegisterShape("LogicalAnd")
+@ops.RegisterShape("LogicalOr")
+@ops.RegisterShape("Maximum")
+@ops.RegisterShape("Minimum")
+@ops.RegisterShape("Mod")
+@ops.RegisterShape("Mul")
+@ops.RegisterShape("NotEqual")
+@ops.RegisterShape("Pow")
+@ops.RegisterShape("Sub")
+def _BroadcastShape(op):
+ """Common shape function for binary operators that broadcast their inputs."""
+ shape_x = op.inputs[0].get_shape()
+ shape_y = op.inputs[1].get_shape()
+ if shape_x.ndims is None or shape_y.ndims is None:
+ return [tensor_shape.unknown_shape()]
+
+ # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
+ # and pad with 1 to make them the same length.
+ broadcasted_dims = reversed(list(itertools.izip_longest(
+ reversed(shape_x.dims), reversed(shape_y.dims),
+ fillvalue=tensor_shape.Dimension(1))))
+ # Next we combine the dimensions according to the numpy broadcasting rules.
+ # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ return_dims = []
+ for (dim_x, dim_y) in broadcasted_dims:
+ if dim_x.value is None or dim_y.value is None:
+ # One or both dimensions is unknown. If either dimension is greater than
+ # 1, we assume that the program is correct, and the other dimension will
+ # be broadcast to match it.
+ # TODO(mrry): If we eliminate the shape checks in C++, we must still
+ # assert that the unknown dim is either 1 or the same as the known dim.
+ if dim_x.value is not None and dim_x.value > 1:
+ return_dims.append(dim_x)
+ elif dim_y.value is not None and dim_y.value > 1:
+ return_dims.append(dim_y)
+ else:
+ return_dims.append(None)
+ elif dim_x.value == 1:
+ # We will broadcast dim_x to dim_y.
+ return_dims.append(dim_y)
+ elif dim_y.value == 1:
+ # We will broadcast dim_y to dim_x.
+ return_dims.append(dim_x)
+ elif dim_x.value == dim_y.value:
+ # The dimensions are compatible, so output is the same size in that
+ # dimension.
+ return_dims.append(dim_x.merge_with(dim_y))
+ else:
+ raise ValueError("Incompatible shapes for broadcasting: %s and %s"
+ % (shape_x, shape_y))
+ return [tensor_shape.TensorShape(return_dims)]
+
+
+@ops.RegisterShape("AddN")
+def _AddNShape(op):
+ merged_shape = tensor_shape.unknown_shape()
+ for input_ in op.inputs:
+ merged_shape = merged_shape.merge_with(input_.get_shape())
+ return [merged_shape]
+
+
+@ops.RegisterShape("Select")
+def _SelectShape(op):
+ # All three inputs must have the same shape.
+ return [op.inputs[0].get_shape()
+ .merge_with(op.inputs[1].get_shape())
+ .merge_with(op.inputs[2].get_shape())]
+
+
+@ops.RegisterShape("ArgMax")
+@ops.RegisterShape("ArgMin")
+def _ArgOpShape(op):
+ """Common shape function for arg-reduction ops."""
+ dimension_shape = op.inputs[1].get_shape()
+ dimension_shape.assert_is_compatible_with(tensor_shape.scalar())
+ input_shape = op.inputs[0].get_shape()
+ if input_shape.ndims is None:
+ return [tensor_shape.unknown_shape()]
+ elif input_shape.ndims <= 1:
+ return [tensor_shape.scalar()]
+
+ dimension = tensor_util.ConstantValue(op.inputs[1])
+ if dimension is None:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims - 1)]
+ elif 0 <= dimension and dimension < input_shape.ndims:
+ returned_shape = []
+ for i, dim in enumerate(input_shape.dims):
+ if i != dimension:
+ returned_shape.append(dim)
+ return [tensor_shape.TensorShape(returned_shape)]
+ else:
+ raise ValueError(
+ "dimension (%d) must be in the range [0, %d), where %d is the number "
+ "of dimensions in the input"
+ % (dimension, input_shape.ndims, input_shape.ndims))
+
+
+@ops.RegisterShape("All")
+@ops.RegisterShape("Any")
+@ops.RegisterShape("Max")
+@ops.RegisterShape("Mean")
+@ops.RegisterShape("Min")
+@ops.RegisterShape("Prod")
+@ops.RegisterShape("Sum")
+def _ReductionShape(op):
+ """Common shape function for reduction ops."""
+ input_shape = op.inputs[0].get_shape()
+ reduction_indices = tensor_util.ConstantValue(op.inputs[1])
+ keep_dims = op.get_attr("keep_dims")
+ if reduction_indices is None or input_shape.ndims is None:
+ if keep_dims:
+ return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
+ else:
+ return [tensor_shape.unknown_shape()]
+
+ # Turn reduction_indices from scalar to vector if necessary
+ reduction_indices = np.ravel(reduction_indices)
+
+ for reduction_index in reduction_indices:
+ if reduction_index < 0 or reduction_index >= input_shape.ndims:
+ raise ValueError("Invalid reduction dimension %d for input with %d "
+ "dimensions" % (reduction_index, input_shape.ndims))
+
+ returned_dims = []
+ if keep_dims:
+ for i, dim in enumerate(input_shape.dims):
+ if i in reduction_indices:
+ returned_dims.append(1)
+ else:
+ returned_dims.append(dim)
+ else:
+ for i, dim in enumerate(input_shape.dims):
+ if i not in reduction_indices:
+ returned_dims.append(dim)
+ return [tensor_shape.TensorShape(returned_dims)]
+
+
+@ops.RegisterShape("SegmentMax")
+@ops.RegisterShape("SegmentMean")
+@ops.RegisterShape("SegmentMin")
+@ops.RegisterShape("SegmentProd")
+@ops.RegisterShape("SegmentSum")
+def _SegmentReductionShape(op):
+ """Common shape function for segment reduction ops."""
+ data_shape = op.inputs[0].get_shape()
+ segment_ids_shape = op.inputs[1].get_shape()
+ segment_ids_shape.assert_has_rank(1)
+ return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
+
+
+@ops.RegisterShape("SparseSegmentMean")
+@ops.RegisterShape("SparseSegmentSum")
+def _SparseSegmentReductionShape(op):
+ """Common shape function for sparse segment reduction ops."""
+ data_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ indices_shape.assert_has_rank(1)
+ segment_ids_shape = op.inputs[2].get_shape()
+ segment_ids_shape.assert_has_rank(1)
+ indices_shape.assert_is_compatible_with(segment_ids_shape)
+ return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
+
+
+@ops.RegisterShape("SparseSegmentMeanGrad")
+def _SparseSegmentMeanGradShape(op):
+ """Shape function for the SparseSegmentMeanGrad op."""
+ input_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_segment_ids_shape = op.inputs[2].get_shape().merge_with(indices_shape)
+ unused_output_dim0_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.scalar())
+ output_dim0 = tensor_util.ConstantValue(op.inputs[3])
+ if output_dim0 is not None:
+ dim0 = output_dim0[0]
+ else:
+ dim0 = None
+ return [tensor_shape.TensorShape([dim0]).concatenate(input_shape[1:])]
+
+
+@ops.RegisterShape("UnsortedSegmentSum")
+def _UnsortedSegmentSumShape(op):
+ """Shape function for UnsortedSegmentSum."""
+ data_shape = op.inputs[0].get_shape()
+ segment_ids_shape = op.inputs[1].get_shape()
+ mid = segment_ids_shape.ndims
+ if mid is None:
+ return [tensor_shape.unknown_shape()]
+ else:
+ num_segments = tensor_util.ConstantValue(op.inputs[2])
+ return [tensor_shape.TensorShape([num_segments]).concatenate(
+ data_shape[mid:])]
+
+
+@ops.RegisterShape("LinSpace")
+def _LinspaceShape(op):
+ num = tensor_util.ConstantValue(op.inputs[2])
+ return [tensor_shape.vector(num)]
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
new file mode 100644
index 0000000000..86ea04f54d
--- /dev/null
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -0,0 +1,68 @@
+"""Tests for tensorflow.ops.math_ops."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+exp = math.exp
+log = math.log
+
+class ReduceTest(test_util.TensorFlowTestCase):
+
+ def testReduceAllDims(self):
+ x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
+ with self.test_session():
+ y_tf = math_ops.reduce_sum(x).eval()
+ self.assertEqual(y_tf, 21)
+
+class RoundTest(test_util.TensorFlowTestCase):
+
+ def testRounding(self):
+ x = [0.49, 0.7, -0.3, -0.8]
+ for dtype in [np.float32, np.double]:
+ x_np = np.array(x, dtype=dtype)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.round(x_tf)
+ y_tf_np = y_tf.eval()
+ y_np = np.round(x_np)
+ self.assertAllClose(y_tf_np, y_np, atol=1e-2)
+
+
+class ModTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ x = [0.5, 0.7, 0.3]
+ for dtype in [np.float32, np.double]:
+ # Test scalar and vector versions.
+ for denom in [x[0], [x[0]] * 3]:
+ x_np = np.array(x, dtype=dtype)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.mod(x_tf, denom)
+ y_tf_np = y_tf.eval()
+ y_np = np.fmod(x_np, denom)
+ self.assertAllClose(y_tf_np, y_np, atol=1e-2)
+
+ def testFixed(self):
+ x = [5, 10, 23]
+ for dtype in [np.int32, np.int64]:
+ # Test scalar and vector versions.
+ for denom in [x[0], x]:
+ x_np = np.array(x, dtype=dtype)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y_tf = math_ops.mod(x_tf, denom)
+ y_tf_np = y_tf.eval()
+ y_np = np.mod(x_np, denom)
+ self.assertAllClose(y_tf_np, y_np)
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
new file mode 100644
index 0000000000..7a4dc25e8b
--- /dev/null
+++ b/tensorflow/python/ops/nn.py
@@ -0,0 +1,816 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order
+"""## Activation Functions
+
+The activation ops provide different types of nonlinearities for use in
+neural networks. These include smooth nonlinearities (`sigmoid`,
+`tanh`, and `softplus`), continuous but not everywhere differentiable
+functions (`relu`, `relu6`, and `relu_x`), and random regularization
+(`dropout`).
+
+All activation ops apply componentwise, and produce a tensor of the same
+shape as the input tensor.
+
+@@relu
+@@relu6
+@@softplus
+@@dropout
+@@bias_add
+@@sigmoid
+@@tanh
+
+## Convolution
+
+The convolution ops sweep a 2-D filter over a batch of images, applying the
+filter to each window of each image of the appropriate size. The different
+ops trade off between generic vs. specific filters:
+
+* `conv2d`: Arbitrary filters that can mix channels together.
+* `depthwise_conv2d`: Filters that operate on each channel independently.
+* `separable_conv2d`: A depthwise spatial filter followed by a pointwise filter.
+
+Note that although these ops are called "convolution", they are strictly
+speaking "cross-correlation" since the filter is combined with an input window
+without reversing the filter. For details, see [the properties of
+cross-correlation](https://en.wikipedia.org/wiki/Cross-correlation#Properties).
+
+The filter is applied to image patches of the same size as the filter and
+strided according to the `strides` argument. `strides = [1, 1, 1, 1]` applies
+the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
+filter to every other image patch in each dimension, etc.
+
+Ignoring channels for the moment, the spatial semantics of the convolution ops
+are as follows. If the 4-D `input` has shape
+`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
+`[filter_height, filter_width, ...]`, then
+
+ output.shape = [batch,
+ (in_height - filter_height + 1) / strides[1],
+ (in_width - filter_width + 1) / strides[2],
+ ...]
+
+ output[b, i, j, :] =
+ sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
+ filter[di, dj, ...]
+
+Since `input` is 4-D, each `input[b, i, j, :]` is a vector. For `conv2d`, these
+vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
+vectors. For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
+is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
+concatenated.
+
+In the formula for `output.shape`, the rounding direction depends on padding:
+
+* `padding = 'SAME'`: Round down (only full size windows are considered).
+* `padding = 'VALID'`: Round up (partial windows are included).
+
+@@conv2d
+@@depthwise_conv2d
+@@separable_conv2d
+
+## Pooling
+
+The pooling ops sweep a rectangular window over the input tensor, computing a
+reduction operation for each window (average, max, or max with argmax). Each
+pooling op uses rectangular windows of size `ksize` separated by offset
+`strides`. For example, if `strides` is all ones every window is used, if
+`strides` is all twos every other window is used in each dimension, etc.
+
+In detail, the output is
+
+ output[i] = reduce(value[strides * i:strides * i + ksize])
+
+for each tuple of indices `i`. The output shape is
+
+ output.shape = (value.shape - ksize + 1) / strides
+
+where the rounding direction depends on padding:
+
+* `padding = 'SAME'`: Round down (only full size windows are considered).
+* `padding = 'VALID'`: Round up (partial windows are included).
+
+@@avg_pool
+@@max_pool
+@@max_pool_with_argmax
+
+## Normalization
+
+Normalization is useful to prevent neurons from saturating when inputs may
+have varying scale, and to aid generalization.
+
+@@l2_normalize
+@@local_response_normalization
+@@moments
+
+## Losses
+
+The loss ops measure error between two tensors, or between a tensor and zero.
+These can be used for measuring accuracy of a network in a regression task
+or for regularization purposes (weight decay).
+
+@@l2_loss
+
+## Classification
+
+TensorFlow provides several operations that help you perform classification.
+
+@@sigmoid_cross_entropy_with_logits
+@@softmax
+@@softmax_cross_entropy_with_logits
+
+## Embeddings
+
+TensorFlow provides several operations that help you compute embeddings.
+
+@@embedding_lookup
+@@embedding_lookup_sparse
+
+## Evaluation
+
+The evaluation ops are useful for measuring the performance of a network.
+Since they are nondifferentiable, they are typically used at evaluation time.
+
+@@top_k
+@@in_top_k
+
+## Candidate Sampling
+
+Do you want to train a multiclass or multilabel model with thousands
+or millions of output classes (for example, a language model with a
+large vocabulary)? Training with a full Softmax is slow in this case,
+since all of the classes are evaluated for every training example.
+Candidate Sampling training algorithms can speed up your step times by
+only considering a small randomly-chosen subset of contrastive classes
+(called candidates) for each batch of training examples.
+
+See our [Candidate Sampling Algorithms Reference]
+(http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+### Sampled Loss Functions
+
+TensorFlow provides the following sampled loss functions for faster training.
+
+@@nce_loss
+@@sampled_softmax_loss
+
+### Candidate Samplers
+
+TensorFlow provides the following samplers for randomly sampling candidate
+classes when using one of the sampled loss functions above.
+
+@@uniform_candidate_sampler
+@@log_uniform_candidate_sampler
+@@learned_unigram_candidate_sampler
+@@fixed_unigram_candidate_sampler
+
+### Miscellaneous candidate sampling utilities
+
+@@compute_accidental_hits
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import candidate_sampling_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import numerics
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops.math_ops import sigmoid
+from tensorflow.python.ops.math_ops import tanh
+
+# Bring more nn-associated functionality into this package.
+from tensorflow.python.ops.nn_ops import *
+from tensorflow.python.ops.candidate_sampling_ops import *
+from tensorflow.python.ops.embedding_ops import *
+
+
+def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
+ """Computes sigmoid cross entropy given `logits`.
+
+ Measures the probability error in discrete classification tasks in which each
+ class is independent and not mutually exclusive. For instance, one could
+ perform multilabel classification where a picture can contain both an elephant
+ and a dog at the same time.
+
+ For brevity, let `x = logits`, `z = targets`. The logistic loss is
+
+ x - x * z + log(1 + exp(-x))
+
+ To ensure stability and avoid overflow, the implementation uses
+
+ max(x, 0) - x * z + log(1 + exp(-abs(x)))
+
+ `logits` and `targets` must have the same type and shape.
+
+ Args:
+ logits: A `Tensor` of type `float32` or `float64`.
+ targets: A `Tensor` of the same type and shape as `logits`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ logistic losses.
+ """
+ with ops.op_scope([logits, targets], name, "logistic_loss") as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ targets = ops.convert_to_tensor(targets, name="targets")
+ # The logistic loss formula from above is
+ # x - x * z + log(1 + exp(-x))
+ # For x < 0, a more numerically stable formula is
+ # -x * z + log(1 + exp(x))
+ # To avoid branching, we use the combined version
+ # max(x, 0) - x * z + log(1 + exp(-abs(x)))
+ return math_ops.add(nn_ops.relu(logits) - logits * targets,
+ math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))),
+ name=name)
+
+
+def xw_plus_b(x, weights, biases, name=None):
+ """Computes matmul(x, weights) + biases.
+
+ Args:
+ x: a 2D tensor. Dimensions typically: batch, in_units
+ weights: a 2D tensor. Dimensions typically: in_units, out_units
+ biases: a 1D tensor. Dimensions: out_units
+ name: A name for the operation (optional). If not specified
+ "wx_plus_b" is used.
+
+ Returns:
+ A 2-D Tensor computing matmul(x, weights) + biases.
+ Dimensions typically: batch, out_units.
+ """
+ with ops.op_scope([x, weights, biases], name, "xw_plus_b") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ weights = ops.convert_to_tensor(weights, name="weights")
+ biases = ops.convert_to_tensor(biases, name="biases")
+ mm = math_ops.matmul(x, weights)
+ return nn_ops.bias_add(mm, biases, name=name)
+
+
+def relu_layer(x, weights, biases, name=None):
+ """Computes Relu(x * weight + biases).
+
+ Args:
+ x: a 2D tensor. Dimensions typically: batch, in_units
+ weights: a 2D tensor. Dimensions typically: in_units, out_units
+ biases: a 1D tensor. Dimensions: out_units
+ name: A name for the operation (optional). If not specified
+ "nn_relu_layer" is used.
+
+ Returns:
+ A 2-D Tensor computing relu(matmul(x, weights) + biases).
+ Dimensions typically: batch, out_units.
+ """
+ with ops.op_scope([x, weights, biases], name, "relu_layer") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ weights = ops.convert_to_tensor(weights, name="weights")
+ biases = ops.convert_to_tensor(biases, name="biases")
+ xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
+ return nn_ops.relu(xw_plus_b, name=name)
+
+
+def l2_normalize(x, dim, epsilon=1e-12, name=None):
+ """Normalizes along dimension `dim` using an L2 norm.
+
+ For a 1-D tensor with `dim = 0`, computes
+
+ output = x / sqrt(max(sum(x**2), epsilon))
+
+ For `x` with more dimensions, independently normalizes each 1-D slice along
+ dimension `dim`.
+
+ Args:
+ x: A `Tensor`.
+ dim: Dimension along which to normalize.
+ epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
+ divisor if `norm < sqrt(epsilon)`.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `Tensor` with the same shape as `x`.
+ """
+ with ops.op_scope([x], name, "l2_normalize") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ square_sum = math_ops.reduce_sum(math_ops.square(x), [dim], keep_dims=True)
+ x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
+ return math_ops.mul(x, x_inv_norm, name=name)
+
+
+def zero_fraction(value, name=None):
+ """Returns the fraction of zeros in `value`.
+
+ If `value` is empty, the result is `nan`.
+
+ This is useful in summaries to measure and report sparsity. For example,
+
+ z = tf.Relu(...)
+ summ = tf.scalar_summary('sparsity', tf.zero_fraction(z))
+
+ Args:
+ value: A tensor of numeric type.
+ name: A name for the operation (optional).
+
+ Returns:
+ The fraction of zeros in `value`, with type `float32`.
+ """
+ with ops.op_scope([value], name, "zero_fraction"):
+ value = ops.convert_to_tensor(value, name="value")
+ zero = constant_op.constant(0, dtype=value.dtype, name="zero")
+ return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero),
+ types.float32))
+
+
+def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
+ """Computes dropout.
+
+ With probability `keep_prob`, outputs the input element scaled up by
+ `1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected
+ sum is unchanged.
+
+ By default, each element is kept or dropped independently. If `noise_shape`
+ is specified, it must be
+ [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+ to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]`
+ will make independent decisions. For example, if `x.shape = [b, x, y, c]` and
+ `noise_shape = [b, 1, 1, c]`, each batch and channel component will be
+ kept independently and each row and column will be kept or not kept together.
+
+ Args:
+ x: A tensor.
+ keep_prob: Float probability that each element is kept.
+ noise_shape: Shape for randomly generated keep/drop flags.
+ seed: A Python integer. Used to create a random seed.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor of the same shape of `x`.
+
+ Raises:
+ ValueError: If `keep_prob` is not in `(0, 1]`.
+ """
+ if not (0 < keep_prob <= 1):
+ raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
+ with ops.op_scope([x], name, "dropout") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ noise_shape = noise_shape or array_ops.shape(x)
+ # uniform [keep_prob, 1.0 + keep_prob)
+ random_tensor = keep_prob
+ random_tensor += random_ops.random_uniform(
+ noise_shape, seed=seed, dtype=x.dtype)
+ # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
+ binary_tensor = math_ops.floor(random_tensor)
+ return x * (1.0 / keep_prob) * binary_tensor
+
+
+def depthwise_conv2d(input, filter, strides, padding, name=None):
+ """Depthwise 2-D convolution.
+
+ Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
+ and a filter tensor of shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`
+ containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
+ applies a different filter to each input channel (expanding from 1 channel
+ to `channel_multiplier` channels for each), then concatenates the results
+ together. The output has `in_channels * channel_multiplier` channels.
+
+ In detail,
+
+ output[b, i, j, k * channel_multiplier + q] =
+ sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
+ filter[di, dj, k, q]
+
+ Must have `strides[0] = strides[3] = 1`. For the most common case of the
+ same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
+
+ Args:
+ input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+ filter: 4-D with shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`.
+ strides: 1-D of size 4. The stride of the sliding window for each
+ dimension of `input`.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: A name for this operation (optional).
+
+ Returns:
+ A 4-D `Tensor` of shape
+ `[batch, out_height, out_width, in_channels * channel_multiplier].`
+ """
+ with ops.op_scope([input, filter], name, "depthwise") as name:
+ input = ops.convert_to_tensor(input, name="tensor_in")
+ filter = ops.convert_to_tensor(filter, name="filter_in")
+ # A shape is required to statically compute the number of separable filters.
+ if filter.get_shape().ndims is not None:
+ assert len(filter.get_shape()) == 4
+ in_channels = filter.get_shape()[2]
+ # Sanity checks, if shape information is available for the inputs.
+ if input.get_shape().ndims is not None:
+ assert len(input.get_shape()) == 4
+ assert input.get_shape()[3] == in_channels, (
+ "Mismatched input depth %d and number of depthwise filters %d." % (
+ input.get_shape()[3].value, in_channels))
+ else:
+ assert input.get_shape().ndims is not None, (
+ "Either tensor must provide static shape information.")
+ assert input.get_shape().ndims == 4
+ in_channels = input.get_shape()[3]
+
+ if in_channels == 1:
+ return nn_ops.conv2d(input, filter, strides, padding, name=name)
+ else:
+ # Create one separate convolution per channel.
+ convs = []
+ for channel in xrange(in_channels):
+ with ops.name_scope("depth%d" % channel) as channel_scope:
+ t_in = array_ops.slice(input, [0, 0, 0, channel], [-1, -1, -1, 1],
+ name="slice_inputs")
+ f_in = array_ops.slice(filter, [0, 0, channel, 0], [-1, -1, 1, -1],
+ name="slice_params")
+ convs.append(nn_ops.conv2d(t_in, f_in,
+ strides, padding, name=channel_scope))
+ # Concatenate the per-channel convolutions along the channel dimension.
+ return array_ops.concat(3, convs, name=name)
+
+
+def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
+ padding,
+ name=None):
+ """2-D convolution with separable filters.
+
+ Performs a depthwise convolution that acts separately on channels followed by
+ a pointwise convolution that mixes channels. Note that this is separability
+ between dimensions `[1, 2]` and `3`, not spatial separability between
+ dimensions `1` and `2`.
+
+ In detail,
+
+ output[b, i, j, k] = sum_{di, dj, q, r]
+ input[b, strides[1] * i + di, strides[2] * j + dj, q] *
+ depthwise_filter[di, dj, q, r] *
+ pointwise_filter[0, 0, q * channel_multiplier + r, k]
+
+ `strides` controls the strides for the depthwise convolution only, since
+ the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have
+ `strides[0] = strides[3] = 1`. For the most common case of the same
+ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
+
+ Args:
+ input: 4-D `Tensor` with shape `[batch, in_height, in_width, in_channels]`.
+ depthwise_filter: 4-D `Tensor` with shape
+ `[filter_height, filter_width, in_channels, channel_multiplier]`.
+ Contains `in_channels` convolutional filters of depth 1.
+ pointwise_filter: 4-D `Tensor` with shape
+ `[1, 1, channel_multiplier * in_channels, out_channels]`. Pointwise
+ filter to mix channels after `depthwise_filter` has convolved spatially.
+ strides: 1-D of size 4. The strides for the depthwise convolution for
+ each dimension of `input`.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: A name for this operation (optional).
+
+ Returns:
+ A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`.
+ """
+ with ops.op_scope([input, depthwise_filter, pointwise_filter],
+ name, "separable_conv2d") as name:
+ input = ops.convert_to_tensor(input, name="tensor_in")
+ depthwise_filter = ops.convert_to_tensor(depthwise_filter,
+ name="depthwise_filter")
+ pointwise_filter = ops.convert_to_tensor(pointwise_filter,
+ name="pointwise_filter")
+
+ if pointwise_filter.get_shape().ndims is not None:
+ assert len(pointwise_filter.get_shape()) == 4
+ assert pointwise_filter.get_shape()[0] == 1
+ assert pointwise_filter.get_shape()[1] == 1
+ if depthwise_filter.get_shape().ndims and input.get_shape().ndims:
+ channel_multiplier = depthwise_filter.get_shape()[3]
+ in_channels = input.get_shape()[3]
+ out_channels = pointwise_filter.get_shape()[3]
+ # This would mean the separable convolutions is over-parametrized.
+ assert channel_multiplier * in_channels < out_channels
+ # The layout of the ops in the graph are expected to be as follows:
+ # separable_conv2d // Conv2D op corresponding to the pointwise conv.
+ # separable_conv2d/depthwise // Concat op for the deptwise outputs.
+ # separable_conv2d/depthwise/depth0 // Conv2D op for depth 0
+ # separable_conv2d/depthwise/depth1 // Conv2D op for depth 1
+ # separable_conv2d/depthwise/depth2 // Conv2D op for depth 2
+ depthwise = depthwise_conv2d(input, depthwise_filter, strides,
+ padding, name="depthwise")
+ return nn_ops.conv2d(depthwise, pointwise_filter, [1, 1, 1, 1],
+ padding="VALID", name=name)
+
+
+def moments(x, axes, name=None):
+ """Calculate the mean and variance of `x`.
+
+ The mean and variance are calculated by aggregating the contents of `x`
+ across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean
+ and variance of a vector.
+
+ For so-called "global normalization" needed for convolutional filters pass
+ `axes=[0, 1, 2]` (batch, height, width). For batch normalization pass
+ `axes=[0]` (batch).
+
+ Args:
+ x: A `Tensor`.
+ axes: array of ints. Axes along which to compute mean and
+ variance.
+ name: Name used to scope the operations that compute the moments.
+
+ Returns:
+ Two `Tensors`: `mean` and `variance`.
+ """
+ with ops.op_scope([x, axes], name, "moments"):
+ x = ops.convert_to_tensor(x, name="x")
+ divisor = 1.0
+ for d in xrange(len(x.get_shape())):
+ if d in axes:
+ divisor *= x.get_shape()[d].value
+ divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
+ axes = constant_op.constant(axes, name="axes")
+ # Note: We do not use Mean here because it is very slow on GPU.
+ # Note 2: The expression below is potentially more stable.
+ # It is however a bit slower and stability doesn't appear to be an issue.
+ # mean = math_ops.reduce_sum(math_ops.mul(x, divisor), axes, name="mean")
+ # var = math_ops.reduce_sum(math_ops.mul(math_ops.square(x - mean),
+ # divisor), axes,
+ # name="variance")
+ mean = math_ops.mul(math_ops.reduce_sum(x, axes), divisor, name="mean")
+ var = math_ops.mul(math_ops.reduce_sum(math_ops.square(x - mean), axes),
+ divisor, name="variance")
+ return mean, var
+
+
+def _sum_rows(x):
+ """Returns a vector summing up each row of the matrix x."""
+ # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
+ # a matrix. The gradient of _sum_rows(x) is more efficient than
+ # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
+ # we use _sum_rows(x) in the nce_loss() computation since the loss
+ # is mostly used for training.
+ cols = array_ops.shape(x)[1]
+ ones_shape = array_ops.pack([cols, 1])
+ ones = array_ops.ones(ones_shape, x.dtype)
+ return array_ops.reshape(math_ops.matmul(x, ones), [-1])
+
+
+def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
+ num_classes, num_true=1,
+ sampled_values=None,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ name=None):
+ """Helper function for nce_loss and sampled_softmax_loss functions.
+
+ Computes sampled output training logits and labels suitable for implementing
+ e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
+ sampled_softmax_loss).
+
+ Note: In the case where num_true > 1, we assign to each target class
+ the target probability 1 / num_true so that the target probabilities
+ sum to 1 per-example.
+
+ Args:
+ weights: tensor of label embeddings with shape = [num_classes, dim]
+ biases: tensor of num_classes label biases
+ inputs: tensor with shape = [batch_size, dim] corresponding to forward
+ activations of the input network
+ labels: int tensor with shape [batch_size, num_true]
+ num_sampled: number of label classes to sample per batch
+ num_classes: number of possible label classes in the data (e.g. vocab size)
+ num_true: number of target classes per example (default: 1)
+ sampled_values: a tuple of (sampled_candidates, true_expected_count,
+ sampled_expected_count) returned by a *CandidateSampler function to use
+ (if None, we default to LogUniformCandidateSampler)
+ subtract_log_q: subtract the log expected count of the labels in the sample
+ to get the logits of the true labels (default: True)
+ Turn off for Negative Sampling.
+ remove_accidental_hits: whether to remove "accidental hits" where a sampled
+ label equals the true labels (bool, default: False)
+ name: name for this op
+
+ Returns:
+ out_logits, out_labels: tensors with shape [batch_size, num_true +
+ num_sampled] for passing to either SigmoidCrossEntropyWithLogits (NCE)
+ or SoftmaxCrossEntropyWithLogits (sampled softmax).
+
+ """
+
+ with ops.op_scope(
+ [weights, biases, inputs, labels], name, "compute_sampled_logits"):
+ if labels.dtype != types.int64:
+ labels = math_ops.cast(labels, types.int64)
+ labels_flat = array_ops.reshape(labels, [-1])
+
+ # Sample the negative labels.
+ # sampled shape: num_sampled vector
+ # true_expected_count shape = [batch_size, 1]
+ # sampled_expected_count shape = num_sampled vector
+ if sampled_values is None:
+ sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
+ true_classes=labels,
+ num_true=num_true,
+ num_sampled=num_sampled,
+ unique=True,
+ range_max=num_classes)
+ # NOTE: pylint cannot tell that 'sampled_values' is a sequence
+ # pylint: disable=unpacking-non-sequence
+ sampled, true_expected_count, sampled_expected_count = sampled_values
+ # pylint: enable=unpacking-non-sequence
+
+ # weights shape is [num_classes, dim]
+ # labels_flat is a [batch_size * num_true] vector
+ # true_w shape is [batch_size * num_true, dim]
+ # true_b is a [batch_size * num_true] vector
+ true_w = embedding_ops.embedding_lookup(weights, labels_flat)
+ true_b = embedding_ops.embedding_lookup(biases, labels_flat)
+
+ # inputs shape is [batch_size, dim]
+ # true_w shape is [batch_size * num_true, dim]
+ # row_wise_dots is [batch_size, num_true, dim]
+ dim = array_ops.shape(true_w)[1:2]
+ new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
+ row_wise_dots = math_ops.mul(
+ array_ops.expand_dims(inputs, 1),
+ array_ops.reshape(true_w, new_true_w_shape))
+ # We want the row-wise dot plus biases which yields a
+ # [batch_size, num_true] tensor of true_logits.
+ dots_as_matrix = array_ops.reshape(row_wise_dots,
+ array_ops.concat(0, [[-1], dim]))
+ true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
+ true_b = array_ops.reshape(true_b, [-1, num_true])
+ true_logits += true_b
+
+ # Lookup weights and biases for sampled labels.
+ # sampled is a num_sampled int vector
+ # sampled_w shape is [num_sampled, dim]
+ # sampled_b is a num_sampled float vector
+ sampled_w = embedding_ops.embedding_lookup(weights, sampled)
+ sampled_b = embedding_ops.embedding_lookup(biases, sampled)
+
+ # inputs has shape [batch_size, dim]
+ # sampled_w has shape [num_sampled, dim]
+ # sampled_b has shape [num_sampled]
+ # Apply X*W'+B, which yields [batch_size, num_sampled]
+ sampled_logits = math_ops.matmul(inputs,
+ sampled_w,
+ transpose_b=True) + sampled_b
+
+ if remove_accidental_hits:
+ acc_hits = candidate_sampling_ops.compute_accidental_hits(
+ labels, sampled, num_true=num_true)
+ acc_indices, acc_ids, acc_weights = acc_hits
+
+ # This is how SparseToDense expects the indices.
+ acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
+ acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
+ acc_ids, types.int32), [-1, 1])
+ sparse_indices = array_ops.concat(
+ 1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
+ # Create sampled_logits_shape = [batch_size, num_sampled]
+ sampled_logits_shape = array_ops.concat(
+ 0,
+ [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
+ sampled_logits += sparse_ops.sparse_to_dense(
+ sparse_indices, sampled_logits_shape, acc_weights, 0.0)
+
+ if subtract_log_q:
+ # Subtract log of Q(l), prior probability that l appears in sampled.
+ true_logits -= math_ops.log(true_expected_count)
+ sampled_logits -= math_ops.log(sampled_expected_count)
+
+ # Construct output logits and labels. The true labels/logits start at col 0.
+ out_logits = array_ops.concat(1, [true_logits, sampled_logits])
+ # true_logits is a float tensor, ones_like(true_logits) is a float tensor
+ # of ones. We then divide by num_true to ensure the per-example labels sum
+ # to 1.0, i.e. form a proper probability distribution.
+ out_labels = array_ops.concat(
+ 1, [array_ops.ones_like(true_logits) / num_true,
+ array_ops.zeros_like(sampled_logits)])
+
+ return out_logits, out_labels
+
+
+def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=1,
+ sampled_values=None,
+ remove_accidental_hits=False,
+ name="nce_loss"):
+ """Computes and returns the noise-contrastive estimation training loss.
+
+ See [Noise-contrastive estimation: A new estimation principle for
+ unnormalized statistical models]
+ (http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
+ Also see our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+ Note: In the case where num_true > 1, we assign to each target class
+ the target probability 1 / num_true so that the target probabilities
+ sum to 1 per-example.
+
+ Note: It would be useful to allow a variable number of target classes per
+ example. We hope to provide this functionality in a future release.
+ For now, if you have a variable number of target classes, you can pad them
+ out to a constant number by either repeating them or by padding
+ with an otherwise unused class.
+
+ Args:
+ weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
+ biases: A `Tensor` of shape [num_classes]. The class biases.
+ inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ activations of the input network.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
+ sampled_values: a tuple of `(sampled_candidates, true_expected_count,
+ sampled_expected_count)` returned by a *_candidate_sampler function.
+ (if None, we default to LogUniformCandidateSampler)
+ remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. If set to
+ `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
+ learning to generate log-odds instead of log probabilities. See
+ our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf).
+ Default is False.
+ name: A name for the operation (optional).
+
+ Returns:
+ A batch_size 1-D tensor of per-example NCE losses.
+ """
+ logits, labels = _compute_sampled_logits(
+ weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=num_true,
+ sampled_values=sampled_values,
+ subtract_log_q=True,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ sampled_losses = sigmoid_cross_entropy_with_logits(logits,
+ labels,
+ name="sampled_losses")
+ # sampled_losses is batch_size x {true_loss, sampled_losses...}
+ # We sum out true and sampled losses.
+ return _sum_rows(sampled_losses)
+
+
+def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
+ num_classes, num_true=1,
+ sampled_values=None,
+ remove_accidental_hits=True,
+ name="sampled_softmax_loss"):
+ """Computes and returns the sampled softmax training loss.
+
+ This is a faster way to train a softmax classifier over a huge number of
+ classes.
+
+ This operation is for training only. It is generally an underestimate of
+ the full softmax loss.
+
+ At inference time, you can compute full softmax probabilities with the
+ expression `tf.nn.softmax(tf.matmul(inputs, weights) + biases)`.
+
+ See our [Candidate Sampling Algorithms Reference]
+ (http://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
+
+ Args:
+ weights: A `Tensor` of shape [num_classes, dim]. The class embeddings.
+ biases: A `Tensor` of shape [num_classes]. The class biases.
+ inputs: A `Tensor` of shape [batch_size, dim]. The forward
+ activations of the input network.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes. Note that this format differs from
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
+ sampled_values: a tuple of `(sampled_candidates, true_expected_count,
+ sampled_expected_count)` returned by a *_candidate_sampler function.
+ (if None, we default to LogUniformCandidateSampler)
+ remove_accidental_hits: A `bool`. whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. Default is
+ True.
+ name: A name for the operation (optional).
+
+ Returns:
+ A batch_size 1-D tensor of per-example sampled softmax losses.
+
+ """
+ logits, labels = _compute_sampled_logits(
+ weights, biases, inputs, labels, num_sampled, num_classes,
+ num_true=num_true,
+ sampled_values=sampled_values,
+ subtract_log_q=True,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
+ # sampled_losses is a batch_size vector.
+ return sampled_losses
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
new file mode 100644
index 0000000000..0cf867d217
--- /dev/null
+++ b/tensorflow/python/ops/nn_grad.py
@@ -0,0 +1,229 @@
+"""Gradients for operators defined in nn_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import gen_nn_ops
+
+
+@ops.RegisterGradient("Conv2DBackpropInput")
+def _DeConv2DGrad(op, grad):
+ """The derivatives for deconvolution.
+
+ Args:
+ op: the Deconvolution op.
+ grad: the tensor representing the gradient w.r.t. the output
+
+ Returns:
+ the gradients w.r.t. the input and the filter
+ """
+ return [None,
+ nn_ops.conv2d_backprop_filter(grad,
+ array_ops.shape(op.inputs[1]),
+ op.inputs[2],
+ op.get_attr("strides"),
+ op.get_attr("padding")),
+ nn_ops.conv2d(grad,
+ op.inputs[1],
+ op.get_attr("strides"),
+ op.get_attr("padding"))]
+
+
+@ops.RegisterGradient("Softmax")
+def _SoftmaxGrad(op, grad_softmax):
+ """The derivative of the softmax nonlinearity.
+
+ We assume that probs is of shape [batch_size * dim]
+ The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
+ This matrix is diagonal minus a rank one matrix, so it is easy to implement
+ as follows:
+
+ grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
+
+ Args:
+ op: the Softmax op.
+ grad_softmax: the tensor representing the gradient w.r.t. the
+ softmax output.
+
+ Returns:
+ gradient w.r.t the input to the softmax
+
+ """
+ # TODO(ilyasu): assert that the tensor has two dimensions at
+ # graph-construction time? Alternatively: do different things
+ # depending on the dimensionality of the input tensors.
+ softmax = op.outputs[0]
+ grad_x = ((grad_softmax -
+ array_ops.reshape(math_ops.reduce_sum(grad_softmax * softmax, [1]),
+ [-1, 1]))
+ * softmax)
+ return grad_x
+
+
+@ops.RegisterGradient("BiasAdd")
+def _BiasAddGrad(unused_bias_op, received_grad):
+ """Return the gradients for the 2 inputs of bias_op.
+
+ The first input of unused_bias_op is the tensor t, and its gradient is
+ just the gradient the unused_bias_op received.
+
+ The second input of unused_bias_op is the bias vector which has one fewer
+ dimension than "received_grad" (the batch dimension.) Its gradient is the
+ received gradient Summed on the batch dimension, which is the first dimension.
+
+ Args:
+ unused_bias_op: The BiasOp for which we need to generate gradients.
+ received_grad: Tensor. The gradients passed to the BiasOp.
+
+ Returns:
+ Two tensors, the first one for the "tensor" input of the BiasOp,
+ the second one for the "bias" input of the BiasOp.
+ """
+ reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1)
+ return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
+
+
+def _VerifyTensor(t, name, msg):
+ """Assert that the tensor does not contain any NaN's.
+
+ Args:
+ t: Tensor
+ name: name
+ msg: message to log
+ Returns:
+ Tensor, but verified
+ """
+ with ops.name_scope(name):
+ with ops.device(t.device or ops.get_default_graph().get_default_device()):
+ verify_input = array_ops.check_numerics(t, message=msg)
+ out = control_flow_ops.with_dependencies([verify_input], t)
+ return out
+
+
+@ops.RegisterGradient("Relu")
+def _ReluGrad(op, grad):
+ t = _VerifyTensor(op.inputs[0], op.name, "ReluGrad input is not finite.")
+ return gen_nn_ops._relu_grad(grad, t)
+
+
+@ops.RegisterGradient("Relu6")
+def _Relu6Grad(op, grad):
+ return gen_nn_ops._relu6_grad(grad, op.inputs[0])
+
+
+@ops.RegisterGradient("Softplus")
+def _SoftplusGrad(op, grad):
+ return gen_nn_ops._softplus_grad(grad, op.inputs[0])
+
+
+@ops.RegisterGradient("ReluGrad")
+def _ReluGradGrad(op, grad):
+ x = op.inputs[1]
+ return (gen_nn_ops._relu_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
+def _BroadcastMul(vec, mat):
+ """Multiply after broadcasting vec to match dimensions of mat.
+
+ Args:
+ vec: A 1-D tensor of dimension [D0]
+ mat: A 2-D tensor of dimension [D0, D1]
+
+ Returns:
+ A tensor of dimension [D0, D1], the result of vec * mat
+ """
+ # Reshape vec to [D0, 1]
+ vec = array_ops.expand_dims(vec, -1)
+ return vec * mat
+
+
+@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
+def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
+ # grad_0 is the backprop for cost, and we multiply it with the gradients
+ # (which is output[1])
+ # There is no gradient for the labels
+ return _BroadcastMul(grad_0, op.outputs[1]), None
+
+
+@ops.RegisterGradient("Conv2D")
+def _Conv2DGrad(op, grad):
+ return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]),
+ op.inputs[1],
+ grad,
+ op.get_attr("strides"),
+ op.get_attr("padding")),
+ nn_ops.conv2d_backprop_filter(op.inputs[0],
+ array_ops.shape(op.inputs[1]),
+ grad,
+ op.get_attr("strides"),
+ op.get_attr("padding"))]
+
+
+@ops.RegisterGradient("LRN")
+def _LRNGrad(op, grad):
+ depth_radius = op.get_attr("depth_radius")
+ bias = op.get_attr("bias")
+ alpha = op.get_attr("alpha")
+ beta = op.get_attr("beta")
+ return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0],
+ depth_radius, bias, alpha, beta)]
+
+
+@ops.RegisterGradient("AvgPool")
+def _AvgPoolGrad(op, grad):
+ return gen_nn_ops._avg_pool_grad(array_ops.shape(op.inputs[0]), grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"))
+
+
+@ops.RegisterGradient("MaxPool")
+def _MaxPoolGrad(op, grad):
+ return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
+
+
+@ops.RegisterGradient("BatchNormWithGlobalNormalization")
+def _BatchNormWithGlobalNormalizationGrad(op, grad):
+ """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
+
+ We do not backprop anything for the mean and var intentionally as they are
+ not being trained with backprop in the operation.
+
+ Args:
+ op: The BatchNormOp for which we need to generate gradients.
+ grad: Tensor. The gradients passed to the BatchNormOp.
+
+ Returns:
+ dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
+ dm: Backprop for mean, which is
+ sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
+ dv: Backprop for variance, which is
+ sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
+ db: Backprop for beta, which is grad reduced in all except the
+ last dimension.
+ dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
+ """
+ dx, dm, dv, db, dg = gen_nn_ops._batch_norm_with_global_normalization_grad(
+ op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
+ op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
+ return dx, dm, dv, db, dg
+
+
+@ops.RegisterGradient("L2Loss")
+def _L2LossGrad(op, grad):
+ """Return the gradients for L2Loss.
+
+ Args:
+ op: The L2LossOp for which we need to generate gradients.
+ grad: Tensor containing a single number.
+
+ Returns:
+ The gradient, which is (x * grad).
+ """
+ return op.inputs[0] * grad
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
new file mode 100644
index 0000000000..0ffe95de2b
--- /dev/null
+++ b/tensorflow/python/ops/nn_ops.py
@@ -0,0 +1,365 @@
+"""Wrappers for primitive Neural Net (NN) Operations."""
+
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_nn_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_nn_ops import *
+
+
+# Aliases for some automatically-generated names.
+local_response_normalization = gen_nn_ops.lrn
+
+
+def deconv2d(value, filter, output_shape, strides, padding="SAME",
+ name=None):
+ """The transpose of `conv2d`.
+
+ This used to be called "deconvolution", but it is actually the transpose
+ (gradient) of `conv2d`, not an actual deconvolution.
+
+ Args:
+ value: A 4-D `Tensor` of type `float` and shape
+ `[batch, height, width, in_channels]`.
+ filter: A 4-D `Tensor` with the same type as `value` and shape
+ `[height, width, output_channels, in_channels]`. `filter`'s
+ `in_channels` dimension must match that of `value`.
+ output_shape: A 1-D `Tensor` representing the output shape of the
+ deconvolution op.
+ strides: A list of ints. The stride of the sliding window for each
+ dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the returned tensor.
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+
+ Raises:
+ ValueError: If input/output depth does not match `filter`'s shape, or if
+ padding is other than `'VALID'` or `'SAME'`.
+ """
+ with ops.op_scope([value, filter, output_shape], name, "DeConv2D") as name:
+ value = ops.convert_to_tensor(value, name="value")
+ filter = ops.convert_to_tensor(filter, name="filter")
+ if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]):
+ raise ValueError(
+ "input channels does not match filter's input channels, "
+ "{} != {}".format(value.get_shape()[3], filter.get_shape()[3]))
+
+ output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
+ if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
+ raise ValueError("output_shape must have shape (4,), got {}"
+ .format(output_shape_.get_shape()))
+
+ if isinstance(output_shape, (list, np.ndarray)):
+ # output_shape's shape should be == [4] if reached this point.
+ if not filter.get_shape()[2].is_compatible_with(output_shape[3]):
+ raise ValueError(
+ "output_shape does not match filter's output channels, "
+ "{} != {}".format(output_shape[3], filter.get_shape()[2]))
+
+ if padding != "VALID" and padding != "SAME":
+ raise ValueError("padding must be either VALID or SAME:"
+ " {}".format(padding))
+
+ return gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_,
+ filter=filter,
+ out_backprop=value,
+ strides=strides,
+ padding=padding,
+ name=name)
+
+# pylint: disable=protected-access
+def bias_add(value, bias, name=None):
+ """Adds `bias` to `value`.
+
+ This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
+ Broadcasting is supported, so `value` may have any number of dimensions.
+ Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
+ case where both types are quantized.
+
+ Args:
+ value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
+ `int16`, `int8`, or `complex64`.
+ bias: A 1-D `Tensor` with size matching the last dimension of `value`.
+ Must be the same type as `value` unless `value` is a quantized type,
+ in which case a different quantized type may be used.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+ """
+ with ops.op_scope([value, bias], name, "BiasAdd") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
+ return gen_nn_ops._bias_add(value, bias, name=name)
+
+
+ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
+
+
+
+def relu6(features, name=None):
+ """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
+
+ Args:
+ features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
+ `int16`, or `int8`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with the same type as `features`.
+ """
+ with ops.op_scope([features], name, "Relu6") as name:
+ features = ops.convert_to_tensor(features, name="features")
+ return gen_nn_ops._relu6(features, name=name)
+
+
+def softmax_cross_entropy_with_logits(logits, labels, name=None):
+ """Computes softmax cross entropy between `logits` and `labels`.
+
+ Measures the probability error in discrete classification tasks in which the
+ classes are mutually exclusive (each entry is in exactly one class). For
+ example, each CIFAR-10 image is labeled with one and only one label: an image
+ can be a dog or a truck, but not both.
+
+ **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+ on `logits` internally for efficiency. Do not call this op with the
+ output of `softmax`, as it will produce incorrect results.
+
+ `logits` and `labels` must have the same shape `[batch_size, num_classes]`
+ and the same dtype (either `float32` or `float64`).
+
+ Args:
+ logits: Unscaled log probabilities.
+ labels: Each row `labels[i]` must be a valid probability distribution.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+ softmax cross entropy loss.
+ """
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+ logits, labels, name=name)
+ return cost
+
+
+@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
+def _SoftmaxCrossEntropyWithLogitsShape(op):
+ """Shape function for SoftmaxCrossEntropyWithLogits op."""
+ logits_shape = op.inputs[0].get_shape()
+ labels_shape = op.inputs[1].get_shape()
+ input_shape = logits_shape.merge_with(labels_shape).with_rank(2)
+ batch_size = input_shape[0]
+ return [tensor_shape.vector(batch_size.value), input_shape]
+
+
+def avg_pool(value, ksize, strides, padding, name=None):
+ """Performs the average pooling on the input.
+
+ Each entry in `output` is the mean of the corresponding size `ksize`
+ window in `value`.
+
+ Args:
+ value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
+ `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
+ ksize: A list of ints that has length >= 4.
+ The size of the window for each dimension of the input tensor.
+ strides: A list of ints that has length >= 4.
+ The stride of the sliding window for each dimension of the
+ input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` with the same type as `value`. The average pooled output tensor.
+ """
+ with ops.op_scope([value], name, "AvgPool") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ return gen_nn_ops._avg_pool(value, ksize=ksize, strides=strides,
+ padding=padding,
+ name=name)
+
+
+def max_pool(value, ksize, strides, padding, name=None):
+ """Performs the max pooling on the input.
+
+ Args:
+ value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and
+ type `float32`, `float64`, `qint8`, `quint8`, `qint32`.
+ ksize: A list of ints that has length >= 4. The size of the window for
+ each dimension of the input tensor.
+ strides: A list of ints that has length >= 4. The stride of the sliding
+ window for each dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` with the same type as `value`. The max pooled output tensor.
+ """
+ with ops.op_scope([value], name, "MaxPool") as name:
+ value = ops.convert_to_tensor(value, name="input")
+ return gen_nn_ops._max_pool(value, ksize=ksize, strides=strides,
+ padding=padding,
+ name=name)
+
+
+ops.RegisterShape("Relu")(common_shapes.unchanged_shape)
+ops.RegisterShape("Relu6")(common_shapes.unchanged_shape)
+ops.RegisterShape("Softplus")(common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("ReluGrad")
+@ops.RegisterShape("Relu6Grad")
+@ops.RegisterShape("SoftplusGrad")
+def _BinaryElementwiseShape(op):
+ """Returns same shape as both inputs to op.
+
+ Args:
+ op: Input operation.
+
+ Returns:
+ Shape of both inputs to `op`.
+ """
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+
+
+ops.RegisterShape("L2Loss")(common_shapes.scalar_shape)
+
+
+ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4))
+
+
+@ops.RegisterShape("LRNGrad")
+def _LRNGradShape(op):
+ """Shape function for LRNGrad op."""
+ in_grads_shape = op.inputs[0].get_shape().with_rank(4)
+ in_image_shape = op.inputs[1].get_shape().with_rank(4)
+ out_image_shape = op.inputs[2].get_shape().with_rank(4)
+ return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
+
+
+ops.RegisterShape("Softmax")(
+ common_shapes.unchanged_shape_with_rank(2))
+
+
+@ops.RegisterShape("InTopK")
+def _InTopKShape(op):
+ """Shape function for InTopK op."""
+ predictions_shape = op.inputs[0].get_shape().with_rank(2)
+ targets_shape = op.inputs[1].get_shape().with_rank(1)
+ batch_size = predictions_shape[0].merge_with(targets_shape[0])
+ return [tensor_shape.vector(batch_size.value)]
+
+
+@ops.RegisterShape("TopK")
+def _TopKShape(op):
+ """Shape function for TopK op."""
+ input_shape = op.inputs[0].get_shape().with_rank(2)
+ k = op.get_attr("k")
+ num_rows = input_shape[0]
+ num_cols = input_shape[1]
+ if num_cols.value is not None and num_cols.value < k:
+ raise ValueError("input must have at least k (%d) columns" % k)
+ return [tensor_shape.TensorShape([num_rows, k]),
+ tensor_shape.TensorShape([num_rows, k])]
+
+
+@ops.RegisterShape("BatchNormWithGlobalNormalization")
+def _BatchNormShape(op):
+ """Shape function for BatchNormWithGlobalNormalization op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ mean_shape = op.inputs[1].get_shape().with_rank(1)
+ var_shape = op.inputs[2].get_shape().with_rank(1)
+ beta_shape = op.inputs[3].get_shape().with_rank(1)
+ gamma_shape = op.inputs[4].get_shape().with_rank(1)
+ mean_shape[0].merge_with(input_shape[3])
+ var_shape[0].merge_with(input_shape[3])
+ beta_shape[0].merge_with(input_shape[3])
+ gamma_shape[0].merge_with(input_shape[3])
+ return [input_shape]
+
+
+@ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")
+def _BatchNormGradShape(op):
+ """Shape function for BatchNormWithGlobalNormalizationGrad op."""
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ mean_shape = op.inputs[1].get_shape().with_rank(1)
+ var_shape = op.inputs[2].get_shape().with_rank(1)
+ beta_shape = op.inputs[3].get_shape().with_rank(1)
+ out_backprop_shape = op.inputs[4].get_shape().with_rank(4)
+ input_shape = input_shape.merge_with(out_backprop_shape)
+ vector_dim = input_shape[3]
+ vector_dim = vector_dim.merge_with(mean_shape[0])
+ vector_dim = vector_dim.merge_with(var_shape[0])
+ vector_dim = vector_dim.merge_with(beta_shape[0])
+ return [input_shape] + ([tensor_shape.vector(vector_dim)] * 4)
+
+
+ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape)
+ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
+ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
+
+
+@ops.RegisterShape("MaxPoolWithArgmax")
+def _MaxPoolWithArgMaxShape(op):
+ """Shape function for MaxPoolWithArgmax op."""
+ return common_shapes.max_pool_shape(op) * 2
+
+
+@ops.RegisterShape("AvgPoolGrad")
+def _AvgPoolGradShape(op):
+ """Shape function for the AvgPoolGrad op."""
+ orig_input_shape = tensor_util.ConstantValue(op.inputs[0])
+ if orig_input_shape is not None:
+ return [tensor_shape.TensorShape(orig_input_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know orig_input_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("Conv2DBackpropFilter")
+def _Conv2DBackpropFilterShape(op):
+ """Shape function for the Conv2DBackpropFilter op."""
+ filter_shape = tensor_util.ConstantValue(op.inputs[1])
+ if filter_shape is not None:
+ return [tensor_shape.TensorShape(filter_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know filter_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("Conv2DBackpropInput")
+def _Conv2DBackpropInputShape(op):
+ """Shape function for the Conv2DBackpropInput op."""
+ input_shape = tensor_util.ConstantValue(op.inputs[0])
+ if input_shape is not None:
+ return [tensor_shape.TensorShape(input_shape.tolist())]
+ else:
+ # NOTE(mrry): We could in principle work out the shape from the
+ # gradients and the attrs, but if we do not know input_shape
+ # statically, then we are unlikely to know the shape of the
+ # gradients either.
+ return [tensor_shape.unknown_shape(ndims=4)]
+
+
+@ops.RegisterShape("MaxPoolGrad")
+@ops.RegisterShape("MaxPoolGradWithArgmax")
+def _MaxPoolGradShape(op):
+ """Shape function for the MaxPoolGrad op."""
+ orig_input_shape = op.inputs[0].get_shape().with_rank(4)
+ return [orig_input_shape]
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
new file mode 100644
index 0000000000..11ce56e359
--- /dev/null
+++ b/tensorflow/python/ops/nn_test.py
@@ -0,0 +1,882 @@
+"""Tests for tensorflow.ops.nn."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.kernel_tests import gradient_checker as gc
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_grad
+from tensorflow.python.platform import googletest
+
+exp = math.exp
+log = math.log
+
+
+class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
+
+ def _SigmoidCrossEntropyWithLogits(self, logits, targets):
+ assert len(logits) == len(targets)
+ pred = [1 / (1 + exp(-x)) for x in logits]
+ eps = 0.0001
+ pred = [min(max(p, eps), 1 - eps) for p in pred]
+ return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
+
+ def _Inputs(self, x=None, y=None, dtype=types.float64, sizes=None):
+ x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
+ y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
+ assert len(x) == len(y)
+ sizes = sizes if sizes else [len(x)]
+ logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
+ targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
+ losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes)
+ return logits, targets, losses
+
+ def testConstructionNamed(self):
+ with self.test_session():
+ logits, targets, _ = self._Inputs()
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets,
+ name="mylogistic")
+ self.assertEqual("mylogistic", loss.op.name)
+
+ def testLogisticOutput(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=types.float32)
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
+
+ def testLogisticOutputMultiDim(self):
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ logits, targets, losses = self._Inputs(dtype=types.float32,
+ sizes=[2, 2, 2])
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ np_loss = np.array(losses).astype(np.float32)
+ tf_loss = loss.eval()
+ self.assertAllClose(np_loss, tf_loss, atol=0.001)
+
+ def testGradient(self):
+ sizes = [4, 2]
+ with self.test_session():
+ logits, targets, _ = self._Inputs(sizes=sizes)
+ loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ err = gc.ComputeGradientError(logits, sizes, loss, sizes)
+ print "logistic loss gradient err = ", err
+ self.assertLess(err, 1e-7)
+
+
+class ZeroFractionTest(test_util.TensorFlowTestCase):
+
+ def _ZeroFraction(self, x):
+ assert x.shape
+ total_elements = float(np.prod(x.shape))
+ nonzeros = float(np.count_nonzero(x.flatten()))
+ return 1.0 - (nonzeros / total_elements)
+
+ def testZeroFraction(self):
+ x_shape = [5, 17]
+ x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
+ y_np = self._ZeroFraction(x_np)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ x_tf.set_shape(x_shape)
+ y_tf = nn.zero_fraction(x_tf)
+ y_tf_np = y_tf.eval()
+ eps = 1e-8
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ def testZeroFractionEmpty(self):
+ with self.test_session():
+ x = np.zeros(0)
+ y = nn.zero_fraction(x).eval()
+ self.assertTrue(np.isnan(y))
+
+
+class SoftmaxTest(test_util.TensorFlowTestCase):
+
+ def _softmax(self, x):
+ assert len(x.shape) == 2
+ m = x.max(1)[:, np.newaxis]
+ u = np.exp(x - m)
+ z = u.sum(1)[:, np.newaxis]
+ return u / z
+
+ def testSoftmax(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ y_np = self._softmax(x_np)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn.softmax(x_tf)
+ y_tf_np = y_tf.eval()
+ eps = 1e-3
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ def testGradient(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float64)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn.softmax(x_tf)
+ err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape)
+ eps = 1e-8
+ self.assertLess(err, eps)
+
+
+class DeConv2DTest(test_util.TensorFlowTestCase):
+
+ def testDeConv2DSingleStride(self):
+ with self.test_session():
+ strides = [1, 1, 1, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 6, 4, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ value = output.eval()
+
+ # We count the number of cells being added at the locations in the output.
+ # At the center, #cells=kernel_height * kernel_width
+ # At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2)
+ # At the borders, #cells=ceil(kernel_height/2)*kernel_width or
+ # kernel_height * ceil(kernel_width/2)
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(y_shape[2]):
+ for h in xrange(y_shape[1]):
+ target = 4 * 3.0
+ h_in = h > 0 and h < y_shape[1] - 1
+ w_in = w > 0 and w < y_shape[2] - 1
+ if h_in and w_in:
+ target += 5 * 3.0
+ elif h_in or w_in:
+ target += 2 * 3.0
+ self.assertAllClose(target, value[n, h, w, k])
+
+ def testDeConv2DSame(self):
+ with self.test_session():
+ strides = [1, 2, 2, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 12, 8, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ value = output.eval()
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(y_shape[2]):
+ for h in xrange(y_shape[1]):
+ target = 3.0
+ # We add a case for locations divisible by the stride.
+ h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1
+ w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
+ if h_in and w_in:
+ target += 9.0
+ elif h_in or w_in:
+ target += 3.0
+ self.assertAllClose(target, value[n, h, w, k])
+
+ def testDeConv2DValid(self):
+ with self.test_session():
+ strides = [1, 2, 2, 1]
+
+ # Input, output: [batch, height, width, depth]
+ x_shape = [2, 6, 4, 3]
+ y_shape = [2, 13, 9, 2]
+
+ # Filter: [kernel_height, kernel_width, output_depth, input_depth]
+ f_shape = [3, 3, 2, 3]
+
+ x = constant_op.constant(1.0, shape=x_shape, name="x",
+ dtype=types.float32)
+ f = constant_op.constant(1.0, shape=f_shape, name="filter",
+ dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID")
+ value = output.eval()
+
+ cache_values = np.zeros(y_shape, dtype=np.float32)
+
+ # The amount of padding added
+ pad = 1
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[2]):
+ for w in xrange(pad, y_shape[2] - pad):
+ for h in xrange(pad, y_shape[1] - pad):
+ target = 3.0
+ # We add a case for locations divisible by the stride.
+ h_in = h % strides[
+ 1] == 0 and h > pad and h < y_shape[1] - 1 - pad
+ w_in = w % strides[
+ 2] == 0 and w > pad and w < y_shape[2] - 1 - pad
+ if h_in and w_in:
+ target += 9.0
+ elif h_in or w_in:
+ target += 3.0
+ cache_values[n, h, w, k] = target
+
+ # copy values in the border
+ cache_values[n, :, 0, k] = cache_values[n, :, 1, k]
+ cache_values[n, :, -1, k] = cache_values[n, :, -2, k]
+ cache_values[n, 0, :, k] = cache_values[n, 1, :, k]
+ cache_values[n, -1, :, k] = cache_values[n, -2, :, k]
+
+ self.assertAllClose(cache_values, value)
+
+ def testGradient(self):
+ x_shape = [2, 6, 4, 3]
+ f_shape = [3, 3, 2, 3]
+ y_shape = [2, 12, 8, 2]
+ strides = [1, 2, 2, 1]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ f_val = np.random.random_sample(f_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x", dtype=types.float32)
+ f = constant_op.constant(f_val, name="f", dtype=types.float32)
+ output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
+ err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape)
+ print "DeConv gradient err = %g " % err
+ err_tolerance = 0.0005
+ self.assertLess(err, err_tolerance)
+
+
+class L2LossTest(test_util.TensorFlowTestCase):
+
+ def testL2Loss(self):
+ with self.test_session():
+ x = constant_op.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x")
+ l2loss = nn.l2_loss(x)
+ value = l2loss.eval()
+ self.assertAllClose(7.0, value)
+
+ def testGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x")
+ output = nn.l2_loss(x)
+ err = gc.ComputeGradientError(x, x_shape, output, [1])
+ print "L2Loss gradient err = %g " % err
+ err_tolerance = 1e-11
+ self.assertLess(err, err_tolerance)
+
+
+class L2NormalizeTest(test_util.TensorFlowTestCase):
+
+ def _l2Normalize(self, x, dim):
+ norm = np.apply_along_axis(np.linalg.norm, dim, x)
+ return x / np.expand_dims(norm, dim)
+
+ def testL2Normalize(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ for dim in range(len(x_shape)):
+ y_np = self._l2Normalize(x_np, dim)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn.l2_normalize(x_tf, dim)
+ self.assertAllClose(y_np, y_tf.eval())
+
+ def testL2NormalizeGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float64)
+ for dim in range(len(x_shape)):
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn.l2_normalize(x_tf, dim)
+ err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape)
+ print "L2Normalize gradient err = %g " % err
+ self.assertLess(err, 1e-4)
+
+
+class DropoutTest(test_util.TensorFlowTestCase):
+
+ def testDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print rel_error
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability. This time with shaped
+ # noise.
+ x_dim = 40 * 30
+ y_dim = 3
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print rel_error
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropoutCorrelation(self):
+ # Runs a shaped dropout and tests that the correlations are correct.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ # Verifies that each y column as only one type of activation.
+ for i in xrange(x_dim):
+ sorted_value = np.unique(np.sort(value[i, :]))
+ self.assertEqual(sorted_value.size, 1)
+
+ def testShapedDropoutShapeError(self):
+ # Runs shaped dropout and verifies an error is thrown on misshapen noise.
+ x_dim = 40
+ y_dim = 30
+ keep_prob = 0.5
+ with self.test_session():
+ t = constant_op.constant(1.0,
+ shape=[x_dim, y_dim],
+ dtype=types.float32)
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
+ with self.assertRaises(ValueError):
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
+ # test that broadcasting proceeds
+ _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
+ _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
+
+
+class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
+
+ def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
+ scale_after_normalization):
+ y = (x - m) / np.sqrt(v + epsilon)
+ y = y * gamma if scale_after_normalization else y
+ y += beta
+ return y
+
+ def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
+ scale_after_normalization):
+ y = (x - m) * math_ops.rsqrt(v + epsilon)
+ if scale_after_normalization:
+ y = gamma * y
+ y += beta
+ return y
+
+ def testBatchNorm(self):
+ x_shape = [3, 5, 4, 2]
+ param_shape = [2]
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ m_val = np.random.random_sample(param_shape).astype(np.float32)
+ v_val = np.random.random_sample(param_shape).astype(np.float32)
+ beta_val = np.random.random_sample(param_shape).astype(np.float32)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ epsilon = 0.001
+ for scale_after_normalization in [True, False]:
+ bn = nn.batch_norm_with_global_normalization(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ on = self._opsBatchNorm(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ np_batch_norm = self._npBatchNorm(
+ x_val, m_val, v_val, beta_val, gamma_val, epsilon,
+ scale_after_normalization)
+ tf_batch_norm, ops_batch_norm = sess.run([bn, on])
+ self.assertAllClose(np_batch_norm, tf_batch_norm, atol=0.000001)
+ self.assertAllClose(np_batch_norm, ops_batch_norm, atol=0.000001)
+ self.assertAllClose(tf_batch_norm, ops_batch_norm, atol=0.000001)
+
+ def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
+ err_tolerance=1e-11):
+ x_shape = [3, 5, 4, 5]
+ param_shape = [5]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ m_val = np.random.random_sample(param_shape).astype(np.float64)
+ v_val = np.random.random_sample(param_shape).astype(np.float64)
+ beta_val = np.random.random_sample(param_shape).astype(np.float64)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float64)
+ with self.test_session():
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ epsilon = 0.001
+ # If scale_after_normalization is False, backprop for gamma
+ # will be 0. gamma is unchanged.
+ output = nn.batch_norm_with_global_normalization(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ all_params = [x, m, v, beta, gamma]
+ all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
+ err = gc.ComputeGradientError(all_params[param_index],
+ all_shapes[param_index], output, x_shape)
+ print "Batch normalization %s gradient %s scale err = " % (
+ tag, "with" if scale_after_normalization else "without"
+ ), err
+ self.assertLess(err, err_tolerance)
+
+ def testBatchNormInputGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(0, "x", scale_after_normalization)
+
+ def testBatchNormMeanGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(1, "mean", scale_after_normalization)
+
+ def testBatchNormVarianceGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(2, "variance", scale_after_normalization,
+ err_tolerance=1e-03)
+
+ def testBatchNormBetaGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(3, "beta", scale_after_normalization)
+
+ def testBatchNormGammaGradient(self):
+ for scale_after_normalization in [True, False]:
+ self._testBatchNormGradient(4, "gamma", scale_after_normalization)
+
+ def testBatchNormGradImpl(self):
+ x_shape = [7, 5, 4, 6]
+ param_shape = [6]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ m_val = np.random.random_sample(param_shape).astype(np.float32)
+ v_val = np.random.random_sample(param_shape).astype(np.float32)
+ beta_val = np.random.random_sample(param_shape).astype(np.float32)
+ gamma_val = np.random.random_sample(param_shape).astype(np.float32)
+ backprop_val = np.random.random_sample(x_shape).astype(np.float32)
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ x = constant_op.constant(x_val, name="x")
+ m = constant_op.constant(m_val, name="m")
+ v = constant_op.constant(v_val, name="v")
+ beta = constant_op.constant(beta_val, name="beta")
+ gamma = constant_op.constant(gamma_val, name="gamma")
+ backprop = constant_op.constant(backprop_val, name="backprop")
+ epsilon = 0.001
+ for scale_after_normalization in [True, False]:
+ dx, dm, dv, db, dg = (
+ gen_nn_ops._batch_norm_with_global_normalization_grad(
+ x, m, v, gamma, backprop, epsilon, scale_after_normalization))
+ on = self._opsBatchNorm(
+ x, m, v, beta, gamma, epsilon, scale_after_normalization)
+ odx, odm, odv, odb, odg = gradients.gradients(
+ [on], [x, m, v, beta, gamma], [backprop])
+ if scale_after_normalization:
+ all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
+ to_check = ["dx", "dm", "dv", "db", "dg"]
+ else:
+ all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
+ to_check = ["dx", "dm", "dv", "db"]
+ for i, n in enumerate(to_check):
+ print n
+ self.assertAllClose(
+ all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
+
+
+class MomentsTest(test_util.TensorFlowTestCase):
+
+ def RunMomentTest(self, shape, global_norm):
+ with self.test_session():
+ # shape = [batch, width, height, depth]
+ assert len(shape) == 4
+
+ x_numpy = np.random.normal(size=shape).astype(np.float32)
+ x = constant_op.constant(x_numpy)
+ x.set_shape(shape)
+ axes = [0, 1, 2] if global_norm else [0]
+ mean, var = nn.moments(x, axes)
+
+ num_elements = np.prod([shape[i] for i in axes])
+
+ ax = (0, 1, 2) if global_norm else (0)
+ expected_mean = np.sum(x_numpy, axis=ax) / num_elements
+ expected_mean_squared = np.multiply(expected_mean, expected_mean)
+ expected_x_squared = np.sum(
+ np.multiply(x_numpy, x_numpy), axis=ax) / num_elements
+ expected_variance = expected_x_squared - expected_mean_squared
+
+ # Check that the moments are correct.
+ self.assertAllClose(expected_mean, mean.eval())
+ self.assertAllClose(expected_variance, var.eval())
+
+ def testBasic(self):
+ self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=False)
+
+ def testGlobalNormalization(self):
+ self.RunMomentTest(shape=[2, 3, 5, 4], global_norm=True)
+
+ def _testGlobalGradient(self, from_y="mean"):
+ with self.test_session():
+ x_shape = [3, 5, 4, 2]
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ x = constant_op.constant(x_val)
+ x.set_shape(x_shape)
+
+ axes = [0, 1, 2]
+ y_shape = [2] # Depth of x
+ out_mean, out_var = nn.moments(x, axes)
+ if from_y == "mean":
+ y = out_mean
+ elif from_y == "var":
+ y = out_var
+ err = gc.ComputeGradientError(x, x_shape, y, y_shape)
+ print "Moments %s gradient err = %g" % (from_y, err)
+ self.assertLess(err, 1e-11)
+
+ def testMeanGlobalGradient(self):
+ self._testGlobalGradient(from_y="mean")
+
+ def testVarGlobalGradient(self):
+ self._testGlobalGradient(from_y="var")
+
+
+class ComputeSampledLogitsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._num_classes = 5
+ self._dim = 10
+ self._batch_size = 3
+
+ def _GenerateTestInputs(self):
+ np.random.seed(0)
+ weights = np.random.randn(self._num_classes, self._dim).astype(np.float32)
+ biases = np.random.randn(self._num_classes).astype(np.float32)
+ hidden_acts = np.random.randn(self._batch_size, self._dim).astype(
+ np.float32)
+
+ return weights, biases, hidden_acts
+
+ def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b,
+ hidden_acts,
+ num_true=1,
+ true_expected=None,
+ sampled_expected=None):
+
+ batch_size, dim = hidden_acts.shape
+ true_logits = np.sum(
+ hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape(
+ (batch_size, num_true, dim)),
+ axis=2)
+ true_b = true_b.reshape((batch_size, num_true))
+ true_logits += true_b
+ sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b
+
+ if true_expected is not None:
+ true_logits -= np.log(true_expected)
+ if sampled_expected is not None:
+ sampled_logits -= np.log(sampled_expected[np.newaxis, :])
+
+ out_logits = np.concatenate([true_logits, sampled_logits], axis=1)
+ out_labels = np.hstack((np.ones_like(true_logits) / num_true,
+ np.zeros_like(sampled_logits)))
+
+ return out_logits, out_labels
+
+ def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels,
+ num_sampled, num_classes, num_true, sampled_vals,
+ subtract_log_q, remove_accidental_hits,
+ name="sampled_loss_TF"):
+ # Should be called from within a `with test_session():` block
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ hidden_acts_tf = constant_op.constant(hidden_acts,
+ shape=(self._batch_size, self._dim))
+ labels_tf = constant_op.constant(labels, dtype=types.int64,
+ shape=(self._batch_size, num_true))
+
+ pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits(
+ weights_tf, biases_tf, hidden_acts_tf, labels_tf, num_sampled,
+ num_classes, num_true, sampled_vals,
+ subtract_log_q=subtract_log_q,
+ remove_accidental_hits=remove_accidental_hits,
+ name=name)
+ return pred_logits_tf, pred_labels_tf
+
+ def testComputeSampledLogitsShapes(self):
+ # We just check that the shapes of the returned values are correct.
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = sampled_exp = [1., 1., 1., 1.]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+
+ with self.test_session() as sess:
+ for num_true_test in range(1, 5):
+ labels = np.random.randint(low=0, high=self._num_classes,
+ size=self._batch_size * num_true_test)
+ true_w, true_b = weights[labels], biases[labels]
+
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test)
+
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ remove_accidental_hits=True,
+ subtract_log_q=False)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertEqual(logits_np.shape, logits_tf_val.shape)
+ self.assertEqual(labels_np.shape, labels_tf_val.shape)
+
+ def testComputeSampledLogitsValues(self):
+ # Here we check the actual numerics.
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ eps = 1e-3
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
+ true_exp.fill(0.5)
+ sampled_exp = np.empty([num_sampled], dtype=np.float32)
+ sampled_exp.fill(0.5)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session() as sess:
+ for num_true_test in range(1, 5):
+ # Generate test data for this run
+ labels = np.random.randint(low=0, high=self._num_classes,
+ size=self._batch_size * num_true_test)
+ true_w, true_b = weights[labels], biases[labels]
+
+ # Test 1: Without accidental hit removal or subtract_log_q
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test)
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ name="sampled_loss_test1_num_true%d" % num_true_test)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertAllClose(logits_np, logits_tf_val, eps)
+ self.assertAllClose(labels_np, labels_tf_val, eps)
+
+ # Test 2: With accidental hit removal, no subtract_log_q
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=True,
+ name="sampled_loss_test2_num_true%d" % num_true_test)
+
+ # Test that the exponentiated logits of accidental hits are near 0.
+ # First we need to find the hits in this random test run:
+ labels_reshape = labels.reshape((self._batch_size, num_true_test))
+ logits_tf_np = logits_tf.eval()
+ for row in xrange(self._batch_size):
+ row_labels = labels_reshape[row, :]
+ for col in xrange(num_sampled):
+ if sampled[col] in row_labels:
+ # We need to add the num_true_test offset into logits_*
+ self.assertNear(
+ np.exp(logits_tf_np[row, col + num_true_test]), 0., eps)
+
+ # Test 3: With subtract_log_q, no accidental hit removal
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ num_true=num_true_test,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ logits_tf, labels_tf = self._ComputeSampledLogitsTF(
+ weights, biases, hidden_acts, labels, num_sampled,
+ self._num_classes,
+ num_true=num_true_test,
+ sampled_vals=test_sampled_vals,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ name="sampled_loss_test3_num_true%d" % num_true_test)
+
+ logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
+ self.assertAllClose(logits_np, logits_tf_val, eps)
+ self.assertAllClose(labels_np, labels_tf_val, eps)
+
+ def testNCELoss(self):
+ # A simple test to verify the numerics.
+
+ def _SigmoidCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ pred = 1. / (1. + np.exp(-logits))
+ eps = 0.0001
+ pred = np.minimum(np.maximum(pred, eps), 1 - eps)
+ return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
+
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ labels = [0, 1, 2]
+ true_w, true_b = weights[labels], biases[labels]
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
+ true_exp.fill(0.5)
+ sampled_exp = np.empty([num_sampled], dtype=np.float32)
+ sampled_exp.fill(0.5)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session():
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ nce_loss_np = np.sum(
+ _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1)
+
+ labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1))
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ inputs_tf = constant_op.constant(hidden_acts)
+
+ nce_loss_tf = nn.nce_loss(
+ weights_tf, biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals)
+
+ self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
+
+ def testSampledSoftmaxLoss(self):
+ # A simple test to verify the numerics.
+
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(logits - np.amax(
+ logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ weights, biases, hidden_acts = self._GenerateTestInputs()
+ labels = [0, 1, 2]
+ true_w, true_b = weights[labels], biases[labels]
+ sampled = [1, 0, 2, 3]
+ num_sampled = len(sampled)
+ true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32)
+ sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32)
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ test_sampled_vals = (sampled, true_exp, sampled_exp)
+
+ with self.test_session():
+ logits_np, labels_np = self._ComputeSampledLogitsNP(
+ true_w, true_b, sampled_w, sampled_b, hidden_acts,
+ true_expected=true_exp,
+ sampled_expected=sampled_exp)
+ sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np,
+ labels_np)
+
+ labels_tf = constant_op.constant(labels, shape=(self._batch_size, 1))
+ weights_tf = constant_op.constant(weights)
+ biases_tf = constant_op.constant(biases)
+ inputs_tf = constant_op.constant(hidden_acts)
+
+ sampled_softmax_loss_tf = nn.sampled_softmax_loss(
+ weights_tf, biases_tf, inputs_tf, labels_tf,
+ num_sampled=1,
+ num_classes=self._num_classes,
+ num_true=1,
+ sampled_values=test_sampled_vals,
+ remove_accidental_hits=False)
+
+ self.assertAllClose(
+ sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
new file mode 100644
index 0000000000..93f5d5db20
--- /dev/null
+++ b/tensorflow/python/ops/numerics.py
@@ -0,0 +1,50 @@
+"""Connects all float and double tensors to CheckNumericsOp."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+def verify_tensor_all_finite(t, msg, name=None):
+ """Assert that the tensor does not contain any NaN's or Inf's.
+
+ Args:
+ t: Tensor to check.
+ msg: Message to log on failure.
+ name: A name for this operation (optional).
+
+ Returns:
+ Same tensor as `t`.
+ """
+ with ops.op_scope([t], name, "VerifyFinite") as name:
+ t = ops.convert_to_tensor(t, name="t")
+ with ops.device(t.device or t.graph.get_default_device()):
+ verify_input = array_ops.check_numerics(t, message=msg)
+ out = control_flow_ops.with_dependencies([verify_input], t)
+ return out
+
+
+def add_check_numerics_ops():
+ """Connect a check_numerics to every floating point tensor.
+
+ `check_numerics` operations themselves are added for each `float` or `double`
+ tensor in the graph. For all ops in the graph, the `check_numerics` op for
+ all of its (`float` or `double`) inputs is guaranteed to run before the
+ `check_numerics` op on any of its outputs.
+
+ Returns:
+ A `group` op depending on all `check_numerics` ops added.
+ """
+ check_op = []
+ # This code relies on the ordering of ops in get_operations().
+ # The consumer of a tensor always comes before that tensor's producer in
+ # this list. This is true because get_operations() returns ops in the order
+ # added, and ops can only be added once its inputs are added.
+ for op in ops.get_default_graph().get_operations():
+ for output in op.outputs:
+ if output.dtype in [types.float32, types.float64]:
+ message = op.name + ":" + str(output.value_index)
+ with ops.control_dependencies(check_op):
+ check_op = [array_ops.check_numerics(output, message=message)]
+ return control_flow_ops.group(*check_op)
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
new file mode 100644
index 0000000000..5947b6df89
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library.py
@@ -0,0 +1,640 @@
+"""Class to hold a library of OpDefs and use it to create Brain operations."""
+
+import numbers
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types as types_lib
+from tensorflow.python.ops import constant_op
+from tensorflow.python.platform import logging
+
+
+def _Attr(op_def, name):
+ for attr in op_def.attr:
+ if attr.name == name:
+ return attr
+ raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
+ (op_def.name, name))
+
+
+def _AttrValue(attr_protos, name):
+ if name in attr_protos:
+ return attr_protos[name]
+ raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
+ (name, attr_protos))
+
+
+def _SatisfiesTypeConstraint(dtype, attr_def):
+ if attr_def.HasField("allowed_values"):
+ allowed_list = attr_def.allowed_values.list.type
+ if dtype not in allowed_list:
+ raise TypeError(
+ "DataType %s for attr '%s' not in list of allowed values: %s" %
+ (types_lib.as_dtype(dtype).name, attr_def.name,
+ ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
+
+
+def _IsListParameter(arg):
+ if arg.number_attr:
+ return True
+ elif arg.type_list_attr:
+ return True
+ return False
+
+
+def _NumTypeFields(arg):
+ num = 0
+ if arg.type != types_pb2.DT_INVALID: num += 1
+ if arg.type_attr: num += 1
+ if arg.type_list_attr: num += 1
+ return num
+
+
+def _IsListValue(v):
+ return isinstance(v, (list, tuple))
+
+
+def _Flatten(l):
+ """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
+ # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
+ l_of_l = [x if _IsListValue(x) else [x] for x in l]
+ # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
+ return [item for sublist in l_of_l for item in sublist]
+
+
+def _Restructure(l, structure):
+ """Returns the elements of list l structured according to the given structure.
+
+ A structure is represented by a list whose elements are either
+ `None` or a non-negative integer. `None` corresponds to a single
+ element in the output list, and an integer N corresponds to a nested
+ list of length N.
+
+ The function returns a data structure whose shape is given by
+ `structure`, and whose elements are taken from `l`. If `structure`
+ is a singleton, the function returns the single data structure
+ implied by the 0th element of `structure`. For example:
+
+ _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
+ -> ["foo", ["bar", "baz"], "qux"]
+
+ _Restructure(["foo"], [None]) -> "foo"
+
+ _Restructure(["foo"], [1]) -> ["foo"]
+
+ _Restructure([], [0]) -> []
+
+ Args:
+ l: A list.
+ structure: A list whose elements are either `None` or a non-negative
+ integer.
+
+ Returns:
+ The elements of `l`, restructured according to `structure`. If
+ `structure` is a list of length 1, this function returns the
+ single data structure implied by `structure[0]`.
+
+ """
+ result = []
+ current_index = 0
+ for element in structure:
+ if element is None:
+ result.append(l[current_index])
+ current_index += 1
+ else:
+ result.append(l[current_index:current_index+element])
+ current_index += element
+
+ if len(result) == 1:
+ return result[0]
+ else:
+ return tuple(result)
+
+
+def _MakeFloat(v, arg_name):
+ if not isinstance(v, numbers.Real):
+ raise TypeError("Expected float for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return float(v)
+
+
+def _MakeInt(v, arg_name):
+ if isinstance(v, basestring):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ try:
+ return int(v)
+ except (ValueError, TypeError):
+ raise TypeError("Expected int for argument '%s' not %s." %
+ (arg_name, repr(v)))
+
+
+def _MakeStr(v, arg_name):
+ if not isinstance(v, basestring):
+ raise TypeError("Expected string for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return str(v) # Convert unicode strings to bytes.
+
+
+def _MakeBool(v, arg_name):
+ if not isinstance(v, bool):
+ raise TypeError("Expected bool for argument '%s' not %s." %
+ (arg_name, repr(v)))
+ return v
+
+
+def _MakeType(v, attr_def):
+ try:
+ v = types_lib.as_dtype(v)
+ except TypeError:
+ raise TypeError("Expected DataType for argument '%s' not %s." %
+ (attr_def.name, repr(v)))
+ i = v.as_datatype_enum
+ _SatisfiesTypeConstraint(i, attr_def)
+ return i
+
+
+def _MakeShape(v, arg_name):
+ """Convert v into a TensorShapeProto."""
+ # Args:
+ # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
+ # arg_name: String, for error messages.
+
+ # Returns:
+ # A TensorShapeProto.
+ if isinstance(v, tensor_shape_pb2.TensorShapeProto):
+ for d in v.dim:
+ if d.name:
+ logging.warning("Warning: TensorShapeProto with a named dimension: %s",
+ str(v))
+ break
+ return v
+ s = tensor_shape.as_shape(v)
+ ret = tensor_shape_pb2.TensorShapeProto()
+ for i in s.as_dimension_list():
+ ret.dim.add(size = i)
+ return ret
+
+
+def _MakeTensor(v, arg_name):
+ """Ensure v is a TensorProto."""
+ if isinstance(v, tensor_pb2.TensorProto):
+ return v
+ raise TypeError(
+ "Don't know how to convert %s to a TensorProto for argument '%s'" %
+ (repr(v), arg_name))
+
+
+class _OpInfo(object):
+ """All per-Op state we would like to precompute/validate."""
+
+ def __init__(self, op_def):
+ self.op_def = op_def
+ # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
+ # here, instead of these checks.
+ for arg in list(op_def.input_arg) + list(op_def.output_arg):
+ num_type_fields = _NumTypeFields(arg)
+ if num_type_fields != 1:
+ raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
+ (arg.name, op_def.name, num_type_fields))
+ if arg.type_attr:
+ attr_type = _Attr(op_def, arg.type_attr).type
+ if attr_type != "type":
+ raise TypeError("Attr '%s' of '%s' used as a type_attr "
+ "but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.type_list_attr:
+ attr_type = _Attr(op_def, arg.type_list_attr).type
+ if attr_type != "list(type)":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
+ (arg.type_attr, op_def.name, attr_type))
+ if arg.number_attr:
+ attr_type = _Attr(op_def, arg.number_attr).type
+ if attr_type != "int":
+ raise TypeError(
+ "Attr '%s' of '%s' used as a number_attr but has type %s" %
+ (arg.number_attr, op_def.name, attr_type))
+
+
+class OpDefLibrary(object):
+ """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
+
+ def __init__(self):
+ self._ops = {}
+
+ def add_op(self, op_def):
+ """Register an OpDef. May call apply_op with the name afterwards."""
+ if not isinstance(op_def, op_def_pb2.OpDef):
+ raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
+ (op_def, type(op_def)))
+ if not op_def.name:
+ raise ValueError("%s missing name." % op_def)
+ if op_def.name in self._ops:
+ raise RuntimeError("Op name %s registered twice." % op_def.name)
+ self._ops[op_def.name] = _OpInfo(op_def)
+
+ def add_op_list(self, op_list):
+ """Register the OpDefs from an OpList."""
+ if not isinstance(op_list, op_def_pb2.OpList):
+ raise TypeError("%s is %s, not an op_def_pb2.OpList" %
+ (op_list, type(op_list)))
+ for op_def in op_list.op:
+ self.add_op(op_def)
+
+ def apply_op(self, op_type_name, g=None, name=None, **keywords):
+ # pylint: disable=g-doc-args
+ """Add a node invoking a registered Op to a graph.
+
+ Config proto extensions must be provided via the 'ext' keyword argument.
+ Example usage:
+ # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
+ # will convert to a Tensor.
+ op_def_library.apply_op("op", input1=input1, input2=input2)
+ # If none of the inputs are Tensors and your session doesn't have a
+ # default graph, you will have to specify the graph.
+ op_def_library.apply_op("op", input1=input1, g=g)
+ # Can specify a node name.
+ op_def_library.apply_op("op", input1=input1, name="node_name")
+ # Must use keyword arguments, with the names specified in the OpDef.
+ op_def_library.apply_op("op", input_name=input, attr_name=attr)
+
+ All attrs must either be inferred from an input or specified.
+ (If inferred, the attr must not be specified.) If an attr has a default
+ value specified in the Op's OpDef, then you may pass None as the value
+ of that attr to get the default.
+
+ Args:
+ op_type_name: string. Must match the name field of a registered Op.
+ g: The graph context (optional)
+ name: string. Optional name of the created op.
+ **keywords: input Tensor and attr arguments specified by name,
+ and optional parameters to pass when constructing the Operation.
+
+ Returns:
+ The Tensor(s) representing the output of the operation, or the Operation
+ itself if there are no outputs.
+
+ Raises:
+ RuntimeError: On some errors.
+ TypeError: On some errors.
+ ValueError: On some errors.
+ """
+ op_info = self._ops.get(op_type_name, None)
+ if op_info is None:
+ raise RuntimeError("Unrecognized Op name " + op_type_name)
+ op_def = op_info.op_def
+
+ # Determine the graph context.
+ try:
+ # Need to flatten all the arguments into a list.
+ # pylint: disable=protected-access
+ g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
+ # pyline: enable=protected-access
+ except AssertionError as e:
+ raise RuntimeError(
+ "Need to specify g=graph to Op '%s' (could not determine graph due "
+ "to: %s)" % (op_type_name, e.message))
+
+ # Default name if not specified.
+ if name is None:
+ name = op_type_name
+
+ # Requires that op_def has passed validation (using the C++
+ # ValidateOpDef() from ../framework/op_def_util.h).
+ attrs = {}
+ inputs = []
+ input_types = []
+ with g.as_default(), ops.name_scope(name) as scope:
+
+ # Perform input type inference
+ inferred_from = {}
+ for input_arg in op_def.input_arg:
+ input_name = input_arg.name
+ if input_name in keywords:
+ values = keywords.pop(input_name)
+ elif input_name + "_" in keywords:
+ # Handle the case where the name is a keyword or built-in
+ # for Python so we use the name + _ instead.
+ input_name += "_"
+ values = keywords.pop(input_name)
+ else:
+ raise TypeError("No argument for input " + input_name)
+
+ # Goals:
+ # * Convert values to Tensors if it contains constants.
+ # * Verify that values is a list if that matches the input_arg's
+ # type.
+ # * If the input_arg's type is determined by attrs, either set
+ # those attrs and validate those attr values are legal (if
+ # they have not yet been set) or validate the input matches
+ # the type indicated by the attrs (if they have already been
+ # inferred via an earlier input).
+ # * If the input_arg has an explicit type, make sure the input
+ # conforms.
+
+ if _IsListParameter(input_arg):
+ if not _IsListValue(values):
+ raise TypeError(
+ "Expected list for '%s' argument to '%s' Op, not %s." %
+ (input_name, op_type_name, values))
+ # In cases where we expect all elements of the list to have the
+ # same dtype, try to cast non-Tensor elements to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.number_attr:
+ if input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+ else:
+ for t in values:
+ if isinstance(t, ops.Tensor):
+ dtype = t.dtype
+ break
+
+ try:
+ values = ops.convert_n_to_tensor_or_indexed_slices(
+ values, name=input_arg.name,
+ dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
+ except (TypeError, ValueError):
+ assert dtype is not None, "Should not fail if dtype is None"
+ assert input_arg.number_attr, "Should be number_attr case"
+ # What types does the conversion function think values have?
+ values = ops.convert_n_to_tensor_or_indexed_slices(values)
+ observed = ", ".join(v.dtype.base_dtype.name for v in values)
+
+ prefix = (
+ "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s that do not match expected type %s." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ elif input_arg.type_attr in attrs:
+ raise TypeError("%s that do not match type %s inferred from "
+ "earlier arguments." %
+ (prefix, types_lib.as_dtype(dtype).name))
+ else:
+ raise TypeError("%s that don't all match." % prefix)
+
+ types = [x.dtype for x in values]
+ inputs.extend(values)
+ else:
+ # In cases where we have an expected type, try to convert non-Tensor
+ # arguments to that type.
+ dtype = None
+ if input_arg.type != types_pb2.DT_INVALID:
+ dtype = input_arg.type
+ elif input_arg.type_attr in attrs:
+ dtype = attrs[input_arg.type_attr]
+
+ try:
+ values = ops.convert_to_tensor(
+ values, name=input_arg.name, dtype=dtype)
+ except ValueError:
+ # What type does convert_to_tensor think it has?
+ observed = ops.convert_to_tensor(values).dtype.name
+ prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
+ (input_name, op_type_name, observed))
+ if input_arg.type != types_pb2.DT_INVALID:
+ raise TypeError("%s expected type of %s." %
+ (prefix, types_lib.as_dtype(input_arg.type).name))
+ else:
+ raise TypeError(
+ "%s type %s of argument '%s'." %
+ (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name,
+ inferred_from[input_arg.type_attr]))
+
+ types = [values.dtype]
+ inputs.append(values)
+ base_types = [x.base_dtype for x in types]
+
+ if input_arg.number_attr:
+ # <number-attr> * <type> or <number-attr> * <type-attr>
+ if input_arg.number_attr in attrs:
+ if len(values) != attrs[input_arg.number_attr]:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d must match "
+ "length %d of argument '%s'." %
+ (input_name, op_type_name, len(values),
+ attrs[input_arg.number_attr],
+ inferred_from[input_arg.number_attr]))
+ else:
+ attrs[input_arg.number_attr] = len(values)
+ inferred_from[input_arg.number_attr] = input_name
+ num_attr = _Attr(op_def, input_arg.number_attr)
+ if num_attr.has_minimum and len(values) < num_attr.minimum:
+ raise ValueError(
+ "List argument '%s' to '%s' Op with length %d shorter "
+ "than minimum length %d." %
+ (input_name, op_type_name, len(values), num_attr.minimum))
+ # All tensors must have the same base type.
+ if any([bt != base_types[0] for bt in base_types]):
+ raise TypeError(
+ "All tensors passed to '%s' of '%s' Op "
+ "must have the same type." %
+ (input_name, op_type_name))
+ if input_arg.type != types_pb2.DT_INVALID:
+ # <number-attr> * <type> case
+ if base_types and base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+ elif input_arg.type_attr in attrs:
+ # <number-attr> * <type-attr> case, where <type-attr> already
+ # has an inferred value.
+ if base_types and base_types[0] != attrs[input_arg.type_attr]:
+ assert False, "Unreachable"
+ else:
+ # <number-attr> * <type-attr> case, where we are now setting
+ # the <type-attr> based on this input
+ if not base_types:
+ raise TypeError(
+ "Don't know how to infer type variable from empty input "
+ "list passed to input '%s' of '%s' Op." %
+ (input_name, op_type_name))
+ attrs[input_arg.type_attr] = base_types[0]
+ inferred_from[input_arg.type_attr] = input_name
+ type_attr = _Attr(op_def, input_arg.type_attr)
+ _SatisfiesTypeConstraint(base_types[0], type_attr)
+ elif input_arg.type_attr:
+ # <type-attr>
+ attr_value = base_types[0]
+ if input_arg.type_attr in attrs:
+ if attrs[input_arg.type_attr] != attr_value:
+ assert False, "Unreachable"
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_attr))
+ attrs[input_arg.type_attr] = attr_value
+ inferred_from[input_arg.type_attr] = input_name
+ elif input_arg.type_list_attr:
+ # <type-list-attr>
+ attr_value = base_types
+ if input_arg.type_list_attr in attrs:
+ if attrs[input_arg.type_list_attr] != attr_value:
+ raise TypeError(
+ "Input '%s' of '%s' Op has type list of %s that does not "
+ "match type list %s of argument '%s'." %
+ (input_name, op_type_name,
+ ", ".join(types_lib.as_dtype(x).name for x in attr_value),
+ ", ".join(types_lib.as_dtype(x).name
+ for x in attrs[input_arg.type_list_attr]),
+ inferred_from[input_arg.type_list_attr]))
+ else:
+ for base_type in base_types:
+ _SatisfiesTypeConstraint(base_type,
+ _Attr(op_def, input_arg.type_list_attr))
+ attrs[input_arg.type_list_attr] = attr_value
+ inferred_from[input_arg.type_list_attr] = input_name
+ else:
+ # single Tensor with specified type
+ if base_types[0] != input_arg.type:
+ assert False, "Unreachable"
+
+ if input_arg.is_ref:
+ if not all(x.is_ref_dtype for x in types):
+ raise TypeError(
+ "Input '%s' of '%s' Op requires l-value input" %
+ (input_name, op_type_name))
+ input_types.extend(types)
+ else:
+ input_types.extend(base_types)
+
+ # Process remaining attrs
+ for attr in op_def.attr:
+ # Skip attrs that have already had their values inferred
+ if attr.name in attrs:
+ if attr.name in keywords:
+ raise TypeError(
+ "Should not specify value for inferred attr '%s'." % attr.name)
+ continue
+ if attr.name in keywords:
+ attrs[attr.name] = keywords.pop(attr.name)
+ elif attr.name + "_" in keywords:
+ # Attrs whose names match Python keywords have an extra '_'
+ # appended, so we must check for that as well.
+ attrs[attr.name] = keywords.pop(attr.name + "_")
+ else:
+ raise TypeError("No argument for attr " + attr.name)
+
+ # Convert attr values to AttrValue protos.
+ attr_protos = {}
+ for attr_def in op_def.attr:
+ key = attr_def.name
+ value = attrs[key]
+ attr_value = attr_value_pb2.AttrValue()
+ if attr_def.HasField("default_value") and value is None:
+ attr_value.CopyFrom(attr_def.default_value)
+ attr_protos[key] = attr_value
+ continue
+ if attr_def.type.startswith("list("):
+ if not _IsListValue(value):
+ raise TypeError("Expected list for attr " + key)
+ if attr_def.has_minimum:
+ if len(value) < attr_def.minimum:
+ raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
+ "less than minimum %d." %
+ (key, op_type_name, len(value),
+ attr_def.minimum))
+ if attr_def.type == "string":
+ attr_value.s = _MakeStr(value, key)
+ if attr_def.HasField("allowed_values"):
+ if attr_value.s not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, attr_value.s,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "list(string)":
+ attr_value.list.s.extend([_MakeStr(x, key) for x in value])
+ if attr_def.HasField("allowed_values"):
+ for x in attr_value.list.s:
+ if x not in attr_def.allowed_values.list.s:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
+ (key, op_type_name, x,
+ '", "'.join(attr_def.allowed_values.list.s)))
+ elif attr_def.type == "int":
+ attr_value.i = _MakeInt(value, key)
+ if attr_def.has_minimum:
+ if attr_value.i < attr_def.minimum:
+ raise ValueError(
+ "Attr '%s' of '%s' Op passed %d less than minimum %d." %
+ (key, op_type_name, attr_value.i, attr_def.minimum))
+ elif attr_def.type == "list(int)":
+ attr_value.list.i.extend([_MakeInt(x, key) for x in value])
+ elif attr_def.type == "float":
+ attr_value.f = _MakeFloat(value, key)
+ elif attr_def.type == "list(float)":
+ attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
+ elif attr_def.type == "bool":
+ attr_value.b = _MakeBool(value, key)
+ elif attr_def.type == "list(bool)":
+ attr_value.list.b.extend([_MakeBool(x, key) for x in value])
+ elif attr_def.type == "type":
+ attr_value.type = _MakeType(value, attr_def)
+ elif attr_def.type == "list(type)":
+ attr_value.list.type.extend(
+ [_MakeType(x, attr_def) for x in value])
+ elif attr_def.type == "shape":
+ attr_value.shape.CopyFrom(_MakeShape(value, key))
+ elif attr_def.type == "list(shape)":
+ attr_value.list.shape.extend(
+ [_MakeShape(x, key) for x in value])
+ elif attr_def.type == "tensor":
+ attr_value.tensor.CopyFrom(_MakeTensor(value, key))
+ elif attr_def.type == "list(tensor)":
+ attr_value.list.tensor.extend(
+ [_MakeTensor(x, key) for x in value])
+ else:
+ raise TypeError("Unrecognized Attr type " + attr_def.type)
+
+ attr_protos[key] = attr_value
+ del attrs # attrs is no longer authoritative, use attr_protos instead
+
+ # Determine output types (possibly using attrs)
+ output_types = []
+ output_structure = []
+ for arg in op_def.output_arg:
+ types = []
+ if arg.number_attr:
+ n = _AttrValue(attr_protos, arg.number_attr).i
+ if arg.type_attr:
+ types = [_AttrValue(attr_protos, arg.type_attr).type] * n
+ else:
+ types = [arg.type] * n
+ output_structure.append(n)
+ elif arg.type_attr:
+ t = _AttrValue(attr_protos, arg.type_attr)
+ types = [t.type]
+ output_structure.append(None)
+ elif arg.type_list_attr:
+ t = _AttrValue(attr_protos, arg.type_list_attr)
+ types = t.list.type
+ output_structure.append(len(t.list.type))
+ else:
+ types = [arg.type]
+ output_structure.append(None)
+ if arg.is_ref:
+ types = [types_lib.as_dtype(x).as_ref for x in types]
+ output_types.extend(types)
+
+ if keywords:
+ raise TypeError("apply_op() got unexpected keyword arguments: " +
+ ", ".join(sorted(keywords.keys())))
+
+ # Add Op to graph
+ if output_structure:
+ op = g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
+ outputs = op.outputs
+ return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
+ output_structure)
+ else:
+ return g.create_op(op_type_name, inputs, output_types, name=scope,
+ input_types=input_types, attrs=attr_protos,
+ op_def=op_def)
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
new file mode 100644
index 0000000000..72de4586a3
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -0,0 +1,1402 @@
+"""Tests for tensorflow.python.ops.op_def_library."""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops.op_def_library import OpDefLibrary
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("Attr")(None)
+ops.RegisterShape("AttrBool")(None)
+ops.RegisterShape("AttrBoolList")(None)
+ops.RegisterShape("AttrDefault")(None)
+ops.RegisterShape("AttrEmptyListDefault")(None)
+ops.RegisterShape("AttrEnum")(None)
+ops.RegisterShape("AttrEnumList")(None)
+ops.RegisterShape("AttrFloat")(None)
+ops.RegisterShape("AttrListDefault")(None)
+ops.RegisterShape("AttrListMin")(None)
+ops.RegisterShape("AttrMin")(None)
+ops.RegisterShape("AttrShape")(None)
+ops.RegisterShape("AttrShapeList")(None)
+ops.RegisterShape("Binary")(None)
+ops.RegisterShape("ComplexStruct")(None)
+ops.RegisterShape("InPolymorphicTwice")(None)
+ops.RegisterShape("MixedStruct")(None)
+ops.RegisterShape("NInPolymorphicTwice")(None)
+ops.RegisterShape("NInTwice")(None)
+ops.RegisterShape("NInTwoTypeVariables")(None)
+ops.RegisterShape("NIntsIn")(None)
+ops.RegisterShape("NIntsOut")(None)
+ops.RegisterShape("NIntsOutDefault")(None)
+ops.RegisterShape("NPolymorphicIn")(None)
+ops.RegisterShape("NPolymorphicOut")(None)
+ops.RegisterShape("NPolymorphicOutDefault")(None)
+ops.RegisterShape("NPolymorphicRestrictIn")(None)
+ops.RegisterShape("NPolymorphicRestrictOut")(None)
+ops.RegisterShape("OutT")(None)
+ops.RegisterShape("OutTypeList")(None)
+ops.RegisterShape("OutTypeListRestrict")(None)
+ops.RegisterShape("Polymorphic")(None)
+ops.RegisterShape("PolymorphicDefaultOut")(None)
+ops.RegisterShape("PolymorphicOut")(None)
+ops.RegisterShape("RefIn")(None)
+ops.RegisterShape("RefOut")(None)
+ops.RegisterShape("ReservedAttr")(None)
+ops.RegisterShape("ReservedInput")(None)
+ops.RegisterShape("Restrict")(None)
+ops.RegisterShape("Simple")(None)
+ops.RegisterShape("SimpleStruct")(None)
+ops.RegisterShape("TypeList")(None)
+ops.RegisterShape("TypeListRestrict")(None)
+ops.RegisterShape("TypeListTwice")(None)
+
+
+class OpDefLibraryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._default_graph_controller = self._g.as_default()
+ self._default_graph_controller.__enter__()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def tearDown(self):
+ self._default_graph_controller.__exit__(None, None, None)
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def Tensor(self, t, name="in"):
+ return self._lib.apply_op("OutT", T=t, name=name)
+
+ def testNoRegisteredOpFails(self):
+ with self.assertRaises(RuntimeError) as cm:
+ self._lib.apply_op("unknown", g=self._g)
+ self.assertEqual(cm.exception.message, "Unrecognized Op name unknown")
+
+ def testAddOpValidation(self):
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingTypeAttr' "
+ "input_arg { name: 'a' type_attr: 'T' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingTypeAttr', "
+ "missing attr 'T'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadTypeAttr' "
+ "output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'int' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingNumberAttr' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingNumberAttr', "
+ "missing attr 'N'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadNumberAttr' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'type' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesA' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesA' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesB' "
+ "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesB' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'ThreeTypes' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' "
+ "type_list_attr: 'U' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'U' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'ThreeTypes' must have one type field not 3")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'NoTypes' must have one type field not 0")
+
+ def testSimple(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'Simple' op: 'Simple' input: 'Simple/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=4)
+ self.assertProtoEquals("""
+ name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=5, name="named")
+ self.assertProtoEquals("""
+ name: 'named' op: 'Simple' input: 'named/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")
+ self.assertProtoEquals("""
+ name: 'two_d' op: 'Simple' input: 'two_d/a'
+ """, out.op.node_def)
+
+ def testSimpleFailures(self):
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a="Bad string")
+ self.assertEqual(cm.exception.message,
+ "Expected int32, got 'Bad string' instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=self.Tensor(types.string))
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'Simple' Op has type string "
+ "that does not match expected type of int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra="bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra1, "
+ "extra2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple")
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", wrong=7)
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=[self.Tensor(types.int32)])
+ self.assertStartsWith(cm.exception.message, "Expected int32, got")
+
+ def testReservedInput(self):
+ self._add_op("name: 'ReservedInput' "
+ "input_arg { name: 'input' type: DT_INT32 } ")
+ op = self._lib.apply_op("ReservedInput", input_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedInput' input: 'x/input'
+ """, op.node_def)
+
+ def testPolymorphic(self):
+ self._add_op("name: 'Polymorphic' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Polymorphic", a=7, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'Polymorphic' input: 'p/a'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a="s", name="q")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'Polymorphic' input: 'q/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'Polymorphic' input: 'r/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a="s", T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Should not specify value for inferred attr 'T'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)])
+ self.assertEqual(cm.exception.message,
+ "List of Tensors when single Tensor expected")
+
+ def testPolymorphicOut(self):
+ self._add_op("name: 'PolymorphicOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut")
+ self.assertEqual(cm.exception.message,
+ "No argument for attr T")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut", T=None)
+ self.assertEqual(cm.exception.message,
+ "Expected DataType for argument 'T' not None.")
+
+ def testPolymorphicDefaultOut(self):
+ self._add_op("name: 'PolymorphicDefaultOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_STRING } }")
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool,
+ name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ def testBinary(self):
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Binary", a=8, b=9, name="b")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Binary", a="left", b="right", name="c")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a="left", b=12)
+ self.assertEqual(cm.exception.message,
+ "Expected string, got 12 instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a=self.Tensor(types.string),
+ b=self.Tensor(types.int32))
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'Binary' Op has type int32 "
+ "that does not match type string of argument 'a'.")
+
+ def testRestrict(self):
+ self._add_op("name: 'Restrict' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out = self._lib.apply_op("Restrict", a="foo", name="g")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'g' op: 'Restrict' input: 'g/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Restrict", a=True, name="h")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'h' op: 'Restrict' input: 'h/a'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Restrict", a=17)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: "
+ "string, bool")
+
+ def testTypeList(self):
+ self._add_op("name: 'TypeList' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeList", a=["foo"], name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeList' input: 'z/a_0'
+ attr { key: 'T' value { list { type: DT_STRING } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[True, 12], name="y")
+ self.assertProtoEquals("""
+ name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'
+ attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeList", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' "
+ "argument to 'TypeList' Op, not ")
+
+ def testTypeListTwice(self):
+ self._add_op("name: 'TypeListTwice' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "input_arg { name: 'b' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", False],
+ name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeListTwice'
+ input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'TypeListTwice' Op has type list of "
+ "string, int32 that does not match type list "
+ "string, bool of argument 'a'.")
+
+ def testOutTypeList(self):
+ self._add_op("name: 'OutTypeList' "
+ "output_arg { name: 'out' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x")
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'x' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_FLOAT } } }
+ """, out.op.node_def)
+
+ out1, out2 = self._lib.apply_op("OutTypeList",
+ T=[types.int32, types.bool],
+ name="w")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'w' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
+ """, out1.op.node_def)
+
+ out = self._lib.apply_op("OutTypeList", T=[], name="empty")
+ self.assertEqual([], out)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeList", T=types.int32)
+ self.assertEqual(cm.exception.message, "Expected list for attr T")
+
+ def testTypeListRestrict(self):
+ self._add_op("name: 'TypeListRestrict' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v")
+ self.assertProtoEquals("""
+ name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListRestrict", a=[True, 12])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testOutTypeListRestrict(self):
+ self._add_op("name: 'OutTypeListRestrict' "
+ "output_arg { name: 'out' type_list_attr: 't' } "
+ "attr { name: 't' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out1, out2 = self._lib.apply_op("OutTypeListRestrict",
+ t=[types.bool, types.string],
+ name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'OutTypeListRestrict'
+ attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeListRestrict",
+ t=[types.string, types.int32])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 't' "
+ "not in list of allowed values: string, bool")
+
+ def testAttr(self):
+ self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }")
+ op = self._lib.apply_op("Attr", a=12, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not 'bad'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=[12])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not [12].")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=None)
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not None.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr")
+ self.assertEqual(cm.exception.message, "No argument for attr a")
+
+ def testAttrFloat(self):
+ self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")
+
+ op = self._lib.apply_op("AttrFloat", a=1.2, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrFloat", a=12, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrFloat", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected float for argument 'a' not 'bad'.")
+
+ def testAttrBool(self):
+ self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }")
+
+ op = self._lib.apply_op("AttrBool", a=True, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBool", a=False, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=0)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=1)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 1.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=[])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not [].")
+
+ def testAttrBoolList(self):
+ self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }")
+
+ op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBoolList'
+ attr { key: 'a' value { list { b: true b: false b:true } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBoolList", a=[], name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBoolList", a=[0])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ def testAttrMin(self):
+ self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' "
+ "has_minimum: true minimum: 5 }")
+ op = self._lib.apply_op("AttrMin", a=12, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrMin", a=2)
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")
+
+ def testAttrListMin(self):
+ self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' "
+ "has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r")
+ self.assertProtoEquals("""
+ name: 'r' op: 'AttrListMin'
+ attr { key: 'a' value { list { i: 1 i: 2 } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrListMin", a=[17])
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrListMin' Op "
+ "passed list of length 1 less than minimum 2.")
+
+ def testAttrEnum(self):
+ self._add_op("name: 'AttrEnum' "
+ "attr { name: 'a' type: 'string' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnum", a="oranges", name="e")
+ self.assertProtoEquals("""
+ name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnum", a="invalid")
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnum\' Op '
+ 'passed string \'invalid\' not in: '
+ '"apples", "oranges".')
+
+ def testAttrEnumList(self):
+ self._add_op("name: 'AttrEnumList' "
+ "attr { name: 'a' type: 'list(string)' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f")
+ self.assertProtoEquals("""
+ name: 'f' op: 'AttrEnumList'
+ attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnumList\' Op '
+ 'passed string \'invalid\' not '
+ 'in: "apples", "oranges".')
+
+ def testAttrShape(self):
+ self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }")
+
+ op = self._lib.apply_op("AttrShape", a=[5], name="s1")
+ self.assertProtoEquals("""
+ name: 's1' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 5 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2")
+ self.assertProtoEquals("""
+ name: 's2' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op(
+ "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")
+ self.assertProtoEquals("""
+ name: 's3' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=[], name="s4")
+ self.assertProtoEquals("""
+ name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }
+ """, op.node_def)
+
+ shape = tensor_shape_pb2.TensorShapeProto()
+ shape.dim.add().size = 6
+ shape.dim.add().size = 3
+ op = self._lib.apply_op("AttrShape", a=shape, name="s5")
+ self.assertProtoEquals("""
+ name: 's5' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }
+ """, op.node_def)
+
+ # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.
+ # with self.assertRaises(TypeError) as cm:
+ # self._lib.apply_op("AttrShape", a=5)
+ # self.assertEqual(cm.exception.message,
+ # "Don't know how to convert 5 to a TensorShapeProto for "
+ # "argument 'a'")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrShape", a="ABC")
+
+ def testAttrShapeList(self):
+ self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }")
+
+ op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")
+ self.assertProtoEquals("""
+ name: 'sl' op: 'AttrShapeList'
+ attr { key: 'a' value { list {
+ shape { dim { size: 3 } dim { size: 2 } }
+ shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShapeList", a=[], name="esl")
+ self.assertProtoEquals("""
+ name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrDefault(self):
+ self._add_op("name: 'AttrDefault' "
+ "attr { name: 'a' type: 'string' "
+ " default_value { s: 'banana' } }")
+
+ op = self._lib.apply_op("AttrDefault", a=None, name="d")
+ self.assertProtoEquals("""
+ name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrDefault", a="kiwi", name="c")
+ self.assertProtoEquals("""
+ name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }
+ """, op.node_def)
+
+ def testAttrListDefault(self):
+ self._add_op("name: 'AttrListDefault' "
+ "attr { name: 'a' type: 'list(int)' "
+ " default_value { list { i: 5 i: 15 } } }")
+
+ op = self._lib.apply_op("AttrListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 5 i: 15 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrEmptyListDefault(self):
+ self._add_op("name: 'AttrEmptyListDefault' "
+ "attr { name: 'a' type: 'list(float)' "
+ " default_value { list { } } }")
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { f: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testReservedAttr(self):
+ self._add_op("name: 'ReservedAttr' "
+ "attr { name: 'range' type: 'int' } ")
+ op = self._lib.apply_op("ReservedAttr", range_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
+ """, op.node_def)
+
+ def testNIntsIn(self):
+ self._add_op("name: 'NIntsIn' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=["foo", "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[string, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have "
+ "types [string, string] that do not match expected type "
+ "int32.")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NIntsIn' Op "
+ "with length 1 shorter than "
+ "minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[int32, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op "
+ "have types [int32, string] that do not match expected "
+ "type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NIntsIn' Op, not ")
+
+ def testNPolymorphicIn(self):
+ self._add_op("name: 'NPolymorphicIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn",
+ a=[1, self.Tensor(types.float32, name="x")],
+ name="q")
+ self.assertProtoEquals("""
+ name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NPolymorphicIn' Op with length 1 "
+ "shorter than minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "All tensors passed to 'a' of 'NPolymorphicIn' "
+ "Op must have the same type.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=[38, self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [int32, string] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=["abcd", self.Tensor(types.int32)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [string, int32] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NPolymorphicIn' Op, not ")
+
+ def testNPolymorphicRestrictIn(self):
+ self._add_op("name: 'NPolymorphicRestrictIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=[False, True, False],
+ name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'NPolymorphicRestrictIn'
+ input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testNInTwice(self):
+ self._add_op("name: 'NInTwice' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwice", a=[], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwice' Op "
+ "with length 1 must match "
+ "length 3 of argument 'a'.")
+
+ def testNInPolymorphicTwice(self):
+ self._add_op("name: 'NInPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInPolymorphicTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInPolymorphicTwice' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
+ "Op have types [string, string] that do not match type "
+ "int32 inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of "
+ "'NInPolymorphicTwice' Op have types [string] that do not "
+ "match type int32 inferred from earlier arguments.")
+
+ def testNInTwoTypeVariables(self):
+ self._add_op("name: 'NInTwoTypeVariables' "
+ "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'S' type: 'type' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[True, False],
+ name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwoTypeVariables'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwoTypeVariables'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables",
+ a=[self.Tensor(types.int32, name="q")],
+ b=[self.Tensor(types.string, name="r")],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 1 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwoTypeVariables' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ def testInPolymorphicTwice(self):
+ self._add_op("name: 'InPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } "
+ "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ")
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'InPolymorphicTwice'
+ input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 3 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
+ self.assertEqual(cm.exception.message,
+ "Don't know how to infer type variable from empty input "
+ "list passed to input 'a' of 'InPolymorphicTwice' Op.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
+ "have types [string, string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' "
+ "Op have types [string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ def testNIntsOut(self):
+ self._add_op("name: 'NIntsOut' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3, out4, out5 = self._lib.apply_op(
+ "NIntsOut", N=5, name="o")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertEquals(types.int32, out4.dtype)
+ self.assertEquals(types.int32, out5.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
+ """, out5.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsOut", N=1)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsOut", N=[3])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'N' not [3].")
+
+ def testNIntsOutDefault(self):
+ self._add_op("name: 'NIntsOutDefault' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2"
+ " default_value { i:3 } }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NIntsOutDefault", N=None, name="z")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
+ """, out2.op.node_def)
+
+ def testNPolymorphicOut(self):
+ self._add_op("name: 'NPolymorphicOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2,
+ T=types.int32, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOut", T=types.string, N=3, name="o")
+ self.assertEquals(types.string, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertEquals(types.string, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 3 } }
+ """, out3.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=1, T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NPolymorphicOut' Op "
+ "passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string])
+ self.assertEqual(
+ cm.exception.message,
+ "Expected DataType for argument 'T' not [tf.string].")
+
+ def testNPolymorphicOutDefault(self):
+ self._add_op("name: 'NPolymorphicOutDefault' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type'"
+ " default_value { type: DT_BOOL } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 "
+ " default_value { i: 2 } }")
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=None, name="r")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=None, name="s")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 's' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=types.int32, name="t")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 't' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=types.int32, name="u")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ def testNPolymorphicRestrictOut(self):
+ self._add_op("name: 'NPolymorphicRestrictOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicRestrictOut", N=3, T=types.bool, name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicRestrictOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testRef(self):
+ self._add_op("name: 'RefIn' "
+ "input_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+ self._add_op("name: 'RefOut' "
+ "output_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+
+ out = self._lib.apply_op("RefOut", T=types.bool, name="o")
+ self.assertEquals(types.bool_ref, out.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'RefOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ op = self._lib.apply_op("RefIn", a=out, name="i")
+ self.assertProtoEquals("""
+ name: 'i' op: 'RefIn' input: 'o'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, op.node_def)
+
+ # Can pass ref to non-ref input.
+ out = self._lib.apply_op("RefOut", T=types.int32, name="r")
+ out = self._lib.apply_op("Simple", a=out, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'Simple' input: 'r'
+ """, out.op.node_def)
+
+ # Can't pass non-ref to ref input.
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("RefIn", a=2)
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'RefIn' Op requires l-value input")
+
+ def testSpecifyDevice(self):
+ with self._g.device("ADevice"):
+ self._lib.apply_op("Simple", a=3)
+ # We look at the whole graph here to make sure the Const op is also given
+ # the specified device.
+ graph_def = self._g.as_graph_def()
+ self.assertEqual(len(graph_def.node), 2)
+ for node in graph_def.node:
+ self.assertEqual(node.device, "ADevice")
+
+ def testStructuredOutputSingleList(self):
+ self._add_op("name: 'SimpleStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a = self._lib.apply_op("SimpleStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+
+ def testStructuredOutputListAndSingle(self):
+ self._add_op("name: 'MixedStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_FLOAT } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a, b = self._lib.apply_op("MixedStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertTrue(isinstance(b, ops.Tensor))
+ self.assertEqual(types.float32, b.dtype)
+
+ def testStructuredOutputMultipleLists(self):
+ self._add_op("name: 'ComplexStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } "
+ "output_arg { name: 'c' type_list_attr: 't_c' } "
+ "attr { name: 'n_a' type: 'int' } "
+ "attr { name: 'n_b' type: 'int' } "
+ "attr { name: 't_c' type: 'list(type)' }")
+ for n_a in [0, 1, 3]:
+ for n_b in [0, 1, 3]:
+ for t_c in [[],
+ [types.int32],
+ [types.int32, types.float32]]:
+ a, b, c = self._lib.apply_op("ComplexStruct",
+ n_a=n_a, n_b=n_b, t_c=t_c)
+
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertEqual(n_b, len(b))
+ self.assertTrue(all(x.dtype == types.int64 for x in b))
+ self.assertEqual(t_c, [x.dtype for x in c])
+
+
+class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def testNoGraph(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, ops.get_default_graph())
+
+ def testDefaultGraph(self):
+ with self._g.as_default():
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, self._g)
+
+ def testIgnoreDefaultGraphWithGraphArgument(self):
+ default_g = ops.Graph()
+ with default_g.as_default():
+ out = self._lib.apply_op("Simple", a=3, g=self._g)
+ self.assertEquals(ops.get_default_graph(), default_g)
+ self.assertEquals(out.graph, self._g)
+
+ def testDifferentGraphFails(self):
+ a = self._lib.apply_op("Simple", a=3, g=self._g)
+ other_g = ops.Graph()
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b)
+ self.assertTrue("must be from the same graph" in cm.exception.message)
+
+ def testDifferentGraphFailsWithGraphArgument(self):
+ other_g = ops.Graph()
+ a = self._lib.apply_op("Simple", a=3, g=other_g)
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b, g=self._g)
+ self.assertTrue(
+ "not from the passed-in graph" in cm.exception.message)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
new file mode 100644
index 0000000000..dc954a3776
--- /dev/null
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -0,0 +1,390 @@
+"""Parsing Ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_parsing_ops import *
+
+
+ops.NoGradient("DecodeRaw")
+ops.NoGradient("StringToNumber")
+
+
+# pylint: disable=protected-access
+def parse_example(serialized,
+ names=None,
+ sparse_keys=None,
+ sparse_types=None,
+ dense_keys=None,
+ dense_types=None,
+ dense_defaults=None,
+ dense_shapes=None,
+ name="ParseExample"):
+ """Parse Example protos.
+
+ Args:
+ serialized: string vector, a batch of binary serialized Example protos.
+ names: A string vector, the names of the serialized protos.
+ "names" may contain, e.g., table key (descriptive) names for the
+ corresponding serialized protos. These are purely useful for debugging
+ purposes, and the presence of values here has no effect on the output.
+ "names" may be an empty vector, if no names are available.
+ If non-empty, this vector must be the same length as "serialized".
+ sparse_keys: A string list of keys in the Examples' features.
+ These keys are associated with sparse values.
+ sparse_types: A list of DTypes.
+ This list's length must match that of sparse_keys. Currently
+ parse_example supports tf.float32 (FloatList), tf.int64 (Int64List),
+ and tf.string (BytesList).
+ dense_keys: A string list of keys in the Examples' features.
+ These keys are associated with dense values.
+ dense_types: A list of DTypes.
+ This list's length must match that of dense_keys. Currently
+ parse_example supports tf.float32 (FloatList), tf.int64 (Int64List),
+ and tf.string (BytesList).
+ dense_defaults: A dict of {key:Tensor} (some may be missing).
+ The keys of the dict must match the dense_keys of the feature.
+ If a key is not present in this dictionary, the corresponding dense
+ Feature is required in all elements of serialized.
+ dense_shapes: A list of tuples.
+ Entries provide the shape of data in each dense Feature in features.
+ The length of dense_shapes must be the same as the length of dense_keys.
+ The number of elements in the Feature corresponding to dense_key[j]
+ must always have np.prod(dense_shapes[j]) entries.
+ If dense_shapes[j] == (D0, D1, ..., DN) then the the shape of output
+ Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):
+ The dense outputs are just the inputs row-stacked by batch.
+ name: (Optional) Name of Op in the graph.
+
+ Returns:
+ A dictionary mapping keys to Tensors and SparseTensors.
+
+ The key dense_keys[j] is mapped to a tensor of type dense_types[j] and
+ of shape (serialized.size(),) + dense_shapes[j] (i.e., the dense outputs are
+ inputs, reshaped in row-major format and then row-stacked by batch).
+
+ The key sparse_keys[j] is mapped to a SparseTensor of type sparse_types[j].
+ The SparseTensor represents a ragged matrix. Its indices are [batch, index]
+ where "batch" is is the batch entry the value is from, and "index" is the
+ value's index in the list of values associated with that feature
+ and example. For example, if one expects a tf.float32 sparse feature "ft"
+ and three serialized examples are provided:
+
+ serialized = [
+ features:
+ { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] },
+ features:
+ { feature: [] },
+ features:
+ { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] }
+ ]
+
+ then the output will look like:
+
+ {"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
+ values=[1.0, 2.0, 3.0],
+ shape=(3, 2)) }
+
+ Raises:
+ ValueError: If sparse and dense keys intersect, or input lengths do not
+ match up for sparse_* (similarly for dense_*).
+ TypeError: If an input is malformed.
+
+ Example input, format, and output: Just Sparse Inputs
+ ================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } }
+ feature: { key: "gps" value: { float_list: { value: [] } } }
+ },
+ features: {
+ feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } }
+ feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } }
+ feature: { key: "gps" value: { } }
+ }]
+ names: ["input0", "input1"],
+ sparse_keys: ["kw", "dank", "gps"]
+ sparse_types: [DT_STRING, DT_INT64, DT_FLOAT]
+
+ Then the expected output is a dictionary:
+ {
+ "kw": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["knit", "big", "emmy"]
+ shape=[2, 2]),
+ "dank": SparseTensor(
+ indices=[[1, 0]],
+ values=[42],
+ shape=[2, 1]),
+ "gps": SparseTensor(
+ indices=[],
+ values=[],
+ shape=[2, 0]),
+ }
+
+
+ Example input, format, and output: Dense Inputs (without defaults)
+ ==================================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "age" value: { int64_list: { value: [ 0 ] } } }
+ feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ },
+ features: {
+ feature: { key: "age" value: { int64_list: { value: [] } } }
+ feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
+ }]
+ names: ["input0", "input1"],
+ dense_keys: np.array(["age", "gender"])
+ dense_types: [tf.int64, tf.string]
+ dense_defaults: {
+ "age": -1 # defaults to -1 if missing
+ # "gender" has no specified default so it's required
+ }
+ dense_shapes: [(1,), (1,)] # age, gender, label, weight
+
+ Then the expected output is a dictionary:
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ }
+
+
+ Example input, format, and output: Dense Inputs (with defaults)
+ ===============================================================
+
+ Given two brain.Example input protos:
+
+ serialized: // serialized versions of the protos below
+ [features: {
+ feature: { key: "weight" value: { float_list: { value: [ 1.0 ] } } }
+ },
+ features: {
+ feature: { key: "label" value: { float_list: { value: [ -1.0, 0.0 ] } } }
+ }]
+ names: ["input0", "input1"],
+ dense_keys: np.array(["label", "weight"])
+ dense_defaults: {
+ "label": [1.0, 2.0], # float (default: vector)
+ "weight": 5.0 # float (default: scalar, 5.0)
+ }
+ dense_shapes: [(2,), (1,)] # age, gender, label, weight
+
+ Then the expected output is a dictionary:
+ {
+ "label": [[1.0, 2.0], [-1.0, 0.0]],
+ "weight": [[1.0], [5.0]],
+ }
+ """
+ names = [] if names is None else names
+ dense_defaults = {} if dense_defaults is None else dense_defaults
+ sparse_keys = [] if sparse_keys is None else sparse_keys
+ sparse_types = [] if sparse_types is None else sparse_types
+ dense_keys = [] if dense_keys is None else dense_keys
+ dense_types = [] if dense_types is None else dense_types
+ dense_shapes = [
+ []] * len(dense_keys) if dense_shapes is None else dense_shapes
+
+ num_dense = len(dense_keys)
+ num_sparse = len(sparse_keys)
+
+ if len(dense_shapes) != num_dense:
+ raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
+ % (len(dense_shapes), num_dense))
+ if len(dense_types) != num_dense:
+ raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
+ % (len(dense_types), num_dense))
+ if len(sparse_types) != num_sparse:
+ raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
+ % (len(sparse_types), num_sparse))
+ if num_dense + num_sparse == 0:
+ raise ValueError("Must provide at least one sparse key or dense key")
+ if not set(dense_keys).isdisjoint(set(sparse_keys)):
+ raise ValueError(
+ "Dense and sparse keys must not intersect; intersection: %s" %
+ set(dense_keys).intersection(set(sparse_keys)))
+
+ dense_defaults_vec = []
+ for i, key in enumerate(dense_keys):
+ default_value = dense_defaults.get(key)
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key)
+ default_value = array_ops.reshape(default_value, dense_shapes[i])
+
+ dense_defaults_vec.append(default_value)
+
+ dense_shapes = [tensor_util.MakeTensorShapeProto(shape)
+ if isinstance(shape, (list, tuple)) else shape
+ for shape in dense_shapes]
+
+ outputs = gen_parsing_ops._parse_example(
+ serialized=serialized,
+ names=names,
+ dense_defaults=dense_defaults_vec,
+ sparse_keys=sparse_keys,
+ sparse_types=sparse_types,
+ dense_keys=dense_keys,
+ dense_shapes=dense_shapes,
+ name=name)
+
+ (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
+
+ sparse_tensors = [ops.SparseTensor(ix, val, shape) for (ix, val, shape)
+ in zip(sparse_indices, sparse_values, sparse_shapes)]
+
+ return dict(
+ zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+
+
+def parse_single_example(serialized, # pylint: disable=invalid-name
+ names=None,
+ sparse_keys=None,
+ sparse_types=None,
+ dense_keys=None,
+ dense_types=None,
+ dense_defaults=None,
+ dense_shapes=None,
+ name="ParseSingleExample"):
+ """Identical to parse_example but for scalar serialized and names.
+
+ Args:
+ serialized: A scalar string, a single serialized Example.
+ See parse_example documentation for more details.
+ names: (Optional) A scalar string, the associated name.
+ See parse_example documentation for more details.
+ sparse_keys: See parse_example documentation for more details.
+ sparse_types: See parse_example documentation for more details.
+ dense_keys: See parse_example documentation for more details.
+ dense_types: See parse_example documentation for more details.
+ dense_defaults: See parse_example documentation for more details.
+ dense_shapes: See parse_example documentation for more details.
+ name: Optional op name.
+
+ Returns:
+ A dictionary mapping keys to Tensors and SparseTensors.
+
+ For dense tensors, the Tensor is identical to the output of parse_example,
+ except it is one less dimension (the first, batch, dimension is removed).
+
+ For SparseTensors:
+ The first (batch) column of the indices matrix is removed
+ (it is now a column vector).
+ The values vector is unchanged.
+ The first (batch_size) entry of the shape vector is removed
+ (it is now a single element vector).
+
+ Raises:
+ ValueError: if "scalar" or "names" have known shapes, and are not scalars.
+ """
+ with ops.op_scope([serialized], name, "parse_single_example"):
+ serialized = ops.convert_to_tensor(serialized)
+ serialized_shape = serialized.get_shape()
+ if serialized_shape.ndims is not None:
+ if serialized_shape.ndims != 0:
+ raise ValueError("Input serialized must be a scalar")
+ else:
+ serialized = control_flow_ops.with_dependencies(
+ [logging_ops.Assert(
+ math_ops.equal(array_ops.rank(serialized), 0),
+ ["Input serialized must be a scalar"],
+ name="SerializedIsScalar")],
+ serialized,
+ name="SerializedDependencies")
+ serialized = array_ops.expand_dims(serialized, 0)
+ if names is not None:
+ names = ops.convert_to_tensor(names)
+ names_shape = names.get_shape()
+ if names_shape.ndims is not None:
+ if names_shape.ndims != 0:
+ raise ValueError("Input names must be a scalar")
+ else:
+ names = control_flow_ops.with_dependencies(
+ [logging_ops.Assert(
+ math_ops.equal(array_ops.rank(names), 0),
+ ["Input names must be a scalar"],
+ name="NamesIsScalar")],
+ names,
+ name="NamesDependencies")
+ names = array_ops.expand_dims(names, 0)
+
+ outputs = parse_example(serialized,
+ names=names,
+ sparse_keys=sparse_keys,
+ sparse_types=sparse_types,
+ dense_keys=dense_keys,
+ dense_types=dense_types,
+ dense_defaults=dense_defaults,
+ dense_shapes=dense_shapes,
+ name=name)
+ if dense_keys is not None:
+ for d in dense_keys:
+ outputs[d] = array_ops.squeeze(outputs[d], [0], name="Squeeze_%s" % d)
+ if sparse_keys is not None:
+ for s in sparse_keys:
+ outputs[s] = ops.SparseTensor(
+ array_ops.slice(outputs[s].indices,
+ [0, 1], [-1, -1], name="Slice_Indices_%s" % s),
+ outputs[s].values,
+ array_ops.slice(outputs[s].shape,
+ [1], [-1], name="Squeeze_Shape_%s" % s))
+ return outputs
+
+
+@ops.RegisterShape("ParseExample")
+def _ParseExampleShape(op):
+ """Shape function for the ParseExample op."""
+ input_shape = op.inputs[0].get_shape().with_rank(1)
+ num_sparse = op.get_attr("Nsparse")
+ num_dense = op.get_attr("Ndense")
+ dense_shapes = op.get_attr("dense_shapes")
+ sparse_index_shapes = [
+ tensor_shape.matrix(None, 2) for _ in range(num_sparse)]
+ sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)]
+ sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)]
+ assert num_dense == len(dense_shapes)
+ dense_shapes = [
+ input_shape.concatenate((d.size for d in dense_shape.dim))
+ for dense_shape in dense_shapes]
+ return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes +
+ dense_shapes)
+
+
+ops.RegisterShape("StringToNumber")(
+ common_shapes.unchanged_shape)
+
+
+@ops.RegisterShape("DecodeRaw")
+def _DecodeRawShape(op):
+ """Shape function for the DecodeRaw op."""
+ # NOTE(mrry): Last dimension is data-dependent.
+ return [op.inputs[0].get_shape().concatenate([None])]
+
+
+@ops.RegisterShape("DecodeCSV")
+def _DecodeCSVShape(op):
+ """Shape function for the DecodeCSV op."""
+ input_shape = op.inputs[0].get_shape()
+ # Optionally check that all of other inputs are scalar or empty.
+ for default_input in op.inputs[1:]:
+ default_input_shape = default_input.get_shape().with_rank(1)
+ if default_input_shape[0] > 1:
+ raise ValueError(
+ "Shape of a default must be a length-0 or length-1 vector.")
+ return [input_shape] * len(op.outputs)
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
new file mode 100644
index 0000000000..6bd8dd9e3d
--- /dev/null
+++ b/tensorflow/python/ops/random_ops.py
@@ -0,0 +1,181 @@
+"""Operations for generating random numbers."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_random_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_random_ops import *
+# pylint: enable=wildcard-import
+
+
+def _ShapeTensor(shape):
+ """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
+ if isinstance(shape, (tuple, list)) and not shape:
+ dtype = types.int32
+ else:
+ dtype = None
+ return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
+
+# pylint: disable=protected-access
+def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+ seed=None, name=None):
+ """Outputs random values from a normal distribution.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.op_scope([shape, mean, stddev], name, "random_normal") as name:
+ shape_tensor = _ShapeTensor(shape)
+ mean_tensor = ops.convert_to_tensor(
+ mean, dtype=dtype, name="mean")
+ stddev_tensor = ops.convert_to_tensor(
+ stddev, dtype=dtype, name="stddev")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._random_standard_normal(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * stddev_tensor
+ value = math_ops.add(mul, mean_tensor, name=name)
+ return value
+
+
+ops.NoGradient("RandomStandardNormal")
+
+
+def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32,
+ seed=None, name=None):
+ """Outputs random values from a truncated normal distribution.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the truncated normal distribution.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.op_scope([shape, mean, stddev], name, "truncated_normal") as name:
+ shape_tensor = _ShapeTensor(shape)
+ mean_tensor = ops.convert_to_tensor(
+ mean, dtype=dtype, name="mean")
+ stddev_tensor = ops.convert_to_tensor(
+ stddev, dtype=dtype, name="stddev")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._truncated_normal(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * stddev_tensor
+ value = math_ops.add(mul, mean_tensor, name=name)
+ return value
+
+
+ops.NoGradient("TruncatedNormal")
+
+
+def random_uniform(shape, minval=0.0, maxval=1.0,
+ dtype=types.float32, seed=None,
+ name=None):
+ """Outputs random values from a uniform distribution.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
+ the range of random values to generate.
+ dtype: The type of the output.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+ """
+ with ops.op_scope([shape, minval, maxval], name, "random_uniform") as name:
+ shape_tensor = _ShapeTensor(shape)
+ min_tensor = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ range_tensor = ops.convert_to_tensor(
+ maxval - minval, dtype=dtype, name="range")
+ seed1, seed2 = random_seed.get_seed(seed)
+ rnd = gen_random_ops._random_uniform(shape_tensor, dtype,
+ seed=seed1,
+ seed2=seed2)
+ mul = rnd * range_tensor
+ value = math_ops.add(mul, min_tensor, name=name)
+ return value
+
+
+def random_shuffle(value, seed=None, name=None):
+ """Randomly shuffles a tensor along its first dimension.
+
+ The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
+ to one and only one `output[i]`. For example, a mapping that might occur for a
+ 3x2 tensor is:
+
+ ```python
+ [[1, 2], [[5, 6],
+ [3, 4], ==> [1, 2],
+ [5, 6]] [3, 4]]
+ ```
+
+ Args:
+ value: A Tensor to be shuffled.
+ seed: A Python integer. Used to create a random seed for the distribution.
+ See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of same shape and type as `value`, shuffled along its first
+ dimension.
+ """
+ seed1, seed2 = random_seed.get_seed(seed)
+ return gen_random_ops._random_shuffle(value, seed=seed1, seed2=seed2,
+ name=name)
+
+
+ops.NoGradient("RandomUniform")
+
+
+@ops.RegisterShape("TruncatedNormal")
+@ops.RegisterShape("RandomStandardNormal")
+@ops.RegisterShape("RandomUniform")
+def _RandomShape(op):
+ shape_val = tensor_util.ConstantValue(op.inputs[0])
+ if shape_val is not None:
+ return [tensor_shape.TensorShape(shape_val.tolist())]
+ else:
+ shape_shape = op.inputs[0].get_shape().with_rank_at_most(1)
+ return [tensor_shape.unknown_shape(ndims=shape_shape.num_elements())]
+
+
+ops.RegisterShape("RandomShuffle")(common_shapes.unchanged_shape)
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
new file mode 100644
index 0000000000..3685b671b7
--- /dev/null
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -0,0 +1,12 @@
+"""Gradients for operators defined in sparse_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import sparse_ops
+
+
+ops.NoGradient("SparseToDense")
+
+
+ops.NoGradient("SparseConcat")
+
+
+ops.NoGradient("SparseReorder")
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
new file mode 100644
index 0000000000..c0dca6156d
--- /dev/null
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -0,0 +1,458 @@
+"""## Sparse Tensor Representation.
+
+Tensorflow supports a `SparseTensor` representation for data that is sparse
+in multiple dimensions. Contrast this representation with `IndexedSlices`,
+which is efficient for representing tensors that are sparse in their first
+dimension, and dense along all other dimensions.
+
+@@SparseTensor
+@@SparseTensorValue
+
+## Sparse to Dense Conversion.
+
+@@sparse_to_dense
+@@sparse_tensor_to_dense
+@@sparse_to_indicator
+
+## Manipulation.
+
+@@sparse_concat
+@@sparse_reorder
+@@sparse_retain
+@@sparse_fill_empty_rows
+"""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_sparse_ops import *
+# pylint: enable=wildcard-import
+# pylint: disable=protected-access
+
+
+def sparse_concat(concat_dim, sp_inputs, name=None):
+ """Concatenates a list of `SparseTensor` along the specified dimension.
+
+ Concatenation is with respect to the dense versions of each sparse input.
+ It is assumed that each inputs is a `SparseTensor` whose elements are ordered
+ along increasing dimension number.
+
+ All inputs' shapes must match, except for the concat dimension. The
+ `indices`, `values`, and `shapes` lists must have the same length.
+
+ The output shape is identical to the inputs', except along the concat
+ dimension, where it is the sum of the inputs' sizes along that dimension.
+
+ The output elements will be resorted to preserve the sort order along
+ increasing dimension number.
+
+ This op runs in `O(M log M)` time, where `M` is the total number of non-empty
+ values across all inputs. This is due to the need for an internal sort in
+ order to concatenate efficiently across an arbitrary dimension.
+
+ For example, if `concat_dim = 1` and the inputs are
+
+ sp_inputs[0]: shape = [2, 3]
+ [0, 2]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ sp_inputs[1]: shape = [2, 4]
+ [0, 1]: "d"
+ [0, 2]: "e"
+
+ then the output will be
+
+ shape = [2, 7]
+ [0, 2]: "a"
+ [0, 4]: "d"
+ [0, 5]: "e"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ Graphically this is equivalent to doing
+
+ [ a] concat [ d e ] = [ a d e ]
+ [b c ] [ ] [b c ]
+
+ Args:
+ concat_dim: Dimension to concatenate along.
+ sp_inputs: List of `SparseTensor` to concatenate.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A `SparseTensor` with the concatenated output.
+
+ Raises:
+ TypeError: If `sp_inputs` is not a list of `SparseTensor`.
+ """
+ if not isinstance(sp_inputs, list):
+ raise TypeError("Inputs must be a list")
+ if not all(isinstance(sp_input, ops.SparseTensor) for sp_input in sp_inputs):
+ raise TypeError("All inputs must be SparseTensors")
+
+ if len(sp_inputs) == 1: # Degenerate case of one tensor.
+ return sp_inputs[0]
+
+ inds = [sp_input.indices for sp_input in sp_inputs]
+ vals = [sp_input.values for sp_input in sp_inputs]
+ shapes = [sp_input.shape for sp_input in sp_inputs]
+
+ output_ind, output_val, output_shape = (
+ gen_sparse_ops._sparse_concat(
+ inds,
+ vals,
+ shapes,
+ concat_dim,
+ name=name))
+
+ return ops.SparseTensor(output_ind, output_val, output_shape)
+
+
+@ops.RegisterShape("SparseConcat")
+def _SparseConcatShape(op):
+ """Shape function for SparseConcat op."""
+ num_inputs = int(op.get_attr("N"))
+
+ # TF flattens and concatenates all list inputs, so reconstruct the lists here.
+ ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]]
+ val_shapes = [val.get_shape().with_rank(1)
+ for val in op.inputs[num_inputs:2 * num_inputs]]
+ shape_shapes = [shape.get_shape().with_rank(1)
+ for shape in op.inputs[2 * num_inputs:]]
+
+ output_ind_rows = tensor_shape.Dimension(0)
+ output_ind_cols = tensor_shape.Dimension(None)
+ output_val_elems = tensor_shape.Dimension(0)
+ output_shape_shape = tensor_shape.TensorShape(None)
+
+ for i in range(num_inputs):
+ num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
+ output_ind_rows += num_elems_i
+ output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
+ output_val_elems += num_elems_i
+ output_shape_shape = output_shape_shape.merge_with(shape_shapes[i])
+
+ output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols)
+ output_val_shape = tensor_shape.vector(output_val_elems)
+
+ return [output_ind_shape, output_val_shape, output_shape_shape]
+
+
+def sparse_reorder(sp_input, name=None):
+ """Reorders a `SparseTensor` into the canonical, row-major ordering.
+
+ Note that by convention, all sparse ops preserve the canonical ordering
+ along increasing dimension number. The only time ordering can be violated
+ is during manual manipulation of the indices and values to add entries.
+
+ Reordering does not affect the shape of the `SparseTensor`.
+
+ For example, if sp_input has shape `[4, 5]` and `indices` / `values`:
+
+ [0, 3]: b
+ [0, 1]: a
+ [3, 1]: d
+ [2, 0]: c
+
+ then the output will be a `SparseTensor` of shape `[4, 5]` and
+ `indices` / `values`:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ Args:
+ sp_input: The input `SparseTensor`.
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ A `SparseTensor` with the same shape and non-empty values, but in
+ canonical ordering.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ reordered_ind, reordered_val = (
+ gen_sparse_ops._sparse_reorder(
+ sp_input.indices,
+ sp_input.values,
+ sp_input.shape,
+ name=name))
+
+ return ops.SparseTensor(
+ reordered_ind, reordered_val, array_ops.identity(sp_input.shape))
+
+
+@ops.RegisterShape("SparseReorder")
+def _SparseReorderShape(op):
+ """Shape function for SparseReorder op."""
+ input_indices_shape = op.inputs[0].get_shape().with_rank(2)
+ input_values_shape = op.inputs[1].get_shape().with_rank(1)
+ unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
+
+ return [input_indices_shape, input_values_shape]
+
+
+@ops.RegisterShape("SparseToDense")
+def _SparseToDenseShape(op):
+ input_shape = tensor_util.ConstantValue(op.inputs[1])
+ if input_shape is not None:
+ if np.ndim(input_shape) > 1:
+ raise ValueError("Input shape should be a vector")
+ return [tensor_shape.TensorShape(input_shape.tolist())]
+ else:
+ input_shape_shape = op.inputs[1].get_shape().with_rank_at_most(1)
+ return [tensor_shape.unknown_shape(ndims=input_shape_shape.num_elements())]
+
+
+def sparse_tensor_to_dense(sp_input, default_value, name=None):
+ """Converts a `SparseTensor` into a dense tensor.
+
+ This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s.
+
+ For example, if `sp_input` has shape `[3, 5]` and non-empty string values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+
+ and `default_value` is `x`, then the output will be a dense `[3, 5]`
+ string tensor with values:
+
+ [[x a x b x]
+ [x x x x x]
+ [c x x x x]]
+
+ Args:
+ sp_input: The input `SparseTensor`.
+ default_value: Scalar value to set for indices not specified in
+ `sp_input`.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A dense tensor with shape `sp_input.shape` and values specified by
+ the non-empty values in `sp_input`. Indices not in `sp_input` are assigned
+ `default_value`.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ return gen_sparse_ops.sparse_to_dense(
+ sp_input.indices,
+ sp_input.shape,
+ sp_input.values,
+ default_value,
+ name=name)
+
+
+def sparse_to_indicator(sp_input, vocab_size, name=None):
+ """Converts a `SparseTensor` of ids into a dense bool indicator tensor.
+
+ The last dimension of `sp_input` is discarded and replaced with the values of
+ `sp_input`. If `sp_input.shape = [D0, D1, ..., Dn, K]`, then
+ `output.shape = [D0, D1, ..., Dn, vocab_size]`, where
+
+ output[d_0, d_1, ..., d_n, sp_input[d_0, d_1, ..., d_n, k]] = True
+
+ and False elsewhere in `output`.
+
+ For example, if `sp_input.shape = [2, 3, 4]` with non-empty values:
+
+ [0, 0, 0]: 0
+ [0, 1, 0]: 10
+ [1, 0, 3]: 103
+ [1, 1, 2]: 112
+ [1, 1, 3]: 113
+ [1, 2, 1]: 121
+
+ and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool
+ tensor with False everywhere except at positions
+
+ (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121).
+
+ This op is useful for converting `SparseTensor`s into dense formats for
+ compatibility with ops that expect dense tensors.
+
+ The input `SparseTensor` must be in row-major order.
+
+ Args:
+ sp_input: A `SparseTensor` of type `int32` or `int64`.
+ vocab_size: The new size of the last dimension, with
+ `all(0 <= sp_input.values < vocab_size)`.
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ A dense bool indicator tensor representing the indices with specified value.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ with ops.op_scope([sp_input], name, "SparseToIndicator") as name:
+ indices_shape = array_ops.shape(sp_input.indices)
+ num_entries = indices_shape[0]
+ rank = indices_shape[1]
+
+ ids = sp_input.values
+ if ids.dtype != types.int64:
+ ids = math_ops.cast(ids, types.int64)
+
+ # Slice off the last dimension of indices, then then tack on the ids
+ indices_columns_to_preserve = array_ops.slice(
+ sp_input.indices, [0, 0], array_ops.pack([-1, rank - 1]))
+ new_indices = array_ops.concat(
+ 1, [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])])
+
+ new_values = array_ops.fill(array_ops.expand_dims(num_entries, 0), True)
+ new_shape = array_ops.concat(
+ 0, [array_ops.slice(sp_input.shape, [0],
+ array_ops.expand_dims(rank - 1, 0)), [vocab_size]])
+
+ sp_new = ops.SparseTensor(new_indices, new_values, new_shape)
+
+ return sparse_tensor_to_dense(sp_new, False, name=name)
+
+
+def sparse_retain(sp_input, to_retain):
+ """Retains specified non-empty values within a `SparseTensor`.
+
+ For example, if `sp_input` has shape `[4, 5]` and 4 non-empty string values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ and `to_retain = [True, False, False, True]`, then the output will
+ be a `SparseTensor` of shape `[4, 5]` with 2 non-empty values:
+
+ [0, 1]: a
+ [3, 1]: d
+
+ Args:
+ sp_input: The input `SparseTensor` with `N` non-empty elements.
+ to_retain: A bool vector of length `N` with `M` true values.
+
+ Returns:
+ A `SparseTensor` with the same shape as the input and `M` non-empty
+ elements corresponding to the true positions in `to_retain`.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ to_retain = ops.convert_to_tensor(to_retain)
+
+ # Shape checking, if shape is known at graph construction time
+ retain_shape = to_retain.get_shape()
+ retain_shape.assert_has_rank(1)
+ sp_input.values.get_shape()[0].merge_with(retain_shape[0])
+
+ where_true = array_ops.reshape(array_ops.where(to_retain), [-1])
+ new_indices = array_ops.gather(sp_input.indices, where_true)
+ new_values = array_ops.gather(sp_input.values, where_true)
+ return ops.SparseTensor(
+ new_indices, new_values, array_ops.identity(sp_input.shape))
+
+
+def sparse_fill_empty_rows(sp_input, default_value, name=None):
+ """Fills empty rows in the input 2-D `SparseTensor` with a default value.
+
+ This op adds entries with the specified `default_value` at index
+ `[row, 0]` for any row in the input that does not already have a value.
+
+ For example, suppose `sp_input` has shape `[5, 6]` and non-empty values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [2, 0]: c
+ [3, 1]: d
+
+ Rows 1 and 4 are empty, so the output will be of shape `[5, 6]` with values:
+
+ [0, 1]: a
+ [0, 3]: b
+ [1, 0]: default_value
+ [2, 0]: c
+ [3, 1]: d
+ [4, 0]: default_value
+
+ Note that the input may have empty columns at the end, with no effect on
+ this op.
+
+ The output `SparseTensor` will be in row-major order and will have the
+ same shape as the input.
+
+ This op also returns an indicator vector such that
+
+ empty_row_indicator[i] = True iff row i was an empty row.
+
+ Args:
+ sp_input: A `SparseTensor` with shape `[N, M]`.
+ default_value: The value to fill for empty rows, with the same type as
+ `sp_input.`
+ name: A name prefix for the returned tensors (optional)
+
+ Returns:
+ sp_ordered_output: A `SparseTensor` with shape `[N, M]`, and with all empty
+ rows filled in with `default_value`.
+ empty_row_indicator: A bool vector of length `N` indicating whether each
+ input row was empty.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ if not isinstance(sp_input, ops.SparseTensor):
+ raise TypeError("Input must be a SparseTensor")
+
+ with ops.op_scope([sp_input], name, "SparseFillEmptyRows"):
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=sp_input.values.dtype)
+
+ num_rows = math_ops.cast(sp_input.shape[0], types.int32)
+ all_row_indices = math_ops.cast(
+ math_ops.range(0, num_rows, 1), types.int64)
+ empty_row_indices, _ = array_ops.list_diff(
+ all_row_indices, sp_input.indices[:, 0])
+ empty_row_indicator = gen_sparse_ops.sparse_to_dense(
+ empty_row_indices, array_ops.expand_dims(sp_input.shape[0], -1), True,
+ False)
+
+ empty_row_indices_as_column = array_ops.reshape(empty_row_indices, [-1, 1])
+ additional_indices = array_ops.concat(
+ 1,
+ [empty_row_indices_as_column,
+ array_ops.zeros_like(empty_row_indices_as_column)])
+ additional_values = array_ops.fill(array_ops.shape(empty_row_indices),
+ default_value)
+
+ all_indices_unordered = array_ops.concat(
+ 0, [sp_input.indices, additional_indices])
+ all_values_unordered = array_ops.concat(
+ 0, [sp_input.values, additional_values])
+ sp_unordered_output = ops.SparseTensor(
+ all_indices_unordered, all_values_unordered, sp_input.shape)
+ sp_ordered_output = sparse_reorder(sp_unordered_output)
+
+ return sp_ordered_output, empty_row_indicator
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
new file mode 100644
index 0000000000..07a5e6c6da
--- /dev/null
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -0,0 +1,212 @@
+"""Tests for Python ops defined in sparse_ops."""
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import googletest
+
+
+class SparseToIndicatorTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self, dtype):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_2x3x4(self, dtype):
+ ind = np.array([
+ [0, 0, 1],
+ [0, 1, 0], [0, 1, 2],
+ [1, 0, 3],
+ [1, 1, 1], [1, 1, 3],
+ [1, 2, 2]])
+ val = np.array([1, 10, 12, 103, 111, 113, 122])
+ shape = np.array([2, 3, 4])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, dtype),
+ constant_op.constant(shape, types.int64))
+
+ def testInt32(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(types.int32)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = ((0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33))
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testInt64(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6(types.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 50).eval()
+
+ expected_output = np.zeros((5, 50), dtype=np.bool)
+ expected_trues = [(0, 0), (1, 10), (1, 13), (1, 14), (3, 32), (3, 33)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+ def testHigherRank(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_2x3x4(types.int64)
+ output = sparse_ops.sparse_to_indicator(sp_input, 200).eval()
+
+ expected_output = np.zeros((2, 3, 200), dtype=np.bool)
+ expected_trues = [(0, 0, 1), (0, 1, 10), (0, 1, 12),
+ (1, 0, 103), (1, 1, 111), (1, 1, 113), (1, 2, 122)]
+ for expected_true in expected_trues:
+ expected_output[expected_true] = True
+
+ self.assertAllEqual(output, expected_output)
+
+
+class SparseRetainTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def testBasic(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 1, 0], dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 4], [3, 2]])
+ self.assertAllEqual(output.values, [0, 14, 32])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testRetainNone(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.zeros((6,), dtype=np.bool)
+ sp_output = sparse_ops.sparse_retain(sp_input, to_retain)
+
+ output = sess.run(sp_output)
+
+ self.assertAllEqual(output.indices, np.array([]).reshape((0, 2)))
+ self.assertAllEqual(output.values, [])
+ self.assertAllEqual(output.shape, [5, 6])
+
+ def testMismatchedRetainShape(self):
+ with self.test_session(use_gpu=False):
+ sp_input = self._SparseTensor_5x6()
+ to_retain = np.array([1, 0, 0, 1, 0], dtype=np.bool)
+ with self.assertRaises(ValueError):
+ sparse_ops.sparse_retain(sp_input, to_retain)
+
+
+class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
+
+ def _SparseTensor_5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array([0, 10, 13, 14, 32, 33])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_String5x6(self):
+ ind = np.array([
+ [0, 0],
+ [1, 0], [1, 3], [1, 4],
+ [3, 2], [3, 3]])
+ val = np.array(["a", "b", "c", "d", "e", "f"])
+ shape = np.array([5, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.string),
+ constant_op.constant(shape, types.int64))
+
+ def _SparseTensor_2x6(self):
+ ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]])
+ val = np.array([0, 10, 13, 14])
+ shape = np.array([2, 6])
+ return ops.SparseTensor(
+ constant_op.constant(ind, types.int64),
+ constant_op.constant(val, types.int32),
+ constant_op.constant(shape, types.int64))
+
+ def testFillNumber(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14, -1, 32, 33, -1])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testFillString(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_String5x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, ""))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(
+ output.indices,
+ [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]])
+ self.assertAllEqual(output.values, ["a", "b", "c", "d", "", "e", "f", ""])
+ self.assertAllEqual(output.shape, [5, 6])
+ self.assertAllEqual(empty_row_indicator_out,
+ np.array([0, 0, 1, 0, 1]).astype(np.bool))
+
+ def testNoEmptyRows(self):
+ with self.test_session(use_gpu=False) as sess:
+ sp_input = self._SparseTensor_2x6()
+ sp_output, empty_row_indicator = (
+ sparse_ops.sparse_fill_empty_rows(sp_input, -1))
+
+ output, empty_row_indicator_out = sess.run(
+ [sp_output, empty_row_indicator])
+
+ self.assertAllEqual(output.indices, [[0, 0], [1, 0], [1, 3], [1, 4]])
+ self.assertAllEqual(output.values, [0, 10, 13, 14])
+ self.assertAllEqual(output.shape, [2, 6])
+ self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
new file mode 100644
index 0000000000..beef8e75b5
--- /dev/null
+++ b/tensorflow/python/ops/standard_ops.py
@@ -0,0 +1,41 @@
+# pylint: disable=wildcard-import,unused-import
+"""Import names of Tensor Flow standard Ops."""
+
+# Imports the following modules so that @RegisterGradient get executed.
+from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import state_grad
+
+from tensorflow.python.ops.array_ops import *
+from tensorflow.python.ops.clip_ops import *
+# TODO(vrv): Switch to import * once we're okay with exposing the module.
+from tensorflow.python.ops.control_flow_ops import group
+from tensorflow.python.ops.control_flow_ops import no_op
+from tensorflow.python.ops.control_flow_ops import tuple
+from tensorflow.python.ops.data_flow_ops import *
+from tensorflow.python.ops.gradients import *
+from tensorflow.python.ops.init_ops import *
+from tensorflow.python.ops.io_ops import *
+from tensorflow.python.ops.linalg_ops import *
+from tensorflow.python.ops.logging_ops import *
+from tensorflow.python.ops.math_ops import *
+from tensorflow.python.ops.numerics import *
+from tensorflow.python.ops.parsing_ops import *
+from tensorflow.python.ops.random_ops import *
+from tensorflow.python.ops.sparse_ops import *
+from tensorflow.python.ops.state_ops import assign
+from tensorflow.python.ops.state_ops import assign_add
+from tensorflow.python.ops.state_ops import assign_sub
+from tensorflow.python.ops.state_ops import count_up_to
+from tensorflow.python.ops.state_ops import scatter_add
+from tensorflow.python.ops.state_ops import scatter_sub
+from tensorflow.python.ops.state_ops import scatter_update
+from tensorflow.python.ops.string_ops import *
+from tensorflow.python.ops.summary_ops import histogram_summary
+from tensorflow.python.ops.summary_ops import image_summary
+from tensorflow.python.ops.summary_ops import merge_all_summaries
+from tensorflow.python.ops.summary_ops import merge_summary
+from tensorflow.python.ops.summary_ops import scalar_summary
+from tensorflow.python.ops.variable_scope import *
+from tensorflow.python.ops.variables import *
diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py
new file mode 100644
index 0000000000..d9b084693c
--- /dev/null
+++ b/tensorflow/python/ops/state_grad.py
@@ -0,0 +1,18 @@
+"""Gradients for operators defined in state_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import state_ops
+
+ops.NoGradient("Assign")
+
+
+ops.NoGradient("AssignAdd")
+
+
+ops.NoGradient("AssignSub")
+
+
+ops.NoGradient("ScatterAdd")
+
+
+ops.NoGradient("ScatterSub")
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
new file mode 100644
index 0000000000..1c8f38b94c
--- /dev/null
+++ b/tensorflow/python/ops/state_ops.py
@@ -0,0 +1,189 @@
+"""## Variables
+
+@@Variable
+
+## Variable helper functions
+
+TensorFlow provides a set of functions to help manage the set of variables
+collected in the graph.
+
+@@all_variables
+@@trainable_variables
+
+@@initialize_all_variables
+@@initialize_variables
+@@assert_variables_initialized
+
+## Saving and Restoring Variables.
+
+@@Saver
+
+@@latest_checkpoint
+
+@@get_checkpoint_state
+@@update_checkpoint_state
+
+## Sharing Variables
+
+TensorFlow provides several classes and operations that you can use to
+create variables contingent on certain conditions.
+
+@@get_variable
+@@get_variable_scope
+@@variable_scope
+
+@@constant_initializer
+@@random_normal_initializer
+@@truncated_normal_initializer
+@@random_uniform_initializer
+@@uniform_unit_scaling_initializer
+@@zeros_initializer
+
+## Sparse Variable Updates
+
+The sparse update ops modify a subset of the entries in a dense `Variable`,
+either overwriting the entries or adding / subtracting a delta. These are
+useful for training embedding models and similar lookup-based networks, since
+only a small subset of embedding vectors change in any given step.
+
+Since a sparse update of a large tensor may be generated automatically during
+gradient computation (as in the gradient of [`tf.gather`](array_ops.md#gather)),
+an [`IndexedSlices`](#IndexedSlices) class is provided that encapsulates a set
+of sparse indices and values. `IndexedSlices` objects are detected and handled
+automatically by the optimizers in most cases.
+
+@@scatter_update
+@@scatter_add
+@@scatter_sub
+@@sparse_mask
+@@IndexedSlices
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_state_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_state_ops import *
+
+
+# pylint: disable=protected-access
+def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
+ shared_name=""):
+ """Create a variable Operation.
+
+ See also variables.Variable.
+
+ Args:
+ shape: The shape of the tensor managed by this variable
+ dtype: The underlying type of the tensor values.
+ name: optional name to use for the variable op.
+ set_shape: If True, set the shape property of the returned Tensor to
+ the shape argument.
+ container: An optional string. Defaults to "".
+ If non-empty, this variable is placed in the given container.
+ Otherwise, a default container is used.
+ shared_name: An optional string. Defaults to "".
+ If non-empty, this variable is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+
+ Returns:
+ A variable tensor.
+ """
+ ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
+ container=container, shared_name=shared_name)
+ # TODO(mrry): Move this to where it is used, so we can get rid of this op
+ # wrapper?
+ if set_shape:
+ ret.set_shape(shape)
+ return ret
+
+
+# NOTE(mrry): Shapes are conditionally set in the Python wrapper.
+ops.RegisterShape("Variable")(common_shapes.unknown_shape)
+
+
+@ops.RegisterShape("TemporaryVariable")
+def _TemporaryVariableShape(op):
+ """Shape function for the TemporaryVariable op."""
+ shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
+ return [tensor_shape.TensorShape(shape)]
+
+
+@ops.RegisterShape("DestroyTemporaryVariable")
+def _DestroyTemporaryVariableShape(op):
+ """Shape function for the DestroyTemporaryVariable op."""
+ return [op.inputs[0].get_shape()]
+
+
+def init_variable(v, init, name="init"):
+ """Initializes variable with "init".
+
+ This op does the following:
+ if init is a Tensor, v = init
+ if callable(init): v = init(VariableShape(v), v.dtype)
+
+ Args:
+ v: Variable to initialize
+ init: Tensor to assign to v,
+ Or an object convertible to Tensor e.g. nparray,
+ Or an Initializer that generates a tensor given the shape and type of v.
+ An "Initializer" is a callable that returns a tensor that "v" should be
+ set to. It will be called as init(shape, dtype).
+ name: Optional name for the op.
+
+ Returns:
+ The operation that initializes v.
+ """
+ with ops.op_scope([v, init], None, v.op.name + "/"):
+ with ops.name_scope(name) as scope:
+ with ops.device(v.device or ops.get_default_graph().get_default_device()):
+ if callable(init):
+ assert v.get_shape().is_fully_defined(), "Variable shape unknown."
+ # TODO(mrry): Convert to v.shape when the property and
+ # accessor are reconciled (and all initializers support
+ # tf.TensorShape objects).
+ value = init(v.get_shape().as_list(), v.dtype.base_dtype)
+ value = ops.convert_to_tensor(value, name="value")
+ return assign(v, value, name=scope)
+ else:
+ init = ops.convert_to_tensor(init, name="init")
+ return assign(v, init, name=scope)
+
+
+@ops.RegisterShape("Assign")
+def _AssignShape(op):
+ """Shape function for the Assign op."""
+ if op.get_attr("validate_shape"):
+ # NOTE(mrry): Return a known shape here. This makes it awkward to
+ # chain a validated-shape assignment and a reshaping assignment,
+ # but that is a sufficiently niche case that supporting it does
+ # not seem worthwhile.
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+ return [op.inputs[1].get_shape()]
+
+
+@ops.RegisterShape("AssignAdd")
+@ops.RegisterShape("AssignSub")
+def _AssignUpdateShape(op):
+ """Shape function for the AssignAdd and AssignSub dense update ops."""
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+
+
+@ops.RegisterShape("CountUpTo")
+def _CountUpToShape(op):
+ """Shape function for the CountUpTo op."""
+ return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
+
+
+@ops.RegisterShape("ScatterAdd")
+@ops.RegisterShape("ScatterSub")
+@ops.RegisterShape("ScatterUpdate")
+def _ScatterUpdateShape(op):
+ """Shape function for the sparse update ops."""
+ var_shape = op.inputs[0].get_shape()
+ indices_shape = op.inputs[1].get_shape()
+ unused_updates_shape = op.inputs[2].get_shape().merge_with(
+ indices_shape.concatenate(var_shape[1:]))
+ return [var_shape]
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
new file mode 100644
index 0000000000..8181fe9a2a
--- /dev/null
+++ b/tensorflow/python/ops/string_ops.py
@@ -0,0 +1,12 @@
+"""String Ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_string_ops
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.python.ops.gen_string_ops import *
+
+ops.NoGradient("StringToHashBucket")
+
+ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
new file mode 100644
index 0000000000..d65fd1ea7c
--- /dev/null
+++ b/tensorflow/python/ops/summary_ops.py
@@ -0,0 +1,177 @@
+"""Summary Operations."""
+# pylint: disable=wildcard-import,protected-access
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_summary_ops
+from tensorflow.python.ops.gen_summary_ops import *
+
+
+def _Collect(val, collections, default_collections):
+ if collections is None:
+ collections = default_collections
+ for key in collections:
+ ops.add_to_collection(key, val)
+
+
+def histogram_summary(tag, values, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with a histogram.
+
+ The generated
+ [`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ has one summary value containing a histogram for `values`.
+
+ This op reports an `OutOfRange` error if any value is not finite.
+
+ Args:
+ tag: A `string` `Tensor`. 0-D. Tag to use for the summary value.
+ values: A `float32` `Tensor`. Any shape. Values to use to build the
+ histogram.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tag, values], name, "HistogramSummary") as scope:
+ val = gen_summary_ops._histogram_summary(
+ tag=tag, values=values, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+def image_summary(tag, tensor, max_images=None, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with images.
+
+ The summary has up to `max_images` summary values containing images. The
+ images are built from `tensor` which must be 4-D with shape `[batch_size,
+ height, width, channels]` and where `channels` can be:
+
+ * 1: `tensor` is interpreted as Grayscale.
+ * 3: `tensor` is interpreted as RGB.
+ * 4: `tensor` is interpreted as RGBA.
+
+ The images have the same number of channels as the input tensor. Their values
+ are normalized, one image at a time, to fit in the range `[0, 255]`. The
+ op uses two different normalization algorithms:
+
+ * If the input values are all positive, they are rescaled so the largest one
+ is 255.
+
+ * If any input value is negative, the values are shifted so input value 0.0
+ is at 127. They are then rescaled so that either the smallest value is 0,
+ or the largest one is 255.
+
+ The `tag` argument is a scalar `Tensor` of type `string`. It is used to
+ build the `tag` of the summary values:
+
+ * If `max_images` is 1, the summary value tag is '*tag*/image'.
+ * If `max_images` is greater than 1, the summary value tags are
+ generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
+
+ Args:
+ tag: A scalar `Tensor` of type `string`. Used to build the `tag`
+ of the summary values.
+ tensor: A 4-D `float32` `Tensor` of shape `[batch_size, height, width,
+ channels]` where `channels` is 1, 3, or 4.
+ max_images: Max number of batch elements to generate images for.
+ collections: Optional list of ops.GraphKeys. The collections to add the
+ summary to. Defaults to [ops.GraphKeys.SUMMARIES]
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tag, tensor], name, "ImageSummary") as scope:
+ val = gen_summary_ops._image_summary(
+ tag=tag, tensor=tensor, max_images=max_images, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+def merge_summary(inputs, collections=None, name=None):
+ """Merges summaries.
+
+ This op creates a
+ [`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ protocol buffer that contains the union of all the values in the input
+ summaries.
+
+ When the Op is run, it reports an `InvalidArgument` error if multiple values
+ in the summaries to merge use the same tag.
+
+ Args:
+ inputs: A list of `string` `Tensor` objects containing serialized `Summary`
+ protocol buffers.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer resulting from the merging.
+ """
+ with ops.op_scope(inputs, name, "MergeSummary") as scope:
+ val = gen_summary_ops._merge_summary(inputs=inputs, name=name)
+ _Collect(val, collections, [])
+ return val
+
+
+def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
+ """Merges all summaries collected in the default graph.
+
+ Args:
+ key: `GraphKey` used to collect the summaries. Defaults to
+ `GraphKeys.SUMMARIES`.
+
+ Returns:
+ If no summaries were collected, returns None. Otherwise returns a scalar
+ `Tensor` of type`string` containing the serialized `Summary` protocol
+ buffer resulting from the merging.
+ """
+ summary_ops = ops.get_collection(key)
+ if not summary_ops:
+ return None
+ else:
+ return merge_summary(summary_ops)
+
+
+def scalar_summary(tags, values, collections=None, name=None):
+ """Outputs a `Summary` protocol buffer with scalar values.
+
+ The input `tags` and `values` must have the same shape. The generated
+ summary has a summary value for each tag-value pair in `tags` and `values`.
+
+ Args:
+ tags: A 1-D `string` `Tensor`. Tags for the summaries.
+ values: A 1-D `float32` or `float64` Tensor. Values for the summaries.
+ collections: Optional list of graph collections keys. The new summary op is
+ added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+ buffer.
+ """
+ with ops.op_scope([tags, values], name, "ScalarSummary") as scope:
+ val = gen_summary_ops._scalar_summary(tags=tags, values=values, name=scope)
+ _Collect(val, collections, [ops.GraphKeys.SUMMARIES])
+ return val
+
+
+ops.NoGradient("HistogramAccumulatorSummary")
+ops.NoGradient("HistogramSummary")
+ops.NoGradient("ImageSummary")
+ops.NoGradient("MergeSummary")
+ops.NoGradient("ScalarSummary")
+
+
+@ops.RegisterShape("HistogramAccumulatorSummary")
+@ops.RegisterShape("HistogramSummary")
+@ops.RegisterShape("ImageSummary")
+@ops.RegisterShape("MergeSummary")
+@ops.RegisterShape("ScalarSummary")
+def _ScalarShape(unused_op):
+ return [tensor_shape.scalar()]
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
new file mode 100644
index 0000000000..c9c2cac0a5
--- /dev/null
+++ b/tensorflow/python/ops/variable_scope.py
@@ -0,0 +1,333 @@
+"""A class to store named variables and a scope operator to manage sharing."""
+
+import contextlib
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import logging
+
+
+class _VariableStore(object):
+ """Variable store that carries a number of named Variables.
+
+ New variable names and new variables can be created; all stored
+ variables are initialized with the initializer passed to __init__.
+
+ Attributes:
+ vars: a dictionary with string names (same as passed in GetVar) as keys
+ and the corresponding TensorFlow Variables as values.
+ """
+
+ def __init__(self):
+ """Create a variable store."""
+ self._vars = {} # A dictionary of the stored TensorFlow variables.
+
+ def get_variable(self, name, shape=None, dtype=types.float32,
+ initializer=None, reuse=None, trainable=True,
+ collections=None):
+ """Gets an existing variable with these parameters or create a new one.
+
+ If a variable with the given name is already stored, we return the stored
+ variable. Otherwise, we create a new one.
+
+ Set `reuse` to `True` when you only want to reuse existing Variables.
+ Set `reuse` to `False` when you only want to create new Variables.
+ If `reuse` is `None` (the default), both new and existing variables are
+ returned.
+
+ If initializer is `None` (the default), the default initializer passed in
+ the constructor is used. If that one is `None` too, we use a new
+ `UniformUnitScalingInitializer`.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable.
+ reuse: a Boolean or `None`. Controls reuse or creation of variables.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
+ collections: List of graph collections keys to add the Variable to.
+ Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
+
+ Returns:
+ The created or existing variable.
+
+ Raises:
+ ValueError: when creating a new variable and shape is not declared,
+ when reusing a variable and specifying a conflicting shape,
+ or when violating reuse during variable creation.
+ """
+ should_check = reuse is not None
+ dtype = types.as_dtype(dtype)
+ shape = tensor_shape.as_shape(shape)
+ if name in self._vars:
+ # Here we handle the case when returning an existing variable.
+ if should_check and not reuse:
+ raise ValueError("Over-sharing: Variable %s already exists, disallowed."
+ " Did you mean to set reuse=True in VarScope?" % name)
+ found_var = self._vars[name]
+ if not shape.is_compatible_with(found_var.get_shape()):
+ raise ValueError("Trying to share variable %s, but specified shape %s"
+ " and found shape %s." % (name, str(shape),
+ str(found_var.get_shape())))
+ if not dtype.is_compatible_with(found_var.dtype):
+ dtype_str = dtype.name
+ found_type_str = found_var.dtype.name
+ raise ValueError("Trying to share variable %s, but specified dtype %s"
+ " and found dtype %s." % (name, str(dtype_str),
+ str(found_type_str)))
+ return found_var
+
+ # The code below handles only the case of creating a new variable.
+ if should_check and reuse:
+ raise ValueError("Under-sharing: Variable %s does not exist, disallowed."
+ " Did you mean to set reuse=None in VarScope?" % name)
+ if not shape.is_fully_defined():
+ raise ValueError("Shape of a new variable (%s) must be fully defined, "
+ "but instead was %s." % (name, shape))
+ if initializer is None:
+ initializer = init_ops.uniform_unit_scaling_initializer()
+ with ops.name_scope(name + "/Initializer/"):
+ init_val = initializer(shape.as_list(), dtype=dtype)
+ v = variables.Variable(init_val, name=name, trainable=trainable,
+ collections=collections)
+ self._vars[name] = v
+ logging.info("Created variable %s with shape %s and init %s", v.name,
+ format(shape), str(initializer))
+ return v
+
+
+class _VariableScope(object):
+ """Variable scope object to carry defaults to provide to get_variable.
+
+ Many of the arguments we need for get_variable in a variable store are most
+ easily handled with a context. This object is used for the defaults.
+
+ Attributes:
+ name: name of the current scope, used as prefix in get_variable.
+ initializer: default initializer passed to get_variable.
+ reuse: Boolean or None, setting the reuse in get_variable.
+ """
+
+ def __init__(self, reuse, name="", initializer=None):
+ self._name = name
+ self._initializer = initializer
+ self._reuse = reuse
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def reuse(self):
+ return self._reuse
+
+ @property
+ def initializer(self):
+ return self._initializer
+
+ def reuse_variables(self):
+ """Reuse variables in this scope."""
+ self._reuse = True
+
+ def set_initializer(self, initializer):
+ """Set initializer for this scope."""
+ self._initializer = initializer
+
+ def get_variable(self, var_store, name, shape=None, dtype=types.float32,
+ initializer=None, trainable=True, collections=None):
+ """Gets an existing variable with this name or create a new one."""
+ if initializer is None and self._initializer:
+ initializer = self._initializer
+ full_name = self.name + "/" + name if self.name else name
+ # Variable names only depend on variable_scope (full_name here),
+ # not name_scope, so we reset it below for the time of variable creation.
+ with ops.name_scope(None):
+ return var_store.get_variable(full_name, shape, dtype, initializer,
+ self.reuse, trainable, collections)
+
+
+_VARSTORE_KEY = ("__variable_store",)
+_VARSCOPE_KEY = ("__varscope",)
+
+
+def get_variable_scope():
+ """Returns the current variable scope."""
+ scope = ops.get_collection(_VARSCOPE_KEY)
+ if scope: # This collection has at most 1 element, the default scope at [0].
+ return scope[0]
+ scope = _VariableScope(False)
+ ops.add_to_collection(_VARSCOPE_KEY, scope)
+ return scope
+
+
+def _get_default_variable_store():
+ store = ops.get_collection(_VARSTORE_KEY)
+ if store:
+ return store[0]
+ store = _VariableStore()
+ ops.add_to_collection(_VARSTORE_KEY, store)
+ return store
+
+
+def get_variable(name, shape=None, dtype=types.float32, initializer=None,
+ trainable=True, collections=None):
+ """Gets an existing variable with these parameters or create a new one.
+
+ This function prefixes the name with the current variable scope
+ and performs reuse checks. See the
+ [Variable Scope How To](../../how_tos/variable_scope/index.md)
+ for an extensive description of how reusing works. Here is a basic example:
+
+ ```python
+ with tf.variable_scope("foo"):
+ v = get_variable("v", [1]) # v.name == "foo/v:0"
+ w = get_variable("w", [1]) # w.name == "foo/w:0"
+ with tf.variable_scope("foo", reuse=True)
+ v1 = get_variable("v") # The same as v above.
+ ```
+
+ If initializer is `None` (the default), the default initializer passed in
+ the constructor is used. If that one is `None` too, a
+ `UniformUnitScalingInitializer` will be used.
+
+ Args:
+ name: the name of the new or existing variable.
+ shape: shape of the new or existing variable.
+ dtype: type of the new or existing variable (defaults to `DT_FLOAT`).
+ initializer: initializer for the variable if one is created.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable).
+ collections: List of graph collections keys to add the Variable to.
+ Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
+
+ Returns:
+ The created or existing variable.
+
+ Raises:
+ ValueError: when creating a new variable and shape is not declared,
+ or when violating reuse during variable creation. Reuse is set inside
+ `variable_scope`.
+ """
+ return get_variable_scope().get_variable(_get_default_variable_store(), name,
+ shape, dtype, initializer,
+ trainable, collections)
+
+
+@contextlib.contextmanager
+def variable_scope(name_or_scope, reuse=None, initializer=None):
+ """Returns a context for variable scope.
+
+ Variable scope allows to create new variables and to share already created
+ ones while providing checks to not create or share by accident. For details,
+ see the [Variable Scope How To](../../how_tos/variable_scope/index.md),
+ here we present only a few basic examples.
+
+ Simple example of how to create a new variable:
+
+ ```python
+ with tf.variable_scope("foo"):
+ with tf.variable_scope("bar"):
+ v = tf.get_variable("v", [1])
+ assert v.name == "foo/bar/v:0"
+ ```
+
+ Basic example of sharing a variable:
+
+ ```python
+ with tf.variable_scope("foo"):
+ v = get_variable("v", [1])
+ with tf.variable_scope("foo", reuse=True):
+ v1 = tf.get_variable("v", [1])
+ assert v1 == v
+ ```
+
+ Sharing a variable by capturing a scope and setting reuse:
+
+ ```python
+ with tf.variable_scope("foo") as scope.
+ v = get_variable("v", [1])
+ scope.reuse_variables()
+ v1 = tf.get_variable("v", [1])
+ assert v1 == v
+ ```
+
+ To prevent accidental sharing of variables, we raise an exception when
+ getting an existing variable in a non-reusing scope.
+
+ ```python
+ with tf.variable_scope("foo") as scope.
+ v = get_variable("v", [1])
+ v1 = tf.get_variable("v", [1])
+ # Raises ValueError("... v already exists ...").
+ ```
+
+ Similarly, we raise an exception when trying to get a variable that
+ does not exist in reuse mode.
+
+ ```python
+ with tf.variable_scope("foo", reuse=True):
+ v = get_variable("v", [1])
+ # Raises ValueError("... v does not exists ...").
+ ```
+
+ Note that the `reuse` flag is inherited: if we open a reusing scope,
+ then all its sub-scopes become reusing as well.
+
+ Args:
+ name_or_scope: `string` or `VariableScope`: the scope to open.
+ reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as
+ well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
+ initializer: default initializer for variables within this scope.
+
+ Yields:
+ A scope that can be to captured and reused.
+
+ Raises:
+ ValueError: when trying to reuse within a create scope, or create within
+ a reuse scope, or if reuse is not `None` or `True`.
+ TypeError: when the types of some arguments are not appropriate.
+ """
+ if not isinstance(name_or_scope, (_VariableScope, basestring)):
+ raise TypeError("VariableScope: name_scope must be a string or "
+ "VariableScope.")
+ if reuse not in [None, True]:
+ raise ValueError("VariableScope reuse parameter must be True or None.")
+ if not reuse and isinstance(name_or_scope, (_VariableScope)):
+ logging.info("Passing VariableScope to a non-reusing scope, intended?")
+ if reuse and isinstance(name_or_scope, (basestring)):
+ logging.info("Re-using string-named scope, consider capturing as object.")
+ get_variable_scope() # Ensure that a default exists, then get a pointer.
+ default_varscope = ops.get_collection(_VARSCOPE_KEY)
+ try:
+ old = default_varscope[0]
+ reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
+ if isinstance(name_or_scope, _VariableScope):
+ # Handler for the case when we jump to a shared scope.
+ # In this case, we leave the current name_scope unchanged.
+ # We create a new VariableScope (default_varscope[0]) that contains
+ # a copy of the provided shared scope, possibly with changed reuse
+ # and initializer, if the user requested this.
+ default_varscope[0] = _VariableScope(reuse, name_or_scope.name,
+ name_or_scope.initializer)
+ if initializer:
+ default_varscope[0].set_initializer(initializer)
+ yield default_varscope[0]
+ else:
+ # Handler for the case when we just prolong current variable scope.
+ # In this case we prolong the current name_scope and create a new
+ # VariableScope with name extended by the provided one, and inherited
+ # reuse and initializer (except if the user provided values to set).
+ with ops.name_scope(name_or_scope):
+ new_name = old.name + "/" + name_or_scope if old.name else name_or_scope
+ default_varscope[0] = _VariableScope(reuse, name=new_name,
+ initializer=old.initializer)
+ if initializer:
+ default_varscope[0].set_initializer(initializer)
+ yield default_varscope[0]
+ finally:
+ default_varscope[0] = old
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
new file mode 100644
index 0000000000..dafd3b8bdc
--- /dev/null
+++ b/tensorflow/python/ops/variables.py
@@ -0,0 +1,569 @@
+"""Variable class."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+
+
+class Variable(object):
+ """See the [Variables How To](../../how_tos/variables/index.md) for a high
+ level overview.
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + b)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `initialize_all_variables()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize all variables.
+ init_op = tf.initialize_all_variables()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes all variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.VARIABLES`. The convenience function
+ `all_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ betwen variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+
+ Creating a variable.
+
+ @@__init__
+ @@initialized_value
+
+ Changing a variable value.
+
+ @@assign
+ @@assign_add
+ @@assign_sub
+ @@scatter_sub
+ @@count_up_to
+
+ @@eval
+
+ Properties.
+
+ @@name
+ @@dtype
+ @@get_shape
+ @@device
+ @@initializer
+ @@graph
+ @@op
+ """
+
+ def __init__(self, initial_value, trainable=True, collections=None,
+ validate_shape=True, name=None):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
+ The initial value for the Variable. Must have a shape specified unless
+ `validate_shape` is set to False.
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+
+ Returns:
+ A Variable.
+
+ Raises:
+ ValueError: If the initial value does not have a shape and
+ `validate_shape` is `True`.
+ """
+ if collections is None:
+ collections = [ops.GraphKeys.VARIABLES]
+ if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
+ # pylint: disable=g-no-augmented-assignment
+ #
+ # Pylint wants us to write collections += [...TRAINABLE_VARIABLES] which
+ # is not the same (it modifies the list in place.) Here, we only want to
+ # modify the value of the variable, not the list.
+ collections = collections + [ops.GraphKeys.TRAINABLE_VARIABLES]
+ # pylint: enable=g-no-augmented-assignment
+ with ops.op_scope([initial_value], name, "Variable") as name:
+ self._initial_value = ops.convert_to_tensor(initial_value,
+ name="initial_value")
+ if not self._initial_value.get_shape().is_fully_defined():
+ if validate_shape:
+ raise ValueError(
+ "initial_value must have a shape specified: %s"
+ % self._initial_value)
+ self._variable = state_ops.variable_op(
+ [], self._initial_value.dtype.base_dtype, set_shape=False,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value, validate_shape=False).op
+ else:
+ self._variable = state_ops.variable_op(
+ self._initial_value.get_shape(),
+ self._initial_value.dtype.base_dtype,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value).op
+ for key in collections:
+ ops.add_to_collection(key, self)
+ self._save_slice_info = None
+
+ def _as_graph_element(self):
+ """Conversion function for Graph.as_graph_element()."""
+ return self._variable
+
+ def _AsTensor(self):
+ """Conversion function for ops.convert_to_tensor()."""
+ return self._variable
+
+ def eval(self, session=None):
+ """In a session, computes and returns the value of this variable.
+
+ This is not a graph construction method, it does not add ops to the graph.
+
+ This convenience method requires a session where the graph containing this
+ variable has been launched. If no session is passed, the default session is
+ used. See the [Session class](../client.md#Session) for more information on
+ launching a graph and on sessions.
+
+ ```python
+ v = tf.Variable([1, 2])
+ init = tf.initialize_all_variables()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ # Usage passing the session explicitly.
+ print v.eval(sess)
+ # Usage with the default session. The 'with' block
+ # above makes 'sess' the default session.
+ print v.eval()
+ ```
+
+ Args:
+ session: The session to use to evaluate this variable. If
+ none, the default session is used.
+
+ Returns:
+ A numpy `ndarray` with a copy of the value of this variable.
+ """
+ return self._variable.eval(session=session)
+
+ def initialized_value(self):
+ """Returns the value of the initialized variable.
+
+ You should use this instead of the variable itself to initialize another
+ variable with a value that depends on the value of this variable.
+
+ ```python
+ # Initialize 'v' with a random tensor.
+ v = tf.Variable(tf.truncated_normal([10, 40]))
+ # Use `initialized_value` to guarantee that `v` has been
+ # initialized before its value is used to initialize `w`.
+ # The random values are picked only once.
+ w = tf.Variable(v.initialized_value() * 2.0)
+ ```
+
+ Returns:
+ A `Tensor` holding the value of this variable after its initializer
+ has run.
+ """
+ return control_flow_ops.with_dependencies(
+ [self._initializer_op], self._variable)
+
+ def assign(self, value, use_locking=False):
+ """Assigns a new value to the variable.
+
+ This is essentially a shortcut for `assign(self, value)`.
+
+ Args:
+ value: A `Tensor`. The new value for this variable.
+ use_locking: If `True`, use locking during the assignment.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the assignment has completed.
+ """
+ return state_ops.assign(self._variable, value, use_locking=use_locking)
+
+ def assign_add(self, delta, use_locking=False):
+ """Adds a value to this variable.
+
+ This is essentially a shortcut for `assign_add(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to add to this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the addition has completed.
+ """
+ return state_ops.assign_add(self._variable, delta, use_locking=use_locking)
+
+ def assign_sub(self, delta, use_locking=False):
+ """Subtracts a value from this variable.
+
+ This is essentially a shortcut for `assign_sub(self, delta)`.
+
+ Args:
+ delta: A `Tensor`. The value to subtract from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the subtraction has completed.
+ """
+ return state_ops.assign_sub(self._variable, delta, use_locking=use_locking)
+
+ def scatter_sub(self, sparse_delta, use_locking=False):
+ """Subtracts `IndexedSlices` from this variable.
+
+ This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
+ sparse_delta.values)`.
+
+ Args:
+ sparse_delta: `IndexedSlices` to be subtracted from this variable.
+ use_locking: If `True`, use locking during the operation.
+
+ Returns:
+ A `Tensor` that will hold the new value of this variable after
+ the scattered subtraction has completed.
+
+ Raises:
+ ValueError: if `sparse_delta` is not an `IndexedSlices`.
+ """
+ if not isinstance(sparse_delta, ops.IndexedSlices):
+ raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
+ return state_ops.scatter_sub(self._variable,
+ sparse_delta.indices,
+ sparse_delta.values,
+ use_locking=use_locking)
+
+ def count_up_to(self, limit):
+ """Increments this variable until it reaches `limit`.
+
+ When that Op is run it tries to increment the variable by `1`. If
+ incrementing the variable would bring it above `limit` then the Op raises
+ the exception `OutOfRangeError`.
+
+ If no error is raised, the Op outputs the value of the variable before
+ the increment.
+
+ This is essentially a shortcut for `count_up_to(self, limit)`.
+
+ Args:
+ limit: value at which incrementing the variable raises an error.
+
+ Returns:
+ A `Tensor` that will hold the variable value before the increment. If no
+ other Op modifies this variable, the values produced will all be
+ distinct.
+ """
+ return state_ops.count_up_to(self._variable, limit=limit)
+
+ # Conversion to tensor.
+ @staticmethod
+ def _TensorConversionFunction(v, dtype=None, name=None):
+ """Utility function for converting a Variable to a Tensor."""
+ _ = name
+ ret = v._AsTensor() # pylint: disable=protected-access
+ if dtype and not dtype.is_compatible_with(v.dtype):
+ raise ValueError(
+ "Incompatible type conversion requested to type '%s' for variable "
+ "of type '%s'" % (dtype.name, v.dtype.name))
+ return ret
+
+ # Operator overloading.
+ #
+ # To carry over all overloaded operators from ops.Tensor to Variable, we
+ # register the _RunOp() static method as the implementation of all operators.
+ # That function dynamically discovers the overloaded operator in ops.Tensor
+ # and invokes it after converting the Variable to a tensor.
+ @staticmethod
+ def _OverloadAllOperators():
+ """Register overloads for all operators."""
+ for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
+ Variable._OverloadOperator(operator)
+
+ @staticmethod
+ def _OverloadOperator(operator):
+ """Register _RunOp as the implementation of 'operator'.
+
+ Args:
+ operator: string. The operator name.
+ """
+ if operator in ["__invert__", "__neg__", "__abs__"]:
+ setattr(Variable, operator, lambda a: Variable._RunOp(operator, a, None))
+ else:
+ setattr(Variable, operator, lambda a, b: Variable._RunOp(operator, a, b))
+
+ @staticmethod
+ def _RunOp(operator, a, b):
+ """Run the operator 'op' for 'a'.
+
+ Args:
+ operator: string. The operator name.
+ a: A Variable.
+ b: Second argument to the operator. None if unary.
+ Returns:
+ The result of the operator.
+ """
+ # pylint: disable=protected-access
+ if b is not None:
+ return getattr(ops.Tensor, operator)(a._AsTensor(), b)
+ else:
+ return getattr(ops.Tensor, operator)(a._AsTensor())
+ # pylint: enable=protected-access
+
+ @property
+ def name(self):
+ """The name of this variable."""
+ return self._variable.name
+
+ @property
+ def initializer(self):
+ """The initializer operation for this variable."""
+ return self._initializer_op
+
+ @property
+ def device(self):
+ """The device of this variable."""
+ return self._variable.device
+
+ @property
+ def dtype(self):
+ """The `DType` of this variable."""
+ return self._variable.dtype
+
+ @property
+ def op(self):
+ """The `Operation` of this variable."""
+ return self._variable.op
+
+ @property
+ def graph(self):
+ """The `Graph` of this variable."""
+ return self._variable.graph
+
+ def get_shape(self):
+ """The `TensorShape` of this variable.
+
+ Returns:
+ A `TensorShape`.
+ """
+ return self._variable.get_shape()
+
+ # Experimental support for saving variables as slices of a larger variable.
+ class SaveSliceInfo(object):
+ """Information on how to save this Variable as a slice."""
+
+ def __init__(self, name, spec):
+ """Create a SliceInfo.
+
+ Args:
+ name: Name of the larger Tensor that this variable is a slice of.
+ spec: Slice specification for the saver.
+ """
+ self.name = name
+ self.spec = spec
+
+ def _set_save_slice_info(self, save_slice_info):
+ """Sets the slice info for this Variable.
+
+ Args:
+ save_slice_info: A Variable.SliceInfo object.
+ """
+ self._save_slice_info = save_slice_info
+
+
+def all_variables():
+ """Returns all variables collected in the graph.
+
+ The `Variable()` constructor automatically adds new variables to the graph
+ collection `GraphKeys.VARIABLES`. This convenience function returns the
+ contents of that collection.
+
+ Returns:
+ A list of `Variable` objects.
+ """
+ return ops.get_collection(ops.GraphKeys.VARIABLES)
+
+
+def trainable_variables():
+ """Returns all variables created with `trainable=True`.
+
+ When passed `trainable=True`, the `Variable()` constructor automatically
+ adds new variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
+ contents of that collection.
+
+ Returns:
+ A list of Variable objects.
+ """
+ return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+
+
+def initialize_variables(var_list, name="init"):
+ """Returns an Op that initializes a list of variables.
+
+ After you launch the graph in a session, you can run the returned Op to
+ initialize all the variables in `var_list`. This Op runs all the
+ initializers of the variables in `var_list` in parallel.
+
+ Calling `initialize_variables()` is equivalent to passing the list of
+ initializers to `Group()`.
+
+ If `var_list` is empty, however, the function still returns an Op that can
+ be run. That Op just has no effect.
+
+ Args:
+ var_list: List of `Variable` objects to initialize.
+ name: Optional name for the returned operation.
+
+ Returns:
+ An Op that run the initializers of all the specified variables.
+ """
+ if var_list:
+ return control_flow_ops.group(
+ *[v.initializer for v in var_list], name=name)
+ return control_flow_ops.no_op(name=name)
+
+
+def initialize_all_variables():
+ """Returns an Op that initializes all variables.
+
+ This is just a shortcut for `initialize_variables(all_variables())`
+
+ Returns:
+ An Op that initializes all variables in the graph.
+ """
+ return initialize_variables(all_variables())
+
+
+def assert_variables_initialized(var_list=None):
+ """Returns an Op to check if variables are initialized.
+
+ When run, the returned Op will raise the exception `FailedPreconditionError`
+ if any of the variables has not yet been initialized.
+
+ Note: This function is implemented by trying to fetch the values of the
+ variables. If one of the variables is not initialized a message may be
+ logged by the C++ runtime. This is expected.
+
+ Args:
+ var_list: List of `Variable` objects to check. Defaults to the
+ value of `all_variables().`
+
+ Returns:
+ An Op, or None if there are no variables.
+ """
+ if var_list is None:
+ var_list = all_variables()
+ # Backwards compatibility for old-style variables. TODO(mdevin): remove.
+ if not var_list:
+ var_list = []
+ for op in ops.get_default_graph().get_operations():
+ if op.type in ["Variable", "AutoReloadVariable"]:
+ var_list.append(op.outputs[0])
+ if not var_list:
+ return None
+ else:
+ ranks = []
+ for var in var_list:
+ with ops.device(var.device):
+ ranks.append(array_ops.rank(var))
+ if len(ranks) == 1:
+ return ranks[0]
+ else:
+ return array_ops.pack(ranks)
+
+
+# pylint: disable=protected-access
+ops.register_tensor_conversion_function(Variable,
+ Variable._TensorConversionFunction)
+Variable._OverloadAllOperators()
+# pylint: enable=protected-access