aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/array_ops.cc892
-rw-r--r--tensorflow/core/ops/attention_ops.cc54
-rw-r--r--tensorflow/core/ops/candidate_sampling_ops.cc351
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc179
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc357
-rw-r--r--tensorflow/core/ops/image_ops.cc273
-rw-r--r--tensorflow/core/ops/io_ops.cc332
-rw-r--r--tensorflow/core/ops/linalg_ops.cc97
-rw-r--r--tensorflow/core/ops/logging_ops.cc43
-rw-r--r--tensorflow/core/ops/math_ops.cc1053
-rw-r--r--tensorflow/core/ops/nn_ops.cc543
-rw-r--r--tensorflow/core/ops/no_op.cc10
-rw-r--r--tensorflow/core/ops/parsing_ops.cc104
-rw-r--r--tensorflow/core/ops/random_ops.cc108
-rw-r--r--tensorflow/core/ops/sendrecv_ops.cc99
-rw-r--r--tensorflow/core/ops/sparse_ops.cc134
-rw-r--r--tensorflow/core/ops/state_ops.cc290
-rw-r--r--tensorflow/core/ops/string_ops.cc21
-rw-r--r--tensorflow/core/ops/summary_ops.cc115
-rw-r--r--tensorflow/core/ops/training_ops.cc199
20 files changed, 5254 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
new file mode 100644
index 0000000000..8c0571b50e
--- /dev/null
+++ b/tensorflow/core/ops/array_ops.cc
@@ -0,0 +1,892 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Pack")
+ .Input("values: N * T")
+ .Output("output: T")
+ .Attr("N: int >= 1")
+ .Attr("T: type")
+ .Doc(R"doc(
+Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
+
+Packs the `N` tensors in `values` into a tensor with rank one higher than each
+tensor in `values` and shape `[N] + values[0].shape`. The output satisfies
+`output[i, ...] = values[i][...]`.
+
+This is the opposite of `unpack`.
+
+values: Must be of same shape and type.
+output: The packed tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Unpack")
+ .Input("value: T")
+ .Output("output: num * T")
+ .Attr("num: int >= 0")
+ .Attr("T: type")
+ .Doc(R"doc(
+Unpacks the outer dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
+
+Unpacks `num` tensors from `value` by chipping it along the first dimension.
+The i'th tensor in `output` is the slice `value[i, ...]`. Each tensor in
+`output` has shape `value.shape[1:]`.
+
+This is the opposite of `pack`.
+
+value: 1-D or higher, with first dimension `num`.
+output: The list of tensors unpacked from `value`.
+)doc");
+
+// --------------------------------------------------------------------------
+// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
+// in the N == 1 case to remove the node.
+REGISTER_OP("Concat")
+ .Input("concat_dim: int32")
+ .Input("values: N * T")
+ .Output("output: T")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Doc(R"doc(
+Concatenates tensors along one dimension.
+
+concat_dim: 0-D. The dimension along which to concatenate. Must be in the
+ range [0, rank(values)).
+values: The `N` Tensors to concatenate. Their ranks and types must match,
+ and their sizes must match in all dimensions except `concat_dim`.
+output: A `Tensor` with the concatenation of values stacked along the
+ `concat_dim` dimension. This tensor's shape matches that of `values` except
+ in `concat_dim` where it has the sum of the sizes.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Split")
+ .Input("split_dim: int32")
+ .Input("value: T")
+ .Output("output: num_split * T")
+ .Attr("num_split: int >= 1")
+ .Attr("T: type")
+ .Doc(R"doc(
+Splits a tensor into `num_split` tensors along one dimension.
+
+split_dim: 0-D. The dimension along which to split. Must be in the range
+ `[0, rank(value))`.
+num_split: The number of ways to split. Must evenly divide
+ `value.shape[split_dim]`.
+value: The tensor to split.
+output: They are identically shaped tensors, whose shape matches that of `value`
+ except along `split_dim`, where their sizes are
+ `values.shape[split_dim] / num_split`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Const")
+ .Output("output: dtype")
+ .Attr("value: tensor")
+ .Attr("dtype: type")
+ .Doc(R"doc(
+Returns a constant tensor.
+
+value: Attr `value` is the tensor to return.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ZerosLike")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns a tensor of zeros with the same shape and type as x.
+
+x: a tensor of type T.
+y: a tensor of the same shape and type as x but filled with zeros.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Diag")
+ .Input("diagonal: T")
+ .Output("output: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns a diagonal tensor with a given diagonal values.
+
+Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+everything else padded with zeros. The diagonal is computed as follows:
+
+Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
+rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
+
+`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
+
+For example:
+
+```prettyprint
+# 'diagonal' is [1, 2, 3, 4]
+tf.diag(diagonal) ==> [[1, 0, 0, 0]
+ [0, 2, 0, 0]
+ [0, 0, 3, 0]
+ [0, 0, 0, 4]]
+```
+
+diagonal: Rank k tensor where k is at most 3.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Reverse")
+ .Input("tensor: T")
+ .Input("dims: bool")
+ .Output("output: T")
+ .Attr("T: {uint8, int8, int32, bool, float, double}")
+ .Doc(R"Doc(
+Reverses specific dimensions of a tensor.
+
+Given a `tensor`, and a `bool` tensor `dims` representing the dimensions
+of `tensor`, this operation reverses each dimension i of `tensor` where
+`dims[i]` is `True`.
+
+`tensor` can have up to 8 dimensions. The number of dimensions
+of `tensor` must equal the number of elements in `dims`. In other words:
+
+`rank(tensor) = size(dims)`
+
+For example:
+
+```prettyprint
+# tensor 't' is [[[[ 0, 1, 2, 3],
+# [ 4, 5, 6, 7],
+# [ 8, 9, 10, 11]],
+# [[12, 13, 14, 15],
+# [16, 17, 18, 19],
+# [20, 21, 22, 23]]]]
+# tensor 't' shape is [1, 2, 3, 4]
+
+# 'dims' is [False, False, False, True]
+reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
+ [ 7, 6, 5, 4],
+ [ 11, 10, 9, 8]],
+ [[15, 14, 13, 12],
+ [19, 18, 17, 16],
+ [23, 22, 21, 20]]]]
+
+# 'dims' is [False, True, False, False]
+reverse(t, dims) ==> [[[[12, 13, 14, 15],
+ [16, 17, 18, 19],
+ [20, 21, 22, 23]
+ [[ 0, 1, 2, 3],
+ [ 4, 5, 6, 7],
+ [ 8, 9, 10, 11]]]]
+
+# 'dims' is [False, False, True, False]
+reverse(t, dims) ==> [[[[8, 9, 10, 11],
+ [4, 5, 6, 7],
+ [0, 1, 2, 3]]
+ [[20, 21, 22, 23],
+ [16, 17, 18, 19],
+ [12, 13, 14, 15]]]]
+```
+
+tensor: Up to 8-D.
+dims: 1-D. The dimensions to reverse.
+output: The same shape as `tensor`.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EditDistance")
+ .Input("hypothesis_indices: int64")
+ .Input("hypothesis_values: T")
+ .Input("hypothesis_shape: int64")
+ .Input("truth_indices: int64")
+ .Input("truth_values: T")
+ .Input("truth_shape: int64")
+ .Attr("normalize: bool = True")
+ .Attr("T: type")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes the (possibly normalized) Levenshtein Edit Distance.
+
+The inputs are variable-length sequences provided by SparseTensors
+ (hypothesis_indices, hypothesis_values, hypothesis_shape)
+and
+ (truth_indices, truth_values, truth_shape).
+
+The inputs are:
+
+hypothesis_indices: The indices of the hypothesis list SparseTensor.
+ This is an N x R int64 matrix.
+hypothesis_values: The values of the hypothesis list SparseTensor.
+ This is an N-length vector.
+hypothesis_shape: The shape of the hypothesis list SparseTensor.
+ This is an R-length vector.
+truth_indices: The indices of the truth list SparseTensor.
+ This is an M x R int64 matrix.
+truth_values: The values of the truth list SparseTensor.
+ This is an M-length vector.
+truth_shape: The shape of the truth list SparseTensor.
+ This is an R-length vector.
+truth_shape: truth indices, vector.
+normalize: boolean (if true, edit distances are normalized by length of truth).
+
+The output is:
+
+output: A dense float tensor with rank R - 1.
+
+For the example input:
+
+ // hypothesis represents a 2x1 matrix with variable-length values:
+ // (0,0) = ["a"]
+ // (1,0) = ["b"]
+ hypothesis_indices = [[0, 0, 0],
+ [1, 0, 0]]
+ hypothesis_values = ["a", "b"]
+ hypothesis_shape = [2, 1, 1]
+
+ // truth represents a 2x2 matrix with variable-length values:
+ // (0,0) = []
+ // (0,1) = ["a"]
+ // (1,0) = ["b", "c"]
+ // (1,1) = ["a"]
+ truth_indices = [[0, 1, 0],
+ [1, 0, 0],
+ [1, 0, 1],
+ [1, 1, 0]]
+ truth_values = ["a", "b", "c", "a"]
+ truth_shape = [2, 2, 2]
+ normalize = true
+
+The output will be:
+
+ // output is a 2x2 matrix with edit distances normalized by truth lengths.
+ output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
+ [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Fill")
+ .Input("dims: int32")
+ .Input("value: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Creates a tensor filled with a scalar value.
+
+This operation creates a tensor of shape `dims` and fills it with `value`.
+
+For example:
+
+```prettyprint
+# output tensor shape needs to be [2, 3]
+# so 'dims' is [2, 3]
+fill(dims, 9) ==> [[9, 9, 9]
+ [9, 9, 9]]
+```
+
+dims: 1-D. Represents the shape of the output tensor.
+value: 0-D (scalar). Value to fill the returned tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Gather")
+ .Input("params: Tparams")
+ .Input("indices: Tindices")
+ .Output("output: Tparams")
+ .Attr("Tparams: type")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Gather slices from `params` according to `indices`.
+
+`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
+Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
+
+ # Scalar indices
+ output[:, ..., :] = params[indices, :, ... :]
+
+ # Vector indices
+ output[i, :, ..., :] = params[indices[i], :, ... :]
+
+ # Higher rank indices
+ output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
+
+If `indices` is a permutation and `len(indices) == params.shape[0]` then
+this operation will permute `params` accordingly.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/Gather.png" alt>
+</div>
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Identity")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Return a tensor with the same shape and contents as the input tensor or value.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefIdentity")
+ .Input("input: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Return the same ref tensor as the input ref tensor.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("StopGradient")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Stops gradient computation.
+
+When executed in a graph, this op outputs its input tensor as-is.
+
+When building ops to compute gradients, this op prevents the contribution of
+its inputs to be taken into account. Normally, the gradient generator adds ops
+to a graph to compute the derivatives of a specified 'loss' by recursively
+finding out inputs that contributed to its computation. If you insert this op
+in the graph it inputs are masked from the gradient generator. They are not
+taken into account for computing gradients.
+
+This is useful any time you want to compute a value with TensorFlow but need
+to pretend that the value was a constant. Some examples include:
+
+* The *EM* algorithm where the *M-step* should not involve backpropagation
+ through the output of the *E-step*.
+* Contrastive divergence training of Boltzmann machines where, when
+ differentiating the energy function, the training must not backpropagate
+ through the graph that generated the samples from the model.
+* Adversarial training, where no backprop should happen through the adversarial
+ example generation process.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("CheckNumerics")
+ .Input("tensor: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("message: string")
+ .Doc(R"doc(
+Checks a tensor for NaN and Inf values.
+
+When run, reports an `InvalidArgument` error if `tensor` has any values
+that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
+
+message: Prefix of the error message.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Reshape")
+ .Input("tensor: T")
+ .Input("shape: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Reshapes a tensor.
+
+Given `tensor`, this operation returns a tensor that has the same values
+as `tensor` with shape `shape`.
+
+If `shape` is the special value `[-1]`, then `tensor` is flattened and the
+operation outputs a 1-D tensor with all elements of `tensor`.
+
+If `shape` is 1-D or higher, then the operation returns a tensor with shape
+`shape` filled with the values of `tensor`. In this case, the number of elements
+implied by `shape` must be the same as the number of elements in `tensor`.
+
+For example:
+
+```prettyprint
+# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
+# tensor 't' has shape [9]
+reshape(t, [3, 3]) ==> [[1, 2, 3]
+ [4, 5, 6]
+ [7, 8, 9]]
+
+# tensor 't' is [[[1, 1], [2, 2]]
+# [[3, 3], [4, 4]]]
+# tensor 't' has shape [2, 2]
+reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
+ [3, 3, 4, 4]]
+
+# tensor 't' is [[[1, 1, 1],
+# [2, 2, 2]],
+# [[3, 3, 3],
+# [4, 4, 4]],
+# [[5, 5, 5],
+# [6, 6, 6]]]
+# tensor 't' has shape [3, 2, 3]
+# pass '[-1]' to flatten 't'
+reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
+```
+
+shape: Defines the shape of the output tensor.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("InvertPermutation")
+ .Input("x: int32")
+ .Output("y: int32")
+ .Doc(R"doc(
+Computes the inverse permutation of a tensor.
+
+This operation computes the inverse of an index permutation. It takes a 1-D
+integer tensor `x`, which represents the indices of a zero-based array, and
+swaps each value with its index position. In other words, for an ouput tensor
+`y` and an input tensor `x`, this operation computes the following:
+
+`y[x[i]] = i for i in [0, 1, ..., len(x) - 1]`
+
+The values must include 0. There can be no duplicate values or negative values.
+
+For example:
+
+```prettyprint
+# tensor `x` is [3, 4, 0, 2, 1]
+invert_permutation(x) ==> [2, 4, 3, 0, 1]
+```
+
+x: 1-D.
+y: 1-D.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Transpose")
+ .Input("x: T")
+ .Input("perm: int32")
+ .Output("y: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Shuffle dimensions of x according to a permutation.
+
+The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+ `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Unique")
+ .Input("x: T")
+ .Output("y: T")
+ .Output("idx: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Finds unique elements in a 1-D tensor.
+
+This operation returns a tensor `y` containing all of the unique elements of `x`
+sorted in the same order that they occur in `x`. This operation also returns a
+tensor `idx` the same size as `x` that contains the index of each value of `x`
+in the unique output `y`. In other words:
+
+`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
+
+For example:
+
+```prettyprint
+# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
+y, idx = unique(x)
+y ==> [1, 2, 4, 7, 8]
+idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
+```
+
+x: 1-D.
+y: 1-D.
+idx: 1-D.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Shape")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the shape of a tensor.
+
+This operation returns a 1-D integer tensor representing the shape of `input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+shape(t) ==> [2, 2, 3]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ReverseSequence")
+ .Input("input: T")
+ .Input("seq_lengths: int64")
+ .Output("output: T")
+ .Attr("seq_dim: int")
+ .Attr("T: type")
+ .Doc(R"doc(
+Reverses variable length slices in dimension `seq_dim`.
+
+This op first slices `input` along the first dimension, and for each slice `i`,
+reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`.
+
+The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,
+and `seq_lengths` must be a vector of length `input.dims(0)`.
+
+The output slice `i` along dimension 0 is then given by input slice `i`, with
+the first `seq_lengths[i]` slices along dimension `seq_dim` reversed.
+
+For example:
+
+```prettyprint
+# Given this:
+seq_dim = 1
+input.dims = (4, ...)
+seq_lengths = [7, 2, 3, 5]
+
+# then slices of input are reversed on seq_dim, but only up to seq_lengths:
+output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
+output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
+output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
+output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
+
+# while entries past seq_lens are copied through:
+output[0, 7:, :, ...] = input[0, 7:, :, ...]
+output[1, 2:, :, ...] = input[1, 2:, :, ...]
+output[2, 3:, :, ...] = input[2, 3:, :, ...]
+output[3, 2:, :, ...] = input[3, 2:, :, ...]
+```
+
+input: The input to reverse.
+seq_lengths: 1-D with length `input.dims(0)` and
+ `max(seq_lengths) < input.dims(seq_dim)`
+seq_dim: The dimension which is partially reversed.
+output: The partially reversed input. It has the same shape as `input`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Rank")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the rank of a tensor.
+
+This operation returns an integer representing the rank of `input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+# shape of tensor 't' is [2, 2, 3]
+rank(t) ==> 3
+```
+
+**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
+of a tensor is the number of indices required to uniquely select each element
+of the tensor. Rank is also known as "order", "degree", or "ndims."
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Size")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the size of a tensor.
+
+This operation returns an integer representing the number of elements in
+`input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]]
+size(t) ==> 12
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Slice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .Doc(R"doc(
+Return a slice from 'input'.
+
+The output tensor is a tensor with dimensions described by 'size'
+whose values are extracted from 'input' starting at the offsets in
+'begin'.
+
+*Requirements*:
+ 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
+
+begin: begin[i] specifies the offset into the 'i'th dimension of
+ 'input' to slice from.
+size: size[i] specifies the number of elements of the 'i'th dimension
+ of 'input' to slice. If size[i] is -1, all remaining elements in dimension
+ i are included in the slice (i.e. this is equivalent to setting
+ size[i] = input.dim_size(i) - begin[i]).
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Tile")
+ .Input("input: T")
+ .Input("multiples: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Constructs a tensor by tiling a given tensor.
+
+This operation creates a new tensor by replicating `input` `multiples` times.
+The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
+and the values of `input` are replicated `multiples[i]` times along the 'i'th
+dimension. For example, tiling `[a b c d]` by `[2]` produces
+`[a b c d a b c d]`.
+
+input: 1-D or higher.
+multiples: 1-D. Length must be the same as the number of dimensions in `input`
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("TileGrad")
+ .Input("input: T")
+ .Input("multiples: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the gradient of `Tile`.
+
+Since `Tile` takes an input and repeats the input `multiples` times
+along each dimension, `TileGrad` takes in `multiples` and aggregates
+each repeated tile of `input` into `output`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Where")
+ .Input("input: bool")
+ .Output("index: int64")
+ .Doc(R"doc(
+Returns locations of true values in a boolean tensor.
+
+This operation returns the coordinates of true elements in `input`. The
+coordinates are returned in a 2-D tensor where the first dimension (rows)
+represents the number of true elements, and the second dimension (columns)
+represents the coordinates of the true elements. Keep in mind, the shape of
+the output tensor can vary depending on how many true values there are in
+`input`. Indices are output in row-major order.
+
+For example:
+
+```prettyprint
+# 'input' tensor is [[True, False]
+# [True, False]]
+# 'input' has two true values, so output has two coordinates.
+# 'input' has rank of 2, so coordinates have two indices.
+where(input) ==> [[0, 0],
+ [1, 0]]
+
+# `input` tensor is [[[True, False]
+# [True, False]]
+# [[False, True]
+# [False, True]]
+# [[False, False]
+# [False, True]]]
+# 'input' has 5 true values, so output has 5 coordinates.
+# 'input' has rank of 3, so coordinates have three indices.
+where(input) ==> [[0, 0, 0],
+ [0, 1, 0],
+ [1, 0, 1],
+ [1, 1, 1],
+ [2, 1, 1]]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("BroadcastGradientArgs")
+ .Input("s0: int32")
+ .Input("s1: int32")
+ .Output("r0: int32")
+ .Output("r1: int32")
+ .Doc(R"doc(
+Return the reduction indices for computing gradients of s0 op s1 with broadcast.
+
+This is typically used by gradient computations for a broadcasting operation.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Pad")
+ .Input("input: T")
+ .Input("paddings: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Pads a tensor with zeros.
+
+This operation pads a `input` with zeros according to the `paddings` you
+specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+how many zeros to add before the contents of `input` in that dimension, and
+`paddings[D, 1]` indicates how many zeros to add after the contents of `input`
+in that dimension.
+
+The padded size of each dimension D of the output is:
+
+`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+
+For example:
+
+```prettyprint
+# 't' is [[1, 1], [2, 2]]
+# 'paddings' is [[1, 1]], [2, 2]]
+# rank of 't' is 2
+pad(t, paddings) ==> [[0, 0, 0, 0, 0]
+ [0, 0, 0, 0, 0]
+ [0, 1, 1, 0, 0]
+ [[0, 2, 2, 0, 0]
+ [0, 0, 0, 0, 0]]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Placeholder")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .Doc(R"doc(
+A placeholder op for a value that will be fed into the computation.
+
+N.B. This operation will fail with an error if it is executed. It is
+intended as a way to represent a value that will always be fed, and to
+provide attrs that enable the fed value to be checked at runtime.
+
+output: A placeholder tensor that must be replaced using the feed mechanism.
+dtype: The type of elements in the tensor.
+shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
+ shape is unconstrained.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ExpandDims")
+ .Input("input: T")
+ .Input("dim: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Inserts a dimension of 1 into a tensor's shape.
+
+Given a tensor `input`, this operation inserts a dimension of 1 at the
+dimension index `dim` of `input`'s shape. The dimension index `dim` starts at
+zero; if you specify a negative number for `dim` it is counted backward from
+the end.
+
+This operation is useful if you want to add a batch dimension to a single
+element. For example, if you have a single image of shape `[height, width,
+channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
+which will make the shape `[1, height, width, channels]`.
+
+Other examples:
+
+```prettyprint
+# 't' is a tensor of shape [2]
+shape(expand_dims(t, 0)) ==> [1, 2]
+shape(expand_dims(t, 1)) ==> [2, 1]
+shape(expand_dims(t, -1)) ==> [2, 1]
+
+# 't2' is a tensor of shape [2, 3, 5]
+shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
+shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
+shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
+```
+
+This operation requires that:
+
+`-1-input.dims() <= dim <= input.dims()`
+
+This operation is related to `squeeze()`, which removes dimensions of
+size 1.
+
+dim: 0-D (scalar). Specifies the dimension index at which to
+ expand the shape of `input`.
+output: Contains the same data as `input`, but its shape has an additional
+ dimension of size 1 added.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Squeeze")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("squeeze_dims: list(int) >= 0 = []")
+ .Doc(R"doc(
+Removes dimensions of size 1 from the shape of a tensor.
+
+Given a tensor `input`, this operation returns a tensor of the same type with
+all dimensions of size 1 removed. If you don't want to remove all size 1
+dimensions, you can remove specific size 1 dimensions by specifying
+`squeeze_dims`.
+
+For example:
+
+```prettyprint
+# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
+shape(squeeze(t)) ==> [2, 3]
+```
+
+Or, to remove specific size 1 dimensions:
+
+```prettyprint
+# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
+shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
+```
+
+input: The `input` to squeeze.
+squeeze_dims: If specified, only squeezes the dimensions listed. The dimension
+ index starts at 0. It is an error to squeeze a dimension that is not 1.
+output: Contains the same data as `input`, but has one or more dimensions of
+ size 1 removed.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ListDiff")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("out: T")
+ .Output("idx: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Computes the difference between two lists of numbers.
+
+Given a list `x` and a list `y`, this operation returns a list `out` that
+represents all numbers that are in `x` but not in `y`. The returned list `out`
+is sorted in the same order that the numbers appear in `x` (duplicates are
+preserved). This operation also returns a list `idx` that represents the
+position of each `out` element in `x`. In other words:
+
+`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]`
+
+For example, given this input:
+
+```prettyprint
+x = [1, 2, 3, 4, 5, 6]
+y = [1, 3, 5]
+```
+
+This operation would return:
+
+```prettyprint
+out ==> [2, 4, 6]
+idx ==> [1, 3, 5]
+```
+
+x: 1-D. Values to keep.
+y: 1-D. Values to remove.
+out: 1-D. Values present in `x` but not in `y`.
+idx: 1-D. Positions of `x` values preserved in `out`.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/attention_ops.cc b/tensorflow/core/ops/attention_ops.cc
new file mode 100644
index 0000000000..6fa9a6e821
--- /dev/null
+++ b/tensorflow/core/ops/attention_ops.cc
@@ -0,0 +1,54 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// Tout = extract_glimpse(Tin, size, offsets) extract the glimpse of size size
+// centered at location offsets from the input tensor Tin
+//
+// REQUIRES: Tin.dims() == 4
+//
+REGISTER_OP("ExtractGlimpse")
+ .Input("input: float")
+ .Input("size: int32")
+ .Input("offsets: float")
+ .Output("glimpse: float")
+ .Attr("centered: bool = true")
+ .Attr("normalized: bool = true")
+ .Attr("uniform_noise: bool = true")
+ .Doc(R"doc(
+Extracts a glimpse from the input tensor.
+
+Returns a set of windows called glimpses extracted at location `offsets`
+from the input tensor. If the windows only partially overlaps the inputs, the
+non overlapping areas will be filled with random noise.
+
+The result is a 4-D tensor of shape `[batch_size, glimpse_height,
+glimpse_width, channels]`. The channels and batch dimensions are the same as that
+of the input tensor. The height and width of the output windows are
+specified in the `size` parameter.
+
+The argument `normalized` and `centered` controls how the windows are built:
+* If the coordinates are normalized but not centered, 0.0 and 1.0
+ correspond to the minimum and maximum of each height and width dimension.
+* If the coordinates are both normalized and centered, they range from -1.0 to
+ 1.0. The coordinates (-1.0, -1.0) correspond to the upper left corner, the
+ lower right corner is located at (1.0, 1.0) and the center is at (0, 0).
+* If the coordinates are not normalized they are interpreted as numbers of pixels.
+
+input: A 4-D float tensor of shape `[batch_size, height, width, channels]`.
+size: A 1-D tensor of 2 elements containing the size of the glimpses to extract.
+ The glimpse height must be specified first, following by the glimpse width.
+offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing the x, y
+ locations of the center of each window.
+glimpse: A tensor representing the glimpses `[batch_size, glimpse_height,
+ glimpse_width, channels]`.
+centered: indicates if the offset coordinates are centered relative to
+ the image, in which case the (0, 0) offset is relative to the center of the
+ input images. If false, the (0,0) offset corresponds to the upper left corner
+ of the input images.
+normalized: indicates if the offset coordinates are normalized.
+uniform_noise: indicates if the noise should be generated using a
+ uniform distribution or a gaussian distribution.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc
new file mode 100644
index 0000000000..a98b0295ee
--- /dev/null
+++ b/tensorflow/core/ops/candidate_sampling_ops.cc
@@ -0,0 +1,351 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("UniformCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a uniform distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("LogUniformCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a log-uniform distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("LearnedUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("FixedUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("vocab_file: string = ''")
+ .Attr("distortion: float = 1.0")
+ .Attr("num_reserved_ids: int = 0")
+ .Attr("num_shards: int >= 1 = 1")
+ .Attr("shard: int >= 0 = 0")
+ .Attr("unigrams: list(float) = []")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+A unigram sampler could use a fixed unigram distribution read from a
+file or passed in as an in-memory array instead of building up the distribution
+from data on the fly. There is also an option to skew the distribution by
+applying a distortion power to the weights.
+
+The vocabulary file should be in CSV-like format, with the last field
+being the weight associated with the word.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+vocab_file: Each valid line in this file (which should have a CSV-like format)
+ corresponds to a valid word ID. IDs are in sequential order, starting from
+ num_reserved_ids. The last entry in each line is expected to be a value
+ corresponding to the count or relative probability. Exactly one of vocab_file
+ and unigrams needs to be passed to this op.
+distortion: The distortion is used to skew the unigram probability distribution.
+ Each weight is first raised to the distortion's power before adding to the
+ internal unigram distribution. As a result, distortion = 1.0 gives regular
+ unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
+ a uniform distribution.
+num_reserved_ids: Optionally some reserved IDs can be added in the range [0,
+ ..., num_reserved_ids) by the users. One use case is that a special unknown
+ word token is used as ID 0. These IDs will have a sampling probability of 0.
+num_shards: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This parameter
+ (together with 'shard') indicates the number of partitions that are being
+ used in the overall computation.
+shard: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This parameter
+ (together with 'num_shards') indicates the particular partition number of a
+ sampler op, when partitioning is being used.
+unigrams: A list of unigram counts or probabilities, one per ID in sequential
+ order. Exactly one of vocab_file and unigrams should be passed to this op.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("AllCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to produce per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("ComputeAccidentalHits")
+ .Input("true_classes: int64")
+ .Input("sampled_candidates: int64")
+ .Output("indices: int32")
+ .Output("ids: int64")
+ .Output("weights: float")
+ .Attr("num_true: int")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Computes the ids of the positions in sampled_candidates that match true_labels.
+
+When doing log-odds NCE, the result of this op should be passed through a
+SparseToDense op, then added to the logits of the sampled candidates. This has
+the effect of 'removing' the sampled labels that match the true labels by
+making the classifier sure that they are sampled labels.
+
+true_classes: The true_classes output of UnpackSparseLabels.
+sampled_candidates: The sampled_candidates output of CandidateSampler.
+indices: A vector of indices corresponding to rows of true_candidates.
+ids: A vector of IDs of positions in sampled_candidates that match a true_label
+ for the row with the corresponding index in indices.
+weights: A vector of the same length as indices and ids, in which each element
+ is -FLOAT_MAX.
+num_true: Number of true labels per context.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
new file mode 100644
index 0000000000..517b2d2742
--- /dev/null
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -0,0 +1,179 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Switch")
+ .Input("data: T")
+ .Input("pred: bool")
+ .Output("output_false: T")
+ .Output("output_true: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Forwards `data` to the output port determined by `pred`.
+
+If `pred` is true, the `data` input is forwared to `output_true`. Otherwise,
+the data goes to `output_false`.
+
+See also `RefSwitch` and `Merge`.
+
+data: The tensor to be forwarded to the appropriate output.
+pred: A scalar that specifies which output port will receive data.
+output_false: If `pred` is false, data will be forwarded to this output.
+output_true: If `pred` is true, data will be forwarded to this output.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefSwitch")
+ .Input("data: Ref(T)")
+ .Input("pred: bool")
+ .Output("output_false: Ref(T)")
+ .Output("output_true: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"doc(
+Forwards the ref tensor `data` to the output port determined by `pred`.
+
+If `pred` is true, the `data` input is forwared to `output_true`. Otherwise,
+the data goes to `output_false`.
+
+See also `Switch` and `Merge`.
+
+data: The ref tensor to be forwarded to the appropriate output.
+pred: A scalar that specifies which output port will receive data.
+output_false: If `pred` is false, data will be forwarded to this output.
+output_true: If `pred` is true, data will be forwarded to this output.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefSelect")
+ .Input("index: int32")
+ .Input("inputs: Ref(N * T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Forwards the `index`th element of `inputs` to `output`.
+
+index: A scalar that determines the input that gets selected.
+inputs: A list of ref tensors, one of which will be forwarded to `output`.
+output: The forwarded tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Merge")
+ .Input("inputs: N * T")
+ .Output("output: T")
+ .Output("value_index: int32")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Forwards the value of an available tensor from `inputs` to `output`.
+
+`Merge` waits for at least one of the tensors in `inputs` to become available.
+It is usually combined with `Switch` to implement branching.
+
+`Merge` forwards the first tensor for become available to `output`, and sets
+`value_index` to its index in `inputs`.
+
+It is an error if more than one tensor in `inputs` is available.
+
+inputs: The input tensors, exactly one of which will become available.
+output: Will be set to the available input tensor.
+value_index: The index of the chosen input tensor in `inputs`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Enter")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("frame_name: string")
+ .Attr("is_constant: bool = false")
+ .Attr("parallel_iterations: int = 10")
+ .Doc(R"doc(
+Creates or finds a child frame, and makes `data` available to the child frame.
+
+This op is used together with `Exit` to create loops in the graph.
+The unique `frame_name` is used by the `Executor` to identify frames. If
+`is_constant` is true, `output` is a constant in the child frame; otherwise
+it may be changed in the child frame. At most `parallel_iterations` iterations
+are run in parallel in the child frame.
+
+data: The tensor to be made available to the child frame.
+frame_name: The name of the child frame.
+is_constant: If true, the output is constant within the child frame.
+parallel_iterations: The number of iterations allowed to run in parallel.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefEnter")
+ .Input("data: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Attr("frame_name: string")
+ .Attr("is_constant: bool = false")
+ .Attr("parallel_iterations: int = 10")
+ .Doc(R"doc(
+Creates or finds a child frame, and makes `data` available to the child frame.
+
+The unique `frame_name` is used by the `Executor` to identify frames. If
+`is_constant` is true, `output` is a constant in the child frame; otherwise
+it may be changed in the child frame. At most `parallel_iterations` iterations
+are run in parallel in the child frame.
+
+data: The tensor to be made available to the child frame.
+frame_name: The name of the child frame.
+is_constant: If true, the output is constant within the child frame.
+parallel_iterations: The number of iterations allowed to run in parallel.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Exit")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Exits the current frame to its parent frame.
+
+Exit makes its input `data` available to the parent frame.
+
+data: The tensor to be made available to the parent frame.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("NextIteration")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Makes its input available to the next iteration.
+
+data: The tensor to be made available to the next iteration.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("LoopCond")
+ .Input("input: bool")
+ .Output("output: bool")
+ .Doc(R"doc(
+Forwards the input to the output.
+
+This operator represents the loop termination condition used by the
+"pivot" switches of a loop.
+
+input:= A boolean scalar, representing the branch predicate of the Switch op.
+output: The same tensor as `input`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ControlTrigger")
+ .Doc(R"doc(
+Does nothing. Serves as a control trigger for scheduling. Only useful as a
+placeholder for control edges.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
new file mode 100644
index 0000000000..49eba33188
--- /dev/null
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -0,0 +1,357 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("DynamicPartition")
+ .Input("data: T")
+ .Input("partitions: int32")
+ .Output("outputs: num_partitions * T")
+ .Attr("num_partitions: int")
+ .Attr("T: type")
+ .Doc(R"doc(
+Partitions `data` into `num_partitions` tensors using indices from `partitions`.
+
+For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`
+becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`
+are placed in `outputs[i]` in lexicographic order of `js`, and the first
+dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.
+In detail,
+
+ outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]
+
+ outputs[i] = pack([data[js, ...] for js if partitions[js] == i])
+
+`data.shape` must start with `partitions.shape`.
+
+For example:
+
+ # Scalar partitions
+ partitions = 1
+ num_partitions = 2
+ data = [10, 20]
+ outputs[0] = [] # Empty with shape [0, 2]
+ outputs[1] = [[10, 20]]
+
+ # Vector partitions
+ partitions = [0, 0, 1, 1, 0]
+ num_partitions = 2
+ data = [10, 20, 30, 40, 50]
+ outputs[0] = [10, 20, 50]
+ outputs[1] = [30, 40]
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/DynamicPartition.png" alt>
+</div>
+
+partitions: Any shape. Indices in the range `[0, num_partitions)`.
+num_partitions: The number of partitions to output.
+)doc");
+
+REGISTER_OP("DynamicStitch")
+ .Input("indices: N * int32")
+ .Input("data: N * T")
+ .Output("merged: T")
+ .Attr("N : int >= 2")
+ .Attr("T : type")
+ .Doc(R"doc(
+Interleave the values from the `data` tensors into a single tensor.
+
+Builds a merged tensor such that
+
+ merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
+
+For example, if each `indices[m]` is scalar or vector, we have
+
+ # Scalar indices
+ merged[indices[m], ...] = data[m][...]
+
+ # Vector indices
+ merged[indices[m][i], ...] = data[m][i, ...]
+
+Each `data[i].shape` must start with the corresponding `indices[i].shape`,
+and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we
+must have `data[i].shape = indices[i].shape + constant`. In terms of this
+`constant`, the output shape is
+
+ merged.shape = [max(indices)] + constant
+
+Values are merged in order, so if an index appears in both `indices[m][i]` and
+`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the
+merged result.
+
+For example:
+
+ indices[0] = 6
+ indices[1] = [4, 1]
+ indices[2] = [[5, 2], [0, 3]]
+ data[0] = [61, 62]
+ data[1] = [[41, 42], [11, 12]]
+ data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
+ merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
+ [51, 52], [61, 62]]
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/DynamicStitch.png" alt>
+</div>
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("RandomShuffleQueue")
+ .Output("handle: Ref(string)")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("shapes: list(shape) >= 0 = []")
+ .Attr("capacity: int = -1")
+ .Attr("min_after_dequeue: int = 0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A queue that randomizes the order of elements.
+
+handle: The handle to the queue.
+component_types: The type of each component in a value.
+shapes: The shape of each component in a value. The length of this attr must
+ be either 0 or the same as the length of component_types. If the length of
+ this attr is 0, the shapes of queue elements are not constrained, and
+ only one element may be dequeued at a time.
+capacity: The upper bound on the number of elements in this queue.
+ Negative numbers mean no limit.
+min_after_dequeue: Dequeue will block unless there would be this
+ many elements after the dequeue or the queue is closed. This
+ ensures a minimum level of mixing of elements.
+seed: If either seed or seed2 is set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, a random seed is used.
+seed2: A second seed to avoid seed collision.
+container: If non-empty, this queue is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this queue will be shared under the given name
+ across multiple sessions.
+)doc");
+
+REGISTER_OP("FIFOQueue")
+ .Output("handle: Ref(string)")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("shapes: list(shape) >= 0 = []")
+ .Attr("capacity: int = -1")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A queue that produces elements in first-in first-out order.
+
+handle: The handle to the queue.
+component_types: The type of each component in a value.
+shapes: The shape of each component in a value. The length of this attr must
+ be either 0 or the same as the length of component_types. If the length of
+ this attr is 0, the shapes of queue elements are not constrained, and
+ only one element may be dequeued at a time.
+capacity: The upper bound on the number of elements in this queue.
+ Negative numbers mean no limit.
+container: If non-empty, this queue is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this queue will be shared under the given name
+ across multiple sessions.
+)doc");
+
+REGISTER_OP("QueueEnqueue")
+ .Input("handle: Ref(string)")
+ .Input("components: Tcomponents")
+ .Attr("Tcomponents: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Enqueues a tuple of one or more tensors in the given queue.
+
+The components input has k elements, which correspond to the components of
+tuples stored in the given queue.
+
+N.B. If the queue is full, this operation will block until the given
+element has been enqueued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors from which the enqueued tensors should be taken.
+timeout_ms: If the queue is full, this operation will block for up to
+ timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueEnqueueMany")
+ .Input("handle: Ref(string)")
+ .Input("components: Tcomponents")
+ .Attr("Tcomponents: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Enqueues zero or more tuples of one or more tensors in the given queue.
+
+This operation slices each component tensor along the 0th dimension to
+make multiple queue elements. All of the tuple components must have the
+same size in the 0th dimension.
+
+The components input has k elements, which correspond to the components of
+tuples stored in the given queue.
+
+N.B. If the queue is full, this operation will block until the given
+elements have been enqueued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors from which the enqueued tensors should
+ be taken.
+timeout_ms: If the queue is too full, this operation will block for up
+ to timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueDequeue")
+ .Input("handle: Ref(string)")
+ .Output("components: component_types")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Dequeues a tuple of one or more tensors from the given queue.
+
+This operation has k outputs, where k is the number of components
+in the tuples stored in the given queue, and output i is the ith
+component of the dequeued tuple.
+
+N.B. If the queue is empty, this operation will block until an element
+has been dequeued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors that were dequeued as a tuple.
+component_types: The type of each component in a tuple.
+timeout_ms: If the queue is empty, this operation will block for up to
+ timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueDequeueMany")
+ .Input("handle: Ref(string)")
+ .Input("n: int32")
+ .Output("components: component_types")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Dequeues n tuples of one or more tensors from the given queue.
+
+This operation concatenates queue-element component tensors along the
+0th dimension to make a single component tensor. All of the components
+in the dequeued tuple will have size n in the 0th dimension.
+
+This operation has k outputs, where k is the number of components in
+the tuples stored in the given queue, and output i is the ith
+component of the dequeued tuple.
+
+N.B. If the queue is empty, this operation will block until n elements
+have been dequeued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+n: The number of tuples to dequeue.
+components: One or more tensors that were dequeued as a tuple.
+component_types: The type of each component in a tuple.
+timeout_ms: If the queue has fewer than n elements, this operation
+ will block for up to timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueClose")
+ .Input("handle: Ref(string)")
+ .Attr("cancel_pending_enqueues: bool = false")
+ .Doc(R"doc(
+Closes the given queue.
+
+This operation signals that no more elements will be enqueued in the
+given queue. Subsequent Enqueue(Many) operations will fail.
+Subsequent Dequeue(Many) operations will continue to succeed if
+sufficient elements remain in the queue. Subsequent Dequeue(Many)
+operations that would block will fail immediately.
+
+handle: The handle to a queue.
+cancel_pending_enqueues: If true, all pending enqueue requests that are
+ blocked on the given queue will be cancelled.
+)doc");
+
+REGISTER_OP("QueueSize")
+ .Input("handle: Ref(string)")
+ .Output("size: int32")
+ .Doc(R"doc(
+Computes the number of elements in the given queue.
+
+handle: The handle to a queue.
+size: The number of elements in the given queue.
+)doc");
+
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LookupTableFind")
+ .Input("table_handle: Ref(string)")
+ .Input("input_values: Tin")
+ .Input("default_value: Tout")
+ .Output("output_values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .Doc(R"doc(
+Maps elements of a tensor into associated values given a lookup table.
+
+If an element of the input_values is not present in the table, the
+specified default_value is used.
+
+The table needs to be initialized and the input and output types correspond
+to the table key and value types.
+
+table_handle: A handle for a lookup table.
+input_values: A vector of key values.
+default_value: A scalar to return if the input is not found in the table.
+output_values: A vector of values associated to the inputs.
+)doc");
+
+REGISTER_OP("LookupTableSize")
+ .Input("table_handle: Ref(string)")
+ .Output("size: int64")
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: The handle to a lookup table.
+size: The number of elements in the given table.
+)doc");
+
+REGISTER_OP("HashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Doc(R"doc(
+Creates and holds an immutable hash table.
+
+The key and value types can be specified. After initialization, the table
+becomes immutable.
+
+table_handle: a handle of a the lookup table.
+container: If non-empty, this hash table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this hash table is shared under the given name across
+ multiple sessions.
+key_dtype: the type of the table key.
+value_dtype: the type of the table value.
+)doc");
+
+REGISTER_OP("InitializeTable")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: a handle of the lookup table to be initialized.
+keys: a vector of keys of type Tkey.
+values: a vector of values of type Tval.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
new file mode 100644
index 0000000000..88af081893
--- /dev/null
+++ b/tensorflow/core/ops/image_ops.cc
@@ -0,0 +1,273 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeArea")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using area interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeBicubic")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using bicubic interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeBilinear")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using bilinear interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeNearestNeighbor")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: T")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using nearest neighbor interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RandomCrop")
+ .Input("image: T")
+ .Input("size: int64")
+ .Output("output: T")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetIsStateful()
+ .Doc(R"doc(
+Randomly crop `image`.
+
+`size` is a 1-D int64 tensor with 2 elements representing the crop height and
+width. The values must be non negative.
+
+This Op picks a random location in `image` and crops a `height` by `width`
+rectangle from that location. The random location is picked so the cropped
+area will fit inside the original image.
+
+image: 3-D of shape `[height, width, channels]`.
+size: 1-D of length 2 containing: `crop_height`, `crop_width`..
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+output: 3-D of shape `[crop_height, crop_width, channels].`
+)doc");
+// TODO(shlens): Support variable rank in RandomCrop.
+
+// --------------------------------------------------------------------------
+REGISTER_OP("DecodeJpeg")
+ .Input("contents: string")
+ .Attr("channels: int = 0")
+ .Attr("ratio: int = 1")
+ .Attr("fancy_upscaling: bool = true")
+ .Attr("try_recover_truncated: bool = false")
+ .Attr("acceptable_fraction: float = 1.0")
+ .Output("image: uint8")
+ .Doc(R"doc(
+Decode a JPEG-encoded image to a uint8 tensor.
+
+The attr `channels` indicates the desired number of color channels for the
+decoded image.
+
+Accepted values are:
+
+* 0: Use the number of channels in the JPEG-encoded image.
+* 1: output a grayscale image.
+* 3: output an RGB image.
+
+If needed, the JPEG-encoded image is transformed to match the requested number
+of color channels.
+
+The attr `ratio` allows downscaling the image by an integer factor during
+decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
+downscaling the image later.
+
+contents: 0-D. The JPEG-encoded image.
+channels: Number of color channels for the decoded image.
+ratio: Downscaling ratio.
+fancy_upscaling: If true use a slower but nicer upscaling of the
+ chroma planes (yuv420/422 only).
+try_recover_truncated: If true try to recover an image from truncated input.
+acceptable_fraction: The minimum required fraction of lines before a truncated
+ input is accepted.
+image: 3-D with shape `[height, width, channels]`..
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EncodeJpeg")
+ .Input("image: uint8")
+ .Attr("format: {'', 'grayscale', 'rgb'} = ''")
+ .Attr("quality: int = 95")
+ .Attr("progressive: bool = false")
+ .Attr("optimize_size: bool = false")
+ .Attr("chroma_downsampling: bool = true")
+ .Attr("density_unit: {'in', 'cm'} = 'in'")
+ .Attr("x_density: int = 300")
+ .Attr("y_density: int = 300")
+ .Attr("xmp_metadata: string = ''")
+ .Output("contents: string")
+ .Doc(R"doc(
+JPEG-encode an image.
+
+`image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
+
+The attr `format` can be used to override the color format of the encoded
+output. Values can be:
+
+* `''`: Use a default format based on the number of channels in the image.
+* `grayscale`: Output a grayscale JPEG image. The `channels` dimension
+ of `image` must be 1.
+* `rgb`: Output an RGB JPEG image. The `channels` dimension
+ of `image` must be 3.
+
+If `format` is not specified or is the empty string, a default format is picked
+in function of the number of channels in `image`:
+
+* 1: Output a grayscale image.
+* 3: Output an RGB image.
+
+image: 3-D with shape `[height, width, channels]`.
+format: Per pixel image format.
+quality: Quality of the compression from 0 to 100 (higher is better and slower).
+progressive: If True, create a JPEG that loads progressively (coarse to fine).
+optimize_size: If True, spend CPU/RAM to reduce size with no quality change.
+chroma_downsampling: See http://en.wikipedia.org/wiki/Chroma_subsampling.
+density_unit: Unit used to specify `x_density` and `y_density`:
+ pixels per inch (`'in'`) or centimeter (`'cm'`).
+x_density: Horizontal pixels per density unit.
+y_density: Vertical pixels per density unit.
+xmp_metadata: If not empty, embed this XMP metadata in the image header.
+contents: 0-D. JPEG-encoded image.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("AdjustContrast")
+ .Input("images: T")
+ .Input("contrast_factor: float")
+ .Input("min_value: float")
+ .Input("max_value: float")
+ .Output("output: float")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
+ .Doc(R"Doc(
+Adjust the contrast of one or more images.
+
+`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
+interpreted as `[height, width, channels]`. The other dimensions only
+represent a collection of images, such as `[batch, height, width, channels].`
+
+Contrast is adjusted independently for each channel of each image.
+
+For each channel, the Op first computes the mean of the image pixels in the
+channel and then adjusts each component of each pixel to
+`(x - mean) * contrast_factor + mean`.
+
+These adjusted values are then clipped to fit in the `[min_value, max_value]`
+interval.
+
+`images: Images to adjust. At least 3-D.
+contrast_factor: A float multiplier for adjusting contrast.
+min_value: Minimum value for clipping the adjusted pixels.
+max_value: Maximum value for clipping the adjusted pixels.
+output: The constrast-adjusted image or images.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("DecodePng")
+ .Input("contents: string")
+ .Attr("channels: int = 0")
+ .Output("image: uint8")
+ .Doc(R"doc(
+Decode a PNG-encoded image to a uint8 tensor.
+
+The attr `channels` indicates the desired number of color channels for the
+decoded image.
+
+Accepted values are:
+
+* 0: Use the number of channels in the PNG-encoded image.
+* 1: output a grayscale image.
+* 3: output an RGB image.
+* 4: output an RGBA image.
+
+If needed, the PNG-encoded image is transformed to match the requested number
+of color channels.
+
+contents: 0-D. The PNG-encoded image.
+channels: Number of color channels for the decoded image.
+image: 3-D with shape `[height, width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EncodePng")
+ .Input("image: uint8")
+ .Attr("compression: int = -1")
+ .Output("contents: string")
+ .Doc(R"doc(
+PNG-encode an image.
+
+`image` is a 3-D uint8 Tensor of shape `[height, width, channels]` where
+`channels` is:
+
+* 1: for grayscale.
+* 3: for RGB.
+* 4: for RGBA.
+
+The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
+default or a value from 0 to 9. 9 is the highest compression level, generating
+the smallest output, but is slower.
+
+image: 3-D with shape `[height, width, channels]`.
+compression: Compression level.
+contents: 0-D. PNG-encoded image.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
new file mode 100644
index 0000000000..937fedd45d
--- /dev/null
+++ b/tensorflow/core/ops/io_ops.cc
@@ -0,0 +1,332 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Save")
+ .Input("filename: string")
+ .Input("tensor_names: string")
+ .Input("data: T")
+ .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Doc(R"doc(
+Saves the input tensors to disk.
+
+The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
+is written to `filename` with name `tensor_names[i]`.
+
+See also `SaveSlices`.
+
+filename: Must have a single element. The name of the file to which we write
+the tensor.
+tensor_names: Shape `[N]`. The names of the tensors to be saved.
+data: `N` tensors to save.
+)doc");
+
+REGISTER_OP("SaveSlices")
+ .Input("filename: string")
+ .Input("tensor_names: string")
+ .Input("shapes_and_slices: string")
+ .Input("data: T")
+ .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Doc(R"doc(
+Saves input tensors slices to disk.
+
+This is like `Save` except that tensors can be listed in the saved file as being
+a slice of a larger tensor. `shapes_and_slices` specifies the shape of the
+larger tensor and the slice that this tensor covers. `shapes_and_slices` must
+have as many elements as `tensor_names`.
+
+Elements of the `shapes_and_slices` input must either be:
+
+* The empty string, in which case the corresponding tensor is
+ saved normally.
+* A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the
+ `dimI` are the dimensions of the larger tensor and `slice-spec`
+ specifies what part is covered by the tensor to save.
+
+`slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1`
+where each `sliceI` is either:
+
+* The string `-` meaning that the slice covers all indices of this dimension
+* `start,length` where `start` and `length` are integers. In that
+ case the slice covers `length` indices starting at `start`.
+
+See also `Save`.
+
+filename: Must have a single element. The name of the file to which we write the
+tensor.
+tensor_names: Shape `[N]`. The names of the tensors to be saved.
+shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when
+saving the tensors.
+data: `N` tensors to save.
+)doc");
+
+REGISTER_OP("Restore")
+ .Input("file_pattern: string")
+ .Input("tensor_name: string")
+ .Output("tensor: dt")
+ .Attr("dt: type")
+ .Attr("preferred_shard: int = -1")
+ .Doc(R"doc(
+Restores a tensor from checkpoint files.
+
+Reads a tensor stored in one or several files. If there are several files (for
+instance because a tensor was saved as slices), `file_pattern` may contain
+wildcard symbols (`*` and `?`) in the filename portion only, not in the
+directory portion.
+
+If a `file_pattern` matches several files, `preferred_shard` can be used to hint
+in which file the requested tensor is likely to be found. This op will first
+open the file at index `preferred_shard` in the list of matching files and try
+to restore tensors from that file. Only if some tensors or tensor slices are
+not found in that first file, then the Op opens all the files. Setting
+`preferred_shard` to match the value passed as the `shard` input
+of a matching `Save` Op may speed up Restore. This attribute only affects
+performance, not correctness. The default value -1 means files are processed in
+order.
+
+See also `RestoreSlice`.
+
+file_pattern: Must have a single element. The pattern of the files from
+ which we read the tensor.
+tensor_name: Must have a single element. The name of the tensor to be
+ restored.
+tensor: The restored tensor.
+dt: The type of the tensor to be restored.
+preferred_shard: Index of file to open first if multiple files match
+ `file_pattern`.
+)doc");
+
+REGISTER_OP("RestoreSlice")
+ .Input("file_pattern: string")
+ .Input("tensor_name: string")
+ .Input("shape_and_slice: string")
+ .Output("tensor: dt")
+ .Attr("dt: type")
+ .Attr("preferred_shard: int = -1")
+ .Doc(R"doc(
+Restores a tensor from checkpoint files.
+
+This is like `Restore` except that restored tensor can be listed as filling
+only a slice of a larger tensor. `shape_and_slice` specifies the shape of the
+larger tensor and the slice that the restored tensor covers.
+
+The `shape_and_slice` input has the same format as the
+elements of the `shapes_and_slices` input of the `SaveSlices` op.
+
+file_pattern: Must have a single element. The pattern of the files from
+ which we read the tensor.
+tensor_name: Must have a single element. The name of the tensor to be
+ restored.
+shape_and_slice: Scalar. The shapes and slice specifications to use when
+ restoring a tensors.
+tensor: The restored tensor.
+dt: The type of the tensor to be restored.
+preferred_shard: Index of file to open first if multiple files match
+ `file_pattern`. See the documentation for `Restore`.
+)doc");
+
+REGISTER_OP("ShardedFilename")
+ .Input("basename: string")
+ .Input("shard: int32")
+ .Input("num_shards: int32")
+ .Output("filename: string")
+ .Doc(R"doc(
+Generate a sharded filename. The filename is printf formated as
+ %s-%05d-of-%05d, basename, shard, num_shards.
+)doc");
+
+REGISTER_OP("ShardedFilespec")
+ .Input("basename: string")
+ .Input("num_shards: int32")
+ .Output("filename: string")
+ .Doc(R"doc(
+Generate a glob pattern matching all sharded file names.
+)doc");
+
+// Reader source ops ----------------------------------------------------------
+
+REGISTER_OP("WholeFileReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the entire contents of a file as a value.
+
+To use, enqueue filenames in a Queue. The output of ReaderRead will
+be a filename (key) and the contents of that file (value).
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TextLineReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("skip_header_lines: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the lines of a file delimited by '\n'.
+
+reader_handle: The handle to reference the Reader.
+skip_header_lines: Number of lines to skip from the beginning of every file.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("FixedLengthRecordReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("header_bytes: int = 0")
+ .Attr("record_bytes: int")
+ .Attr("footer_bytes: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs fixed-length records from a file.
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TFRecordReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the records from a TensorFlow Records file.
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("IdentityReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the queued work as both the key and value.
+
+To use, enqueue strings in a Queue. ReaderRead will take the front
+work string and output (work, work).
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+// Ops that operate on Readers ------------------------------------------------
+
+REGISTER_OP("ReaderRead")
+ .Input("reader_handle: Ref(string)")
+ .Input("queue_handle: Ref(string)")
+ .Output("key: string")
+ .Output("value: string")
+ .Doc(R"doc(
+Returns the next record (key, value pair) produced by a Reader.
+
+Will dequeue from the input queue if necessary (e.g. when the
+Reader needs to start reading from a new file since it has finished
+with the previous file).
+
+reader_handle: Handle to a Reader.
+queue_handle: Handle to a Queue, with string work items.
+key: A scalar.
+value: A scalar.
+)doc");
+
+REGISTER_OP("ReaderNumRecordsProduced")
+ .Input("reader_handle: Ref(string)")
+ .Output("records_produced: int64")
+ .Doc(R"doc(
+Returns the number of records this Reader has produced.
+
+This is the same as the number of ReaderRead executions that have
+succeeded.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderNumWorkUnitsCompleted")
+ .Input("reader_handle: Ref(string)")
+ .Output("units_completed: int64")
+ .Doc(R"doc(
+Returns the number of work units this Reader has finished processing.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderSerializeState")
+ .Input("reader_handle: Ref(string)")
+ .Output("state: string")
+ .Doc(R"doc(
+Produce a string tensor that encodes the state of a Reader.
+
+Not all Readers support being serialized, so this can produce an
+Unimplemented error.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderRestoreState")
+ .Input("reader_handle: Ref(string)")
+ .Input("state: string")
+ .Doc(R"doc(
+Restore a reader to a previously saved state.
+
+Not all Readers support being restored, so this can produce an
+Unimplemented error.
+
+reader_handle: Handle to a Reader.
+state: Result of a ReaderSerializeState of a Reader with type
+ matching reader_handle.
+)doc");
+
+REGISTER_OP("ReaderReset")
+ .Input("reader_handle: Ref(string)")
+ .Doc(R"doc(
+Restore a Reader to its initial clean state.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+// Other input Ops ----------------------------------------------------------
+
+REGISTER_OP("ReadFile")
+ .Input("filename: string")
+ .Output("contents: string")
+ .Doc(R"doc(
+Reads and outputs the entire contents of the input filename.
+)doc");
+
+REGISTER_OP("MatchingFiles")
+ .Input("pattern: string")
+ .Output("filenames: string")
+ .Doc(R"doc(
+Returns the set of files matching a pattern.
+
+Note that this routine only supports wildcard characters in the
+basename portion of the pattern, not in the directory portion.
+
+pattern: A (scalar) shell wildcard pattern.
+filenames: A vector of matching filenames.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
new file mode 100644
index 0000000000..a9b940295e
--- /dev/null
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -0,0 +1,97 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("MatrixDeterminant")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the determinant of a square matrix.
+
+input: A tensor of shape `[M, M]`.
+output: A scalar, equal to the determinant of the input.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchMatrixDeterminant")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the determinants for a batch of square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a 1-D tensor containing the determinants
+for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[...]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("MatrixInverse")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the inverse of a square invertible matrix. Checks for invertibility.
+
+input: Shape is `[M, M]`.
+output: Shape is `[M, M]` containing the matrix inverse of the input.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchMatrixInverse")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the inverse of square invertible matrices. Checks for invertibility.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor of the same shape as the input
+containing the inverse for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("Cholesky")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {double, float}")
+ .Doc(R"doc(
+Calculates the Cholesky decomposition of a square matrix.
+
+The input has to be symmetric and positive definite. Only the lower-triangular
+part of the input will be used for this operation. The upper-triangular part
+will not be read.
+
+The result is the lower-triangular matrix of the Cholesky decomposition of the
+input.
+
+input: Shape is `[M, M]`.
+output: Shape is `[M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchCholesky")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {double, float}")
+ .Doc(R"doc(
+Calculates the Cholesky decomposition of a batch of square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices, with the same constraints as the single matrix Cholesky
+decomposition above. The output is a tensor of the same shape as the input
+containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
new file mode 100644
index 0000000000..28546fe645
--- /dev/null
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Assert")
+ .Input("condition: bool")
+ .Input("data: T")
+ .Attr("T: list(type)")
+ .Attr("summarize: int = 3")
+ .Doc(R"doc(
+Asserts that the given condition is true.
+
+If `condition` evaluates to false, print the list of tensors in `data`.
+`summarize` determines how many entries of the tensors to print.
+
+condition: The condition to evaluate.
+data: The tensors to print out when condition is false.
+summarize: Print this many entries of each tensor.
+)doc");
+
+REGISTER_OP("Print")
+ .Input("input: T")
+ .Input("data: U")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("U: list(type)")
+ .Attr("message: string = ''")
+ .Attr("first_n: int = -1")
+ .Attr("summarize: int = 3")
+ .Doc(R"doc(
+Prints a list of tensors.
+
+Passes `input` through to `output` and prints `data` when evaluating.
+
+input: The tensor passed to `output`
+data: A list of tensors to print out when op is evaluated.
+output:= The unmodified `input` tensor
+message: A string, prefix of the error message.
+first_n: Only log `first_n` number of times. -1 disables logging.
+summarize: Only print this many entries of each tensor.
+)doc");
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
new file mode 100644
index 0000000000..20e56316ea
--- /dev/null
+++ b/tensorflow/core/ops/math_ops.cc
@@ -0,0 +1,1053 @@
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("AddN")
+ .Input("inputs: N * T")
+ .Output("sum: T")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetIsCommutative()
+ .SetIsAggregate()
+ .Doc(R"doc(
+Add all input tensors element wise.
+
+inputs: Must all be the same size and shape.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BatchMatMul")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("out: T")
+ .Attr("T: {float, double, int32, complex64}")
+ .Attr("adj_x: bool = false")
+ .Attr("adj_y: bool = false")
+ .Doc(R"doc(
+Multiplies slices of two tensors in batches.
+
+Multiplies all slices of `Tensor` `x` and `y` (each slice can be
+viewed as an element of a batch), and arranges the individual results
+in a single output tensor of the same batch size. Each of the
+individual slices can optionally be adjointed (to adjoint a matrix
+means to transpose and conjugate it) before multiplication by setting
+the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
+
+The input tensors `x` and `y` are 3-D or higher with shape `[..., r_x, c_x]`
+and `[..., r_y, c_y]`.
+
+The output tensor is 3-D or higher with shape `[..., r_o, c_o]`, where:
+
+ r_o = c_x if adj_x else r_x
+ c_o = r_y if adj_y else c_y
+
+It is computed as:
+
+ out[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
+
+x: 3-D or higher with shape `[..., r_x, c_x]`.
+y: 3-D or higher with shape `[..., r_y, c_y]`.
+out: 3-D or higher with shape `[..., r_o, c_o]`
+adj_x: If `True`, adjoint the slices of `x`. Defaults to `False`.
+adj_y: If `True`, adjoint the slices of `y`. Defaults to `False`.
+)doc");
+
+// --------------------------------------------------------------------------
+// Casting Ops
+//
+// NOTE: Only a smaller number of types are supported by
+// Cast. The exact casting rule is TBD. The current
+// implementation uses C++ static cast rules for numeric
+// types, which may be changed in the future.
+REGISTER_OP("Cast")
+ .Input("x: SrcT")
+ .Output("y: DstT")
+ .Attr("SrcT: type")
+ .Attr("DstT: type")
+ .Doc(R"doc(
+Cast x of type SrcT to y of DstT.
+)doc");
+
+REGISTER_OP("_HostCast")
+ .Input("x: SrcT")
+ .Output("y: DstT")
+ .Attr("SrcT: type")
+ .Attr("DstT: type")
+ .Doc(R"doc(
+Cast x of type SrcT to y of DstT.
+
+_HostCast requires its input and produces its output in host memory.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Abs")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Computes the absolute value of a tensor.
+
+Given a tensor `x`, this operation returns a tensor containing the absolute
+value of each element in `x`. For example, if x is an input element and y is
+an output element, this operation computes \\(y = |x|\\).
+)doc");
+
+REGISTER_OP("ComplexAbs")
+ .Input("x: complex64")
+ .Output("y: float")
+ .Doc(R"doc(
+Computes the complex absolute value of a tensor.
+
+Given a tensor `x` of complex numbers, this operation returns a tensor of type
+`float` that is the absolute value of each element in `x`. All elements in `x`
+must be complex numbers of the form \\(a + bj\\). The absolute value is
+computed as \\( \sqrt{a^2 + b^2}\\).
+
+For example:
+
+```
+# tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]]
+tf.complex_abs(x) ==> [5.25594902, 6.60492229]
+```
+)doc");
+
+// Declares cwise unary operations signature: 't -> 't
+#define UNARY() \
+ Input("x: T").Output("y: T").Attr( \
+ "T: {float, double, int32, complex64, int64}")
+
+REGISTER_OP("Neg")
+ .UNARY()
+ .Doc(R"doc(
+Computes numerical negative value element-wise.
+I.e., \\(y = -x\\).
+)doc");
+
+REGISTER_OP("Inv")
+ .UNARY()
+ .Doc(R"doc(
+Computes the reciprocal of x element-wise.
+I.e., \\(y = 1 / x\\).
+)doc");
+
+REGISTER_OP("Square")
+ .UNARY()
+ .Doc(R"doc(
+Computes square of x element-wise.
+I.e., \\(y = x * x = x^2\\).
+)doc");
+
+REGISTER_OP("Sqrt")
+ .UNARY()
+ .Doc(R"doc(
+Computes square root of x element-wise.
+I.e., \\(y = \sqrt{x} = x^{1/2}\\).
+)doc");
+
+REGISTER_OP("Rsqrt")
+ .UNARY()
+ .Doc(R"doc(
+Computes reciprocal of square root of x element-wise.
+I.e., \\(y = 1 / \sqrt{x}\\).
+)doc");
+
+REGISTER_OP("Exp")
+ .UNARY()
+ .Doc(R"doc(
+Computes exponential of x element-wise. \\(y = e^x\\).
+)doc");
+
+REGISTER_OP("Log")
+ .UNARY()
+ .Doc(R"doc(
+Computes natural logrithm of x element-wise.
+I.e., \\(y = \log_e x\\).
+)doc");
+
+REGISTER_OP("Tanh")
+ .UNARY()
+ .Doc(R"doc(
+Computes hyperbolic tangent of `x` element-wise.
+)doc");
+
+REGISTER_OP("Sigmoid")
+ .UNARY()
+ .Doc(R"doc(
+Computes sigmoid of `x` element-wise.
+
+Specifically, `y = 1 / (1 + exp(-x))`.
+)doc");
+
+REGISTER_OP("Sin")
+ .UNARY()
+ .Doc(R"doc(
+Computes sin of x element-wise.
+)doc");
+
+REGISTER_OP("Cos")
+ .UNARY()
+ .Doc(R"doc(
+Computes cos of x element-wise.
+)doc");
+
+#undef UNARY
+
+REGISTER_OP("IsNan")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are NaN.
+)doc");
+
+REGISTER_OP("IsInf")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are Inf.
+)doc");
+
+REGISTER_OP("IsFinite")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are finite.
+)doc");
+
+REGISTER_OP("Sign")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns an element-wise indication of the sign of a number.
+
+y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0.
+)doc");
+
+REGISTER_OP("Floor")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns element-wise largest integer not greater than x.
+)doc");
+
+REGISTER_OP("Ceil")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns element-wise smallest integer in not less than x.
+)doc");
+
+// Declares cwise binary operations signature: 't, 't -> 't.
+
+#define BINARY_MORE() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {float, double, int8, int16, int32, complex64, int64}")
+
+#define BINARY_FEWER() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {float, double, int32, complex64, int64}")
+
+REGISTER_OP("Add")
+ .BINARY_MORE()
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns x + y element-wise.
+
+*NOTE*: Add supports broadcasting. AddN does not.
+)doc");
+
+REGISTER_OP("Sub")
+ .BINARY_FEWER()
+ .Doc(R"doc(
+Returns x - y element-wise.
+)doc");
+
+REGISTER_OP("Mul")
+ .BINARY_MORE()
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns x * y element-wise.
+)doc");
+
+REGISTER_OP("Div")
+ .BINARY_FEWER()
+ .Doc(R"doc(
+Returns x / y element-wise.
+)doc");
+
+#undef BINARY_FEWER
+#undef BINARY_MORE
+
+REGISTER_OP("Maximum")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, int64}")
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts.
+)doc");
+
+REGISTER_OP("Minimum")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, int64}")
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts.
+)doc");
+
+REGISTER_OP("Mod")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {int32, int64, float, double}")
+ .Doc(R"doc(
+Returns element-wise remainder of division.
+)doc");
+
+REGISTER_OP("Pow")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, complex64, int64}")
+ .Doc(R"doc(
+Computes the power of one value to another.
+
+Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+corresponding elements in `x` and `y`. For example:
+
+```
+# tensor 'x' is [[2, 2]], [3, 3]]
+# tensor 'y' is [[8, 16], [2, 3]]
+tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+```
+)doc");
+
+// --------------------------------------------------------------------------
+
+// Declares cwise binary comparison operations signature: 't, 't -> bool,
+// where 't has a natural total order.
+#define COMPARISON() \
+ Input("x: T").Input("y: T").Output("z: bool").Attr( \
+ "T: {float, double, int32, int64}")
+
+REGISTER_OP("Less")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x < y) element-wise.
+)doc");
+
+REGISTER_OP("LessEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x <= y) element-wise.
+)doc");
+
+REGISTER_OP("Greater")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x > y) element-wise.
+)doc");
+
+REGISTER_OP("GreaterEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x >= y) element-wise.
+)doc");
+
+#undef COMPARISON
+
+// --------------------------------------------------------------------------
+
+#define COMPARISON() \
+ Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
+ "T: {float, double, int32, int64, complex64, quint8, qint8, qint32}")
+
+REGISTER_OP("Equal")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x == y) element-wise.
+)doc");
+
+REGISTER_OP("NotEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x != y) element-wise.
+)doc");
+
+#undef COMPARISON
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LogicalNot")
+ .Input("x: bool")
+ .Output("y: bool")
+ .Doc(R"doc(
+Returns the truth value of NOT x element-wise.
+)doc");
+
+#define BINARY_LOGICAL() \
+ Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative()
+
+REGISTER_OP("LogicalAnd")
+ .BINARY_LOGICAL()
+ .Doc(R"doc(
+Returns the truth value of x AND y element-wise.
+)doc");
+
+REGISTER_OP("LogicalOr")
+ .BINARY_LOGICAL()
+ .Doc(R"doc(
+Returns the truth value of x OR y element-wise.
+)doc");
+
+#undef BINARY_LOGICAL
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Select")
+ .Input("condition: bool")
+ .Input("t: T")
+ .Input("e: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Selects elements from `t` or `e`, depending on `condition`.
+
+The `condition`, `t`, and `e` tensors must all have the same shape,
+and the output will also have that shape. The `condition` tensor acts
+as an element-wise mask that chooses, based on the value at each
+element, whether the corresponding element in the output should be
+taken from `t` (if true) or `e` (if false). For example:
+
+For example:
+
+```prettyprint
+# 'condition' tensor is [[True, False]
+# [True, False]]
+# 't' is [[1, 1],
+# [1, 1]]
+# 'e' is [[2, 2],
+# [2, 2]]
+select(condition, t, e) ==> [[1, 2],
+ [1, 2]]
+```
+
+t:= A `Tensor` with the same shape as `condition`.
+e:= A `Tensor` with the same type and shape as `t`.
+out:= A `Tensor` with the same type and shape as `t` and `e`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("MatMul")
+ .Input("a: T")
+ .Input("b: T")
+ .Output("product: T")
+ .Attr("transpose_a: bool = false")
+ .Attr("transpose_b: bool = false")
+ .Attr("T: {float, double, int32, complex64}")
+ .Doc(R"doc(
+Multiply the matrix "a" by the matrix "b".
+
+The inputs must be two-dimensional matrices and the inner dimension of
+"a" (after being transposed if transpose_a is true) must match the
+outer dimension of "b" (after being transposed if transposed_b is
+true).
+
+*Note*: The default kernel implementation for MatMul on GPUs uses
+cublas.
+
+transpose_a: If true, "a" is transposed before multiplication.
+transpose_b: If true, "b" is transposed before multiplication.
+)doc");
+
+REGISTER_OP("SparseMatMul")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("product: float")
+ .Attr("transpose_a: bool = false")
+ .Attr("transpose_b: bool = false")
+ .Attr("a_is_sparse: bool = false")
+ .Attr("b_is_sparse: bool = false")
+ .Doc(R"doc(
+Multiply matrix "a" by matrix "b".
+
+The inputs must be two-dimensional matrices and the inner dimension of "a" must
+match the outer dimension of "b". This op is optimized for the case where at
+least one of "a" or "b" is sparse. The breakeven for using this versus a dense
+matrix multiply on one platform was 30% zero values in the sparse matrix.
+)doc");
+
+// --------------------------------------------------------------------------
+
+// For operations where the output is a reduction function along some
+// dimensions of the input.
+REGISTER_OP("Sum")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the sum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Mean")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the mean of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Prod")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the product of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Min")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the minimum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Max")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the maximum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("ArgMax")
+ .Input("input: T")
+ .Input("dimension: int32")
+ .Output("output: int64")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Returns the index with the largest value across dimensions of a tensor.
+
+dimension: int32, 0 <= dimension < rank(input). Describes which dimension
+ of the input Tensor to reduce across. For vectors, use dimension = 0.
+)doc");
+
+REGISTER_OP("ArgMin")
+ .Input("input: T")
+ .Input("dimension: int32")
+ .Output("output: int64")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Returns the index with the smallest value across dimensions of a tensor.
+
+dimension: int32, 0 <= dimension < rank(input). Describes which dimension
+ of the input Tensor to reduce across. For vectors, use dimension = 0.
+)doc");
+
+REGISTER_OP("SegmentSum")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the sum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \sum_j data_j\\) where sum is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMean")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the mean along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
+over `j` such that `segment_ids[j] == i` and `N` is the total number of
+values summed.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMean.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentProd")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the product along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \prod_j data_j\\) where the product is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentProd.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMin")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the minimum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \min_j(data_j)\\) where `min` is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMin.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMax")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the maximum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \max_j(data_j)\\) where `max` is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMax.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("UnsortedSegmentSum")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the sum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \sum_j data_j\\) where sum is over `j` such
+that `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`
+need not be sorted and need not cover all values in the full
+ range of valid values.
+
+If the sum is empty for a given segment ID `i`, `output[i] = 0`.
+
+`num_segments` should equal the number of distinct segment IDs.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/UnsortedSegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension.
+
+output: Has same shape as data, except for dimension_0 which
+has size `num_segments`.
+
+)doc");
+
+REGISTER_OP("SparseSegmentSum")
+ .Input("data: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
+dimension, selecting a subset of dimension_0, specified by `indices`.
+
+For example:
+
+```prettyprint
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+# Select two rows, one segment.
+tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+ ==> [[0 0 0 0]]
+
+# Select two rows, two segment.
+tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+ ==> [[ 1 2 3 4]
+ [-1 -2 -3 -4]]
+
+# Select all rows, two segments.
+tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+ ==> [[0 0 0 0]
+ [5 6 7 8]]
+
+# Which is equivalent to:
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+```
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SparseSegmentMean")
+ .Input("data: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes the mean along sparse segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+dimension, selecting a subset of dimension_0, specified by `indices`.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+
+)doc");
+
+REGISTER_OP("SparseSegmentMeanGrad")
+ .Input("grad: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Input("output_dim0: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes gradients for SparseSegmentMean.
+
+Returns tensor "output" with same shape as grad, except for dimension_0 whose
+value is output_dim0.
+
+grad: gradient propagated to the SparseSegmentMean op.
+indices: indices passed to the corresponding SparseSegmentMean op.
+segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+output_dim0: dimension_0 of "data" passed to SparseSegmentMean op.
+)doc");
+
+REGISTER_OP("All")
+ .Input("input: bool")
+ .Input("reduction_indices: int32")
+ .Output("output: bool")
+ .Attr("keep_dims: bool = false")
+ .Doc(R"doc(
+Computes the "logical and" of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Any")
+ .Input("input: bool")
+ .Input("reduction_indices: int32")
+ .Attr("keep_dims: bool = false")
+ .Output("output: bool")
+ .Doc(R"doc(
+Computes the "logical or" of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Range")
+ .Input("start: int32")
+ .Input("limit: int32")
+ .Input("delta: int32")
+ .Output("output: int32")
+ .Doc(R"doc(
+Creates a sequence of integers.
+
+This operation creates a sequence of integers that begins at `start` and
+extends by increments of `delta` up to but not including `limit`.
+
+For example:
+
+```
+# 'start' is 3
+# 'limit' is 18
+# 'delta' is 3
+tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+```
+
+start: 0-D (scalar). First entry in the sequence.
+limit: 0-D (scalar). Upper limit of sequence, exclusive.
+delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+output: 1-D.
+)doc");
+
+REGISTER_OP("LinSpace")
+ .Input("start: T")
+ .Input("stop: T")
+ .Input("num: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Generates values in an interval.
+
+A sequence of `num` evenly-spaced values are generated beginning at `start`.
+If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
+so that the last one is exactly `stop`.
+
+For example:
+
+```
+tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
+```
+
+start: First entry in the range.
+stop: Last entry in the range.
+num: Number of values to generate.
+output: 1-D. The generated values.
+)doc");
+
+REGISTER_OP("Complex")
+ .Input("real: float")
+ .Input("imag: float")
+ .Output("out: complex64")
+ .Doc(R"doc(
+Converts two real numbers to a complex number.
+
+Given a tensor `real` representing the real part of a complex number, and a
+tensor `imag` representing the imaginary part of a complex number, this
+operation returns complex numbers elementwise of the form \\(a + bj\\), where
+*a* represents the `real` part and *b* represents the `imag` part.
+
+The input tensors `real` and `imag` must have the same shape.
+
+For example:
+
+```
+# tensor 'real' is [2.25, 3.25]
+# tensor `imag` is [4.75, 5.75]
+tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
+```
+)doc");
+
+REGISTER_OP("Real")
+ .Input("in: complex64")
+ .Output("out: float")
+ .Doc(R"doc(
+Returns the real part of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of type
+`float` that is the real part of each element in `in`. All elements in `in`
+must be complex numbers of the form \\(a + bj\\), where *a* is the real part
+returned by this operation and *b* is the imaginary part.
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.real(in) ==> [-2.25, 3.25]
+```
+)doc");
+
+REGISTER_OP("Imag")
+ .Input("in: complex64")
+ .Output("out: float")
+ .Doc(R"doc(
+Returns the imaginary part of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of type
+`float` that is the imaginary part of each element in `in`. All elements in `in`
+must be complex numbers of the form \\(a + bj\\), where *a* is the real part
+and *b* is the imaginary part returned by this operation.
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.imag(in) ==> [4.75, 5.75]
+```
+)doc");
+
+REGISTER_OP("Conj")
+ .Input("in: complex64")
+ .Output("out: complex64")
+ .Doc(R"doc(
+Returns the complex conjugate of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of
+complex numbers that are the complex conjugate of each element in `in`. The
+complex numbers in `in` must be of the form \\(a + bj\\), where *a* is the real
+part and *b* is the imaginary part.
+
+The complex conjugate returned by this operation is of the form \\(a - bj\\).
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.conj(in) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+```
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
new file mode 100644
index 0000000000..03ba49d5cd
--- /dev/null
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -0,0 +1,543 @@
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/util/padding.h"
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("AvgPool")
+ .Input("value: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Performs average pooling on the input.
+
+Each entry in `output` is the mean of the corresponding size `ksize`
+window in `value`.
+
+value: 4-D with shape `[batch, height, width, channels]`.
+ksize: The size of the sliding window for each dimension of `value`.
+strides: The stride of the sliding window for each dimension of `value`.
+padding: The type of padding algorithm to use.
+output: The average pooled output tensor.
+)doc");
+
+REGISTER_OP("AvgPoolGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes gradients of the average pooling function.
+
+orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
+ the output of `avg_pool`.
+ksize: The size of the sliding window for each dimension of the input.
+strides: The stride of the sliding window for each dimension of the input.
+padding: The type of padding algorithm to use.
+output: 4-D. Gradients w.r.t. the input of `avg_pool`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BatchNormWithGlobalNormalization")
+ .Input("t: T")
+ .Input("m: T")
+ .Input("v: T")
+ .Input("beta: T")
+ .Input("gamma: T")
+ .Output("result: T")
+ .Attr("T: numbertype")
+ .Attr("variance_epsilon: float")
+ .Attr("scale_after_normalization: bool")
+ .Doc(R"doc(
+Batch normalization.
+
+t: A 4D input Tensor.
+m: A 1D mean Tensor with size matching the last dimension of t.
+ This is the first output from MovingMoments.
+v: A 1D variance Tensor with size matching the last dimension of t.
+ This is the second output from MovingMoments.
+beta: A 1D beta Tensor with size matching the last dimension of t.
+ An offset to be added to the normalized tensor.
+gamma: A 1D gamma Tensor with size matching the last dimension of t.
+ If "scale_after_normalization" is true, this tensor will be multiplied
+ with the normalized tensor.
+variance_epsilon: A small float number to avoid dividing by 0.
+scale_after_normalization: A bool indicating whether the resulted tensor
+ needs to be multiplied with gamma.
+)doc");
+
+REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
+ .Input("t: T")
+ .Input("m: T")
+ .Input("v: T")
+ .Input("gamma: T")
+ .Input("backprop: T")
+ .Output("dx: T")
+ .Output("dm: T")
+ .Output("dv: T")
+ .Output("db: T")
+ .Output("dg: T")
+ .Attr("T: numbertype")
+ .Attr("variance_epsilon: float")
+ .Attr("scale_after_normalization: bool")
+ .Doc(R"doc(
+Gradients for batch normalization.
+
+t: A 4D input Tensor.
+m: A 1D mean Tensor with size matching the last dimension of t.
+ This is the first output from MovingMoments.
+v: A 1D variance Tensor with size matching the last dimension of t.
+ This is the second output from MovingMoments.
+gamma: A 1D gamma Tensor with size matching the last dimension of t.
+ If "scale_after_normalization" is true, this Tensor will be multiplied
+ with the normalized Tensor.
+backprop: 4D backprop Tensor.
+variance_epsilon: A small float number to avoid dividing by 0.
+scale_after_normalization: A bool indicating whether the resulted tensor
+ needs to be multiplied with gamma.
+
+dx: 4D backprop tensor for input.
+dm: 1D backprop tensor for mean.
+dv: 1D backprop tensor for variance.
+db: 1D backprop tensor for beta.
+dg: 1D backprop tensor for gamma.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BiasAdd")
+ .Attr("T: numbertype")
+ .Input("value: T")
+ .Input("bias: T")
+ .Output("output: T")
+ .Doc(R"doc(
+Adds `bias` to `value`.
+
+This is a special case of `tf.add` where `bias` is restricted to be 1-D.
+Broadcasting is supported, so `value` may have any number of dimensions.
+
+value: Any number of dimensions.
+bias: 1-D with size the last dimension of `value`.
+output: Broadcasted sum of `value` and `bias`.
+)doc");
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Conv2D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes a 2-D convolution given 4-D `input` and `filter` tensors.
+
+Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
+and a filter / kernel tensor of shape
+`[filter_height, filter_width, in_channels, out_channels]`, this op
+performs the following:
+
+1. Flattens the filter to a 2-D matrix with shape
+ `[filter_height * filter_width * in_channels, output_channels]`.
+2. Extracts image patches from the the input tensor to form a *virtual*
+ tensor of shape `[batch, out_height, out_width,
+ filter_height * filter_width * in_channels]`.
+3. For each patch, right-multiplies the filter matrix and the image patch
+ vector.
+
+In detail,
+
+ output[b, i, j, k] =
+ sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
+ filter[di, dj, q, k]
+
+Must have `strides[0] = strides[3] = 1`. For the most common case of the same
+horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
+
+strides: 1-D of length 4. The stride of the sliding window for each dimension
+ of `input`.
+padding: The type of padding algorithm to use.
+)doc");
+
+REGISTER_OP("Conv2DBackpropInput")
+ .Input("input_sizes: int32")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of convolution with respect to the input.
+
+input_sizes: An integer vector representing the shape of `input`,
+ where `input` is a 4-D `[batch, height, width, channels]` tensor.
+filter: 4-D with shape
+ `[filter_height, filter_width, in_channels, out_channels]`.
+out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
+ Gradients w.r.t. the output of the convolution.
+strides: The stride of the sliding window for each dimension of the input
+ of the convolution.
+padding: The type of padding algorithm to use.
+output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient
+ w.r.t. the input of the convolution.
+)doc");
+
+// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
+// more general string attribute ('kernel_impl'?) that can be used to
+// select among several possible implementations.
+REGISTER_OP("Conv2DBackpropFilter")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Output("output: T")
+ .Input("out_backprop: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of convolution with respect to the filter.
+
+input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+filter_sizes: An integer vector representing the tensor shape of `filter`,
+ where `filter` is a 4-D
+ `[filter_height, filter_width, in_channels, out_channels]` tensor.
+out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
+ Gradients w.r.t. the output of the convolution.
+strides: The stride of the sliding window for each dimension of the input
+ of the convolution.
+padding: The type of padding algorithm to use.
+output: 4-D with shape
+ `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
+ the `filter` input of the convolution.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("L2Loss")
+ .Input("t: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+L2 Loss.
+
+Computes half the L2 norm of a tensor without the `sqrt`:
+
+ output = sum(t ** 2) / 2
+
+t: Typically 2-D, but may have any dimensions.
+output: 0-D.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LRN")
+ .Input("input: float")
+ .Output("output: float")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Doc(R"doc(
+Local Response Normalization.
+
+The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
+dimension), and each vector is normalized independently. Within a given vector,
+each component is divided by the weighted, squared sum of inputs within
+`depth_radius`. In detail,
+
+ sqr_sum[a, b, c, d] =
+ sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
+ output = input / (bias + alpha * sqr_sum ** beta)
+
+For details, see [Krizhevsky et al., ImageNet classification with deep
+convolutional neural networks (NIPS 2012)]
+(http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
+
+input: 4-D.
+depth_radius: 0-D. Half-width of the 1-D normalization window.
+bias: An offset (usually positive to avoid dividing by 0).
+alpha: A scale factor, usually positive.
+beta: An exponent.
+)doc");
+
+REGISTER_OP("LRNGrad")
+ .Input("input_grads: float")
+ .Input("input_image: float")
+ .Input("output_image: float")
+ .Output("output: float")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Doc(R"doc(
+Gradients for Local Response Normalization.
+
+input_grads: 4-D with shape `[batch, height, width, channels]`.
+input_image: 4-D with shape `[batch, height, width, channels]`.
+output_image: 4-D with shape `[batch, height, width, channels]`.
+depth_radius: A depth radius.
+bias: An offset (usually > 0 to avoid dividing by 0).
+alpha: A scale factor, usually positive.
+beta: An exponent.
+output: The gradients for LRN.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("MaxPool")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Input("input: float")
+ .Output("output: float")
+ .Doc(R"doc(
+Performs max pooling on the input.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: 4-D input to pool over.
+output: The max pooled output tensor.
+)doc");
+
+REGISTER_OP("MaxPoolGrad")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Input("orig_input: float")
+ .Input("orig_output: float")
+ .Input("grad: float")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients w.r.t. the output of `max_pool`.
+output: Gradients w.r.t. the input to `max_pool`.
+)doc");
+
+REGISTER_OP("MaxPoolWithArgmax")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr("Targmax: {int32, int64} = DT_INT64")
+ .Attr(GetPaddingAttrString())
+ .Input("input: float")
+ .Output("output: float")
+ .Output("argmax: Targmax")
+ .Doc(R"doc(
+Performs max pooling on the input and outputs both max values and indices.
+
+The indices in `argmax` are flattened, so that a maximum value at position
+`[b, y, x, c]` becomes flattened index
+`((b * height + y) * width + x) * channels + c`.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: 4-D with shape `[batch, height, width, channels]`. Input to pool over.
+output: The max pooled output tensor.
+argmax: 4-D. The flattened indices of the max values chosen for each output.
+)doc");
+
+REGISTER_OP("MaxPoolGradWithArgmax")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("Targmax: {int32, int64}")
+ .Input("input: float")
+ .Input("grad: float")
+ .Input("argmax: Targmax")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: The original input.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the
+ output of `max_pool`.
+argmax: The indices of the maximum values chosen for each output of `max_pool`.
+output: Gradients w.r.t. the input of `max_pool`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Relu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear: `max(features, 0)`.
+)doc");
+
+REGISTER_OP("ReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear gradients for a Relu operation.
+
+gradients: The backpropagated gradients to the corresponding Relu operation.
+features: The features passed as input to the corresponding Relu operation.
+backprops: The gradients: `gradients * features * (features > 0)`.
+)doc");
+
+REGISTER_OP("Relu6")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear 6: `min(max(features, 0), 6)`.
+)doc");
+
+REGISTER_OP("Relu6Grad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear 6 gradients for a Relu6 operation.
+
+gradients: The backpropagated gradients to the corresponding Relu6 operation.
+features: The features passed as input to the corresponding Relu6 operation.
+backprops: The gradients:
+ `gradients * features * (features > 0) * (features < 6)`.
+)doc");
+
+REGISTER_OP("Softplus")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes softplus: `log(exp(features) + 1)`.
+)doc");
+
+REGISTER_OP("SoftplusGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes softplus gradients for a softplus operation.
+
+gradients: The backpropagated gradients to the corresponding softplus operation.
+features: The features passed as input to the corresponding softplus operation.
+backprops: The gradients: `gradients / (1 + exp(-features))`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Softmax")
+ .Input("logits: T")
+ .Output("softmax: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes softmax activations.
+
+For each batch `i` and class `j` we have
+
+ softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
+
+logits: 2-D with shape `[batch_size, num_classes]`.
+softmax: Same shape as `logits`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("SoftmaxCrossEntropyWithLogits")
+ .Input("features: T")
+ .Input("labels: T")
+ .Output("loss: T")
+ .Output("backprop: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes softmax cross entropy cost and gradients to backpropagate.
+
+Inputs are the logits, not probabilities.
+
+features: batch_size x num_classes matrix
+labels: batch_size x num_classes matrix
+ The caller must ensure that each batch of labels represents a valid
+ probability distribution.
+loss: Per example loss (batch_size vector).
+backprop: backpropagated gradients (batch_size x num_classes matrix).
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("InTopK")
+ .Attr("k: int")
+ .Input("predictions: float")
+ .Input("targets: int32")
+ .Output("precision: bool")
+ .Doc(R"doc(
+Says whether the targets are in the top K predictions.
+
+This outputs a batch_size bool array, an entry out[i] is true if the
+prediction for the target class is among the top k predictions among
+all predictions for example i. Note that the behavior of InTopK differs
+from the TopK op in its handling of ties; if multiple classes have the
+same prediction value and straddle the top-k boundary, all of those
+classes are considered to be in the top k.
+
+More formally, let
+
+ \\(predictions_i\\) be the predictions for all classes for example i,
+ \\(targets_i\\) be the target class for example i,
+ \\(out_i\\) be the output for example i,
+
+$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
+
+predictions: A batch_size x classes tensor
+targets: A batch_size vector of class ids
+k: Number of top elements to look at for computing precision
+precision: Computed Precision at k as a bool Tensor
+
+)doc");
+
+REGISTER_OP("TopK")
+ .Attr("k: int >= 1")
+ .Input("input: T")
+ .Output("values: T")
+ .Output("indices: int32")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Returns the values and indices of the k largest elements for each row.
+
+\\(values_{i, j}\\) represents the j-th largest element in \\(input_i\\).
+
+\\(indices_{i, j}\\) gives the column index of the corresponding element,
+such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two
+elements are equal, the lower-index element appears first.
+
+k: Number of top elements to look for within each row
+input: A batch_size x classes tensor
+values: A batch_size x k tensor with the k largest elements for each row,
+ sorted in descending order
+indices: A batch_size x k tensor with the index of each value within each row
+
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/no_op.cc b/tensorflow/core/ops/no_op.cc
new file mode 100644
index 0000000000..52778917cb
--- /dev/null
+++ b/tensorflow/core/ops/no_op.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("NoOp")
+ .Doc(R"doc(
+Does nothing. Only useful as a placeholder for control edges.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
new file mode 100644
index 0000000000..7fcaa3abf1
--- /dev/null
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -0,0 +1,104 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("DecodeRaw")
+ .Input("bytes: string")
+ .Output("output: out_type")
+ .Attr("out_type: {float,double,int32,uint8,int16,int8,int64}")
+ .Attr("little_endian: bool = true")
+ .Doc(R"doc(
+Reinterpret the bytes of a string as a vector of numbers.
+
+bytes: All the elements must have the same length.
+little_endian: Whether the input bytes are in little-endian order.
+ Ignored for out_types that are stored in a single byte like uint8.
+output: A Tensor with one more dimension than the input bytes. The
+ added dimension will have size equal to the length of the elements
+ of bytes divided by the number of bytes to represent out_type.
+)doc");
+
+REGISTER_OP("ParseExample")
+ .Input("serialized: string")
+ .Input("names: string")
+ .Input("sparse_keys: Nsparse * string")
+ .Input("dense_keys: Ndense * string")
+ .Input("dense_defaults: Tdense")
+ .Output("sparse_indices: Nsparse * int64")
+ .Output("sparse_values: sparse_types")
+ .Output("sparse_shapes: Nsparse * int64")
+ .Output("dense_values: Tdense")
+ .Attr("Nsparse: int >= 0") // Inferred from sparse_keys
+ .Attr("Ndense: int >= 0") // Inferred from dense_keys
+ .Attr("sparse_types: list({float,int64,string}) >= 0")
+ .Attr("Tdense: list({float,int64,string}) >= 0")
+ .Attr("dense_shapes: list(shape) >= 0")
+ .Doc(R"doc(
+Transforms a vector of brain.Example protos (as strings) into typed tensors.
+
+serialized: A vector containing a batch of binary serialized Example protos.
+names: A vector containing the names of the serialized protos.
+ May contain, for example, table key (descriptive) names for the
+ corresponding serialized protos. These are purely useful for debugging
+ purposes, and the presence of values here has no effect on the output.
+ May also be an empty vector if no names are available.
+ If non-empty, this vector must be the same length as "serialized".
+dense_keys: A list of Ndense string Tensors (scalars).
+ The keys expected in the Examples' features associated with dense values.
+dense_defaults: A list of Ndense Tensors (some may be empty).
+ dense_defaults[j] provides default values
+ when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+ provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+ The input type is inferred from dense_defaults[j], even when it's empty.
+ If dense_defaults[j] is not empty, its shape must match dense_shapes[j].
+dense_shapes: A list of Ndense shapes; the shapes of data in each Feature
+ given in dense_keys.
+ The number of elements in the Feature corresponding to dense_key[j]
+ must always equal dense_shapes[j].NumEntries().
+ If dense_shapes[j] == (D0, D1, ..., DN) then the the shape of output
+ Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):
+ The dense outputs are just the inputs row-stacked by batch.
+sparse_keys: A list of Nsparse string Tensors (scalars).
+ The keys expected in the Examples' features associated with sparse values.
+sparse_types: A list of Nsparse types; the data types of data in each Feature
+ given in sparse_keys.
+ Currently the ParseExample supports DT_FLOAT (FloatList),
+ DT_INT64 (Int64List), and DT_STRING (BytesList).
+)doc");
+
+REGISTER_OP("DecodeCSV")
+ .Input("records: string")
+ .Input("record_defaults: OUT_TYPE")
+ .Output("output: OUT_TYPE")
+ .Attr("OUT_TYPE: list({float,int32,int64,string})")
+ .Attr("field_delim: string = ','")
+ .Doc(R"doc(
+Convert CSV records to tensors. Each column maps to one tensor.
+
+RFC 4180 format is expected for the CSV records.
+(https://tools.ietf.org/html/rfc4180)
+Note that we allow leading and trailing spaces with int or float field.
+
+records: Each string is a record/row in the csv and all records should have
+ the same format.
+record_defaults: One tensor per column of the input record, with either a
+ scalar default value for that column or empty if the column is required.
+field_delim: delimiter to separate fields in a record.
+output: Each tensor will have the same shape as records.
+)doc");
+
+REGISTER_OP("StringToNumber")
+ .Input("string_tensor: string")
+ .Output("output: out_type")
+ .Attr("out_type: {float, int32} = DT_FLOAT")
+ .Doc(R"doc(
+Converts each string in the input Tensor to the specified numeric type.
+
+(Note that int32 overflow results in an error while float overflow
+results in a rounded value.)
+
+out_type: The numeric type to interpret each string in string_tensor as.
+output: A Tensor of the same shape as the input string_tensor.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
new file mode 100644
index 0000000000..4be4354b85
--- /dev/null
+++ b/tensorflow/core/ops/random_ops.cc
@@ -0,0 +1,108 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("RandomUniform")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a uniform distribution.
+
+The generated values follow a uniform distribution in the range `[0, 1)`. The
+lower bound 0 is included in the range, while the upper bound 1 is excluded.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with uniform random values.
+)doc");
+
+REGISTER_OP("RandomStandardNormal")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a normal distribution.
+
+The generated values will have mean 0 and standard deviation 1.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with random normal values.
+)doc");
+
+REGISTER_OP("TruncatedNormal")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a truncated normal distribution.
+
+The generated values follow a normal distribution with mean 0 and standard
+deviation 1, except that values whose magnitude is more than 2 standard
+deviations from the mean are dropped and re-picked.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with random truncated normal
+ values.
+)doc");
+
+REGISTER_OP("RandomShuffle")
+ .Input("value: T")
+ .SetIsStateful()
+ .Output("output: T")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("T: type")
+ .Doc(R"doc(
+Randomly shuffles a tensor along its first dimension.
+
+ The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
+ to one and only one `output[i]`. For example, a mapping that might occur for a
+ 3x2 tensor is:
+
+```prettyprint
+[[1, 2], [[5, 6],
+ [3, 4], ==> [1, 2],
+ [5, 6]] [3, 4]]
+```
+
+value: The tensor to be shuffled.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of same shape and type as `value`, shuffled along its first
+ dimension.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/sendrecv_ops.cc b/tensorflow/core/ops/sendrecv_ops.cc
new file mode 100644
index 0000000000..51158263c1
--- /dev/null
+++ b/tensorflow/core/ops/sendrecv_ops.cc
@@ -0,0 +1,99 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_Send")
+ .Input("tensor: T")
+ .Attr("T: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Sends the named tensor from send_device to recv_device.
+
+tensor: The tensor to send.
+tensor_name: The name of the tensor to send.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_Recv")
+ .Output("tensor: tensor_type")
+ .Attr("tensor_type: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Receives the named tensor from send_device on recv_device.
+
+tensor: The tensor to receive.
+tensor_name: The name of the tensor to receive.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_HostSend")
+ .Input("tensor: T")
+ .Attr("T: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Sends the named tensor from send_device to recv_device.
+
+_HostSend requires its input on host memory whereas _Send requires its
+input on device memory.
+
+tensor: The tensor to send.
+tensor_name: The name of the tensor to send.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_HostRecv")
+ .Output("tensor: tensor_type")
+ .Attr("tensor_type: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Receives the named tensor from send_device on recv_device.
+
+_HostRecv requires its input on host memory whereas _Recv requires its
+input on device memory.
+
+tensor: The tensor to receive.
+tensor_name: The name of the tensor to receive.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
new file mode 100644
index 0000000000..51262373d5
--- /dev/null
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -0,0 +1,134 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("SparseToDense")
+ .Input("sparse_indices: Tindices")
+ .Input("output_shape: Tindices")
+ .Input("sparse_values: T")
+ .Input("default_value: T")
+ .Output("dense: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Doc(R"doc(
+Converts a sparse representation into a dense tensor.
+
+Builds an array `dense` with shape `output_shape` such that
+
+```prettyprint
+# If sparse_indices is scalar
+dense[i] = (i == sparse_indices ? sparse_values : default_value)
+
+# If sparse_indices is a vector, then for each i
+dense[sparse_indices[i]] = sparse_values[i]
+
+# If sparse_indices is an n by d matrix, then for each i in [0, n)
+dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
+```
+
+All other values in `dense` are set to `default_value`. If `sparse_values` is a
+scalar, all sparse indices are set to this single value.
+
+sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
+ index where `sparse_values[i]` will be placed.
+output_shape: 1-D. Shape of the dense output tensor.
+sparse_values: 1-D. Values corresponding to each row of `sparse_indices`,
+ or a scalar value to be used for all sparse indices.
+default_value: Scalar value to set for indices not specified in
+ `sparse_indices`.
+dense: Dense output tensor of shape `output_shape`.
+)doc");
+
+REGISTER_OP("SparseConcat")
+ .Input("indices: N * int64")
+ .Input("values: N * T")
+ .Input("shapes: N * int64")
+ .Output("output_indices: int64")
+ .Output("output_values: T")
+ .Output("output_shape: int64")
+ .Attr("concat_dim: int >= 0")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Doc(R"doc(
+Concatenates a list of `SparseTensor` along the specified dimension.
+
+Concatenation is with respect to the dense versions of these sparse tensors.
+It is assumed that each input is a `SparseTensor` whose elements are ordered
+along increasing dimension number.
+
+All inputs' shapes must match, except for the concat dimension. The
+`indices`, `values`, and `shapes` lists must have the same length.
+
+The output shape is identical to the inputs', except along the concat
+dimension, where it is the sum of the inputs' sizes along that dimension.
+
+The output elements will be resorted to preserve the sort order along
+increasing dimension number.
+
+This op runs in `O(M log M)` time, where `M` is the total number of non-empty
+values across all inputs. This is due to the need for an internal sort in
+order to concatenate efficiently across an arbitrary dimension.
+
+For example, if `concat_dim = 1` and the inputs are
+
+ sp_inputs[0]: shape = [2, 3]
+ [0, 2]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ sp_inputs[1]: shape = [2, 4]
+ [0, 1]: "d"
+ [0, 2]: "e"
+
+then the output will be
+
+ shape = [2, 7]
+ [0, 2]: "a"
+ [0, 4]: "d"
+ [0, 5]: "e"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+Graphically this is equivalent to doing
+
+ [ a] concat [ d e ] = [ a d e ]
+ [b c ] [ ] [b c ]
+
+indices: 2-D. Indices of each input `SparseTensor`.
+values: 1-D. Non-empty values of each `SparseTensor`.
+shapes: 1-D. Shapes of each `SparseTensor`.
+output_indices: 2-D. Indices of the concatenated `SparseTensor`.
+output_values: 1-D. Non-empty values of the concatenated `SparseTensor`.
+output_shape: 1-D. Shape of the concatenated `SparseTensor`.
+concat_dim: Dimension to concatenate along.
+)doc");
+
+REGISTER_OP("SparseReorder")
+ .Input("input_indices: int64")
+ .Input("input_values: T")
+ .Input("input_shape: int64")
+ .Output("output_indices: int64")
+ .Output("output_values: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Reorders a SparseTensor into the canonical, row-major ordering.
+
+Note that by convention, all sparse ops preserve the canonical ordering along
+increasing dimension number. The only time ordering can be violated is during
+manual manipulation of the indices and values vectors to add entries.
+
+Reordering does not affect the shape of the SparseTensor.
+
+If the tensor has rank `R` and `N` non-empty values, `input_indices` has
+shape `[N, R]`, input_values has length `N`, and input_shape has length `R`.
+
+input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+ SparseTensor, possibly not in canonical ordering.
+input_values: 1-D. `N` non-empty values corresponding to `input_indices`.
+input_shape: 1-D. Shape of the input SparseTensor.
+output_indices: 2-D. `N x R` matrix with the same indices as input_indices, but
+ in canonical row-major ordering.
+output_values: 1-D. `N` non-empty values corresponding to `output_indices`.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
new file mode 100644
index 0000000000..da9fd4ad08
--- /dev/null
+++ b/tensorflow/core/ops/state_ops.cc
@@ -0,0 +1,290 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Variable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Holds state in the form of a tensor that persists across steps.
+
+Outputs a ref to the tensor state so it may be read or modified.
+TODO(zhifengc/mrry): Adds a pointer to a more detail document
+about sharing states in tensorflow.
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+container: If non-empty, this variable is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this variable is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TemporaryVariable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("var_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Returns a tensor that may be mutated, but only persists within a single step.
+
+This is an experimental op for internal use only and it is possible to use this
+op in unsafe ways. DO NOT USE unless you fully understand the risks.
+
+It is the caller's responsibility to ensure that 'ref' is eventually passed to a
+matching 'DestroyTemporaryVariable' op after all other uses have completed.
+
+Outputs a ref to the tensor state so it may be read or modified.
+
+ E.g.
+ var = state_ops._temporary_variable([1, 2], types.float_)
+ var_name = var.op.name
+ var = state_ops.assign(var, [[4.0, 5.0]])
+ var = state_ops.assign_add(var, [[6.0, 7.0]])
+ final = state_ops._destroy_temporary_variable(var, var_name=var_name)
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+var_name: Overrides the name used for the temporary variable resource. Default
+value is the name of the 'TemporaryVariable' op (which is guaranteed unique).
+)doc");
+
+REGISTER_OP("DestroyTemporaryVariable")
+ .Input("ref: Ref(T)")
+ .Output("value: T")
+ .Attr("T: type")
+ .Attr("var_name: string")
+ .Doc(R"doc(
+Destroys the temporary variable and returns its final value.
+
+Sets output to the value of the Tensor pointed to by 'ref', then destroys
+the temporary variable called 'var_name'.
+All other uses of 'ref' *must* have executed before this op.
+This is typically achieved by chaining the ref through each assign op, or by
+using control dependencies.
+
+Outputs the final value of the tensor pointed to by 'ref'.
+
+ref: A reference to the temporary variable tensor.
+var_name: Name of the temporary variable, usually the name of the matching
+'TemporaryVariable' op.
+)doc");
+
+REGISTER_OP("Assign")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("validate_shape: bool = true")
+ .Attr("use_locking: bool = true")
+ .SetAllowsUninitializedInput()
+ .Doc(R"doc(
+Update 'ref' by assigning 'value' to it.
+
+This operation outputs "ref" after the assignment is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node. May be uninitialized.
+value: The value to be assigned to the variable.
+validate_shape: If true, the operation will validate that the shape
+ of 'value' matches the shape of the Tensor being assigned to. If false,
+ 'ref' will take on the shape of 'value'.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been reset.
+)doc");
+
+REGISTER_OP("AssignAdd")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by adding 'value' to it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be added to the variable.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("AssignSub")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by subtracting 'value' from it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be subtracted to the variable.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("ScatterUpdate")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .Doc(R"doc(
+Applies sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+If `indices` contains duplicate entries, lexicographically later entries
+override earlier entries.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterUpdate.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to store in `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterAdd")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Adds sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] += updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] += updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterAdd.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to add to `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterSub")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Subtracts sparse updates to a variable reference.
+
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their (negated) contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterSub.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to subtract from `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("CountUpTo")
+ .Input("ref: Ref(T)")
+ .Output("output: T")
+ .Attr("limit: int")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Increments 'ref' until it reaches 'limit'.
+
+This operation outputs "ref" after the update is done. This makes it
+easier to chain operations that need to use the updated value.
+
+ref: Should be from a scalar `Variable` node.
+limit: If incrementing ref would bring it above limit, instead generates an
+ 'OutOfRange' error.
+output: A copy of the input before increment. If nothing else modifies the
+ input, the values produced will all be distinct.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
new file mode 100644
index 0000000000..57b471074c
--- /dev/null
+++ b/tensorflow/core/ops/string_ops.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("StringToHashBucket")
+ .Input("string_tensor: string")
+ .Output("output: int64")
+ .Attr("num_buckets: int >= 1")
+ .Doc(R"doc(
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process.
+
+Note that the hash function may change from time to time.
+
+num_buckets: The number of buckets.
+output: A Tensor of the same shape as the input string_tensor.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
new file mode 100644
index 0000000000..5f46c871b6
--- /dev/null
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
+// inputs or outputs in various ways.
+
+REGISTER_OP("ScalarSummary")
+ .Input("tags: string")
+ .Input("values: T")
+ .Output("summary: string")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with scalar values.
+
+The input `tags` and `values` must have the same shape. The generated summary
+has a summary value for each tag-value pair in `tags` and `values`.
+
+tags: 1-D. Tags for the summary.
+values: 1-D, same size as `tags. Values for the summary.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("HistogramSummary")
+ .Input("tag: string")
+ .Input("values: float")
+ .Output("summary: string")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with a histogram.
+
+The generated
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+has one summary value containing a histogram for `values`.
+
+This op reports an `OutOfRange` error if any value is not finite.
+
+tag: Scalar. Tag to use for the `Summary.Value`.
+values: Any shape. Values to use to build the histogram.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("ImageSummary")
+ .Input("tag: string")
+ .Input("tensor: float")
+ .Output("summary: string")
+ .Attr("max_images: int >= 1 = 3")
+ .Attr(
+ "bad_color: tensor = { dtype: DT_UINT8 "
+ "tensor_shape: { dim { size: 4 } } "
+ "int_val: 255 int_val: 0 int_val: 0 int_val: 255 }")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with images.
+
+The summary has up to `max_images` summary values containing images. The
+images are built from `tensor` which must be 4-D with shape `[batch_size,
+height, width, channels]` and where `channels` can be:
+
+* 1: `tensor` is interpreted as Grayscale.
+* 3: `tensor` is interpreted as RGB.
+* 4: `tensor` is interpreted as RGBA.
+
+The images have the same number of channels as the input tensor. Their values
+are normalized, one image at a time, to fit in the range `[0, 255]`. The
+op uses two different normalization algorithms:
+
+* If the input values are all positive, they are rescaled so the largest one
+ is 255.
+
+* If any input value is negative, the values are shifted so input value 0.0
+ is at 127. They are then rescaled so that either the smallest value is 0,
+ or the largest one is 255.
+
+The `tag` argument is a scalar `Tensor` of type `string`. It is used to
+build the `tag` of the summary values:
+
+* If `max_images` is 1, the summary value tag is '*tag*/image'.
+* If `max_images` is greater than 1, the summary value tags are
+ generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
+
+The `bad_color` argument is the color to use in the generated images for
+non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
+Each element must be in the range `[0, 255]` (It represents the value of a
+pixel in the output image). Non-finite values in the input tensor are
+replaced by this tensor in the output image. The default value is the color
+red.
+
+tag: Scalar. Used to build the `tag` attribute of the summary values.
+tensor: 4-D of shape `[batch_size, height, width, channels]` where
+ `channels` is 1, 3, or 4.
+max_images: Max number of batch elements to generate images for.
+bad_color: Color to use for pixels with non-finite values.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("MergeSummary")
+ .Input("inputs: N * string")
+ .Output("summary: string")
+ .Attr("N : int >= 1")
+ .Doc(R"doc(
+Merges summaries.
+
+This op creates a
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+protocol buffer that contains the union of all the values in the input
+summaries.
+
+When the Op is run, it reports an `InvalidArgument` error if multiple values
+in the summaries to merge use the same tag.
+
+inputs: Can be of any shape. Each must contain serialized `Summary` protocol
+ buffers.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
new file mode 100644
index 0000000000..e7b4e92fd5
--- /dev/null
+++ b/tensorflow/core/ops/training_ops.cc
@@ -0,0 +1,199 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ApplyGradientDescent")
+ .Input("var: Ref(T)")
+ .Input("alpha: T")
+ .Input("delta: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' by subtracting 'alpha' * 'delta' from it.
+
+var: Should be from a Variable().
+alpha: Scaling factor. Must be a scalar.
+delta: The change.
+out: Same as "var".
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyAdagrad")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the adagrad scheme.
+
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("SparseApplyAdagrad")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update relevant entries in '*var' and '*accum' according to the adagrad scheme.
+
+That is for rows we have grad for, we update var and accum as follows:
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Learning rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var and accum.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyMomentum")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("momentum: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the momentum scheme.
+
+accum = accum * momentum + grad
+var -= lr * accum
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+grad: The gradient.
+momentum: Momentum. Must be a scalar.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("SparseApplyMomentum")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Input("momentum: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update relevant entries in '*var' and '*accum' according to the momentum scheme.
+
+That is for rows we have grad for, we update var and accum as follows:
+
+accum = accum * momentum + grad
+var -= lr * accum
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Learning rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var and accum.
+momentum: Momentum. Must be a scalar.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyAdam")
+ .Input("var: Ref(T)")
+ .Input("m: Ref(T)")
+ .Input("v: Ref(T)")
+ .Input("beta1_power: T")
+ .Input("beta2_power: T")
+ .Input("lr: T")
+ .Input("beta1: T")
+ .Input("beta2: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the Adam algorithm.
+
+lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
+v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
+variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
+
+var: Should be from a Variable().
+m: Should be from a Variable().
+v: Should be from a Variable().
+beta1_power: Must be a scalar.
+beta2_power: Must be a scalar.
+lr: Scaling factor. Must be a scalar.
+beta1: Momentum factor. Must be a scalar.
+beta2: Momentum factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var, m, and v tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyRMSProp")
+ .Input("var: Ref(T)")
+ .Input("ms: Ref(T)")
+ .Input("mom: Ref(T)")
+ .Input("lr: T")
+ .Input("rho: T")
+ .Input("momentum: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the RMSProp algorithm.
+
+mean_square = decay * mean_square + (1-decay) * gradient ** 2
+Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+
+ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+var <- var - mom
+
+var: Should be from a Variable().
+ms: Should be from a Variable().
+mom: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+rho: Decay rate. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var, m, and v tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+} // namespace tensorflow