aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-06 07:14:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-06 08:29:11 -0700
commita291c8ab3232af1c2458699101d8dbb6ae4661c6 (patch)
tree0abc7f2436f931a1914044edaa422f291e455e0e
parentb0989a0280971034bd838fa176ae3f92b210e8dc (diff)
Add Elu ops in XLA.
Change: 152383201
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc17
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc65
5 files changed, 94 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9efdaee7ab..7221a0a3c7 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -108,6 +108,12 @@ class BinaryOpsTest(XLATestCase):
expected=np.array([-75, -48, -21, 0], dtype=dtype))
self._testBinary(
+ gen_nn_ops._elu_grad,
+ np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
+ expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
+
+ self._testBinary(
gen_nn_ops._relu_grad,
np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index a0cd905f17..18c4e3dcb1 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -1214,6 +1214,23 @@ TEST_F(OpTest, DynamicStitch) {
});
}
+TEST_F(OpTest, Elu) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, EluGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 1e85d3a2c8..3f324d1071 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -210,6 +210,11 @@ class UnaryOpsTest(XLATestCase):
dtype=dtype))
self._assertOpOutputMatchesExpected(
+ nn_ops.elu,
+ np.array([[-1, 0, 1]], dtype=dtype),
+ expected=np.array([[-0.63212056, 0, 1]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0, 1]], dtype=dtype))
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 2ee80a41e8..e4f73b529f 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -26,6 +26,7 @@ tf_kernel_library(
"depthwise_conv_ops.cc",
"diag_op.cc",
"dynamic_stitch_op.cc",
+ "elu_op.cc",
"fill_op.cc",
"function_ops.cc",
"identity_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
new file mode 100644
index 0000000000..62a5e1bd42
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of XLA Elu Ops
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+namespace {
+
+class EluOp : public XlaOpKernel {
+ public:
+ explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Computes the max of the scalar input x and 0.
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto pred = b->Gt(ctx->Input(0), zero);
+ const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
+ ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1));
+ }
+};
+
+class EluGradOp : public XlaOpKernel {
+ public:
+ explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Return the lhs (incoming gradient) if the rhs (input feature) > 0,
+ // otherwise return lhs * (1 + rhs).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto grad = ctx->Input(0);
+ const auto activation = ctx->Input(1);
+ const auto exp_grad = b->Mul(grad, b->Add(activation, one));
+ const auto pred = b->Gt(activation, zero);
+ ctx->SetOutput(0, b->Select(pred, grad, exp_grad));
+ }
+};
+
+REGISTER_XLA_OP(Name("Elu"), EluOp);
+REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
+
+} // namespace
+} // namespace tensorflow