aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/elu_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/elu_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 62a5e1bd42..2fd27c5ca7 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Elu"), EluOp);
REGISTER_XLA_OP(Name("EluGrad"), EluGradOp);
+class SeluOp : public XlaOpKernel {
+ public:
+ explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Computes the max of the scalar input x and 0.
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
+ 1.0507009873554804934193349852946);
+ const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
+ 1.7580993408473768599402175208123);
+ const auto pred = b->Gt(ctx->Input(0), zero);
+ const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one);
+ ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)),
+ b->Mul(scale_alpha, expm1)));
+ }
+};
+
+class SeluGradOp : public XlaOpKernel {
+ public:
+ explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Return the lhs (incoming gradient) if the rhs (input feature) > 0,
+ // otherwise return lhs * (1 + rhs).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ const auto one = XlaHelpers::One(b, input_type(0));
+ const auto scale = XlaHelpers::FloatLiteral(b, input_type(0),
+ 1.0507009873554804934193349852946);
+ const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
+ 1.7580993408473768599402175208123);
+ const auto grad = ctx->Input(0);
+ const auto activation = ctx->Input(1);
+ const auto lin_grad = b->Mul(grad, scale);
+ const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha));
+ const auto pred = b->Gt(activation, zero);
+ ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad));
+ }
+};
+
+REGISTER_XLA_OP(Name("Selu"), SeluOp);
+REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp);
+
} // namespace
} // namespace tensorflow