aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/nn_ops.cc70
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc3
-rw-r--r--tensorflow/core/ops/ops.pbtxt54
3 files changed, 126 insertions, 1 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 3a25fd15da..1018742521 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1779,6 +1779,33 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
`gradients` otherwise.
)doc");
+REGISTER_OP("Selu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: {half, float, double}")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
+if < 0, `scale * features` otherwise.
+
+See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
+)doc");
+
+REGISTER_OP("SeluGrad")
+ .Input("gradients: T")
+ .Input("outputs: T")
+ .Output("backprops: T")
+ .Attr("T: {half, float, double}")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
+ .Doc(R"doc(
+Computes gradients for the scaled exponential linear (Selu) operation.
+
+gradients: The backpropagated gradients to the corresponding Selu operation.
+outputs: The outputs of the corresponding Selu operation.
+backprops: The gradients: `gradients * (outputs + scale * alpha)`
+if outputs < 0, `scale * gradients` otherwise.
+)doc");
+
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
@@ -1979,6 +2006,49 @@ precision: Computed Precision at `k` as a `bool Tensor`.
)doc");
+// This is the same as `InTopK`, but takes `k` as in input rather than an attr.
+REGISTER_OP("InTopKV2")
+ .Input("predictions: float")
+ .Input("targets: T")
+ .Input("k: T")
+ .Output("precision: bool")
+ .Attr("T: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle predictions;
+ ShapeHandle targets;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
+ DimensionHandle batch_size;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
+ c->set_output(0, c->Vector(batch_size));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Says whether the targets are in the top `K` predictions.
+
+This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
+prediction for the target class is among the top `k` predictions among
+all predictions for example `i`. Note that the behavior of `InTopK` differs
+from the `TopK` op in its handling of ties; if multiple classes have the
+same prediction value and straddle the top-`k` boundary, all of those
+classes are considered to be in the top `k`.
+
+More formally, let
+
+ \\(predictions_i\\) be the predictions for all classes for example `i`,
+ \\(targets_i\\) be the target class for example `i`,
+ \\(out_i\\) be the output for example `i`,
+
+$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
+
+predictions: A `batch_size` x `classes` tensor.
+targets: A `batch_size` vector of class ids.
+k: Number of top elements to look at for computing precision.
+precision: Computed precision at `k` as a `bool Tensor`.
+
+)doc");
+
namespace {
Status TopKShapeFn(InferenceContext* c) {
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index a60b1c3788..51e4f8bffe 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -412,7 +412,8 @@ TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
for (const char* op_name :
- {"ReluGrad", "Relu6Grad", "EluGrad", "SoftplusGrad", "SoftsignGrad"}) {
+ {"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad", "SoftplusGrad",
+ "SoftsignGrad"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "in0|in1");
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 468434bd28..2839575ec7 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -23384,6 +23384,60 @@ op {
description: "Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in\n`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.\n\n```python\n# a is a tensor.\n# e is a tensor of eigenvalues.\n# v is a tensor of eigenvectors.\ne, v = self_adjoint_eig(a)\ne = self_adjoint_eig(a, compute_v=False)\n```"
}
op {
+ name: "Selu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` if < 0, `scale * features` otherwise."
+ description: "See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)"
+}
+op {
+ name: "SeluGrad"
+ input_arg {
+ name: "gradients"
+ description: "The backpropagated gradients to the corresponding Selu operation."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "outputs"
+ description: "The outputs of the corresponding Selu operation."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ description: "The gradients: `gradients * (outputs + scale * alpha)` if outputs < 0,\n`scale * gradients` otherwise."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Computes gradients for the scaled exponential linear (Selu) operation."
+}
+op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"