diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-06 07:14:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-06 08:29:11 -0700 |
commit | a291c8ab3232af1c2458699101d8dbb6ae4661c6 (patch) | |
tree | 0abc7f2436f931a1914044edaa422f291e455e0e | |
parent | b0989a0280971034bd838fa176ae3f92b210e8dc (diff) |
Add Elu ops in XLA.
Change: 152383201
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 6 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/tests/unary_ops_test.py | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/elu_op.cc | 65 |
5 files changed, 94 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9efdaee7ab..7221a0a3c7 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -108,6 +108,12 @@ class BinaryOpsTest(XLATestCase): expected=np.array([-75, -48, -21, 0], dtype=dtype)) self._testBinary( + gen_nn_ops._elu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), + expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index a0cd905f17..18c4e3dcb1 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -1214,6 +1214,23 @@ TEST_F(OpTest, DynamicStitch) { }); } +TEST_F(OpTest, Elu) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, EluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose<DataType>({DT_INT32, DT_FLOAT}); diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 1e85d3a2c8..3f324d1071 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -210,6 +210,11 @@ class UnaryOpsTest(XLATestCase): dtype=dtype)) self._assertOpOutputMatchesExpected( + nn_ops.elu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), expected=np.array([[0, 1]], dtype=dtype)) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 2ee80a41e8..e4f73b529f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -26,6 +26,7 @@ tf_kernel_library( "depthwise_conv_ops.cc", "diag_op.cc", "dynamic_stitch_op.cc", + "elu_op.cc", "fill_op.cc", "function_ops.cc", "identity_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc new file mode 100644 index 0000000000..62a5e1bd42 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of XLA Elu Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class EluOp : public XlaOpKernel { + public: + explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + } +}; + +class EluGradOp : public XlaOpKernel { + public: + explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto exp_grad = b->Mul(grad, b->Add(activation, one)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Elu"), EluOp); +REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); + +} // namespace +} // namespace tensorflow |