diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/elu_op.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 62a5e1bd42..2fd27c5ca7 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -61,5 +61,49 @@ class EluGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Elu"), EluOp); REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); +class SeluOp : public XlaOpKernel { + public: + explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), + b->Mul(scale_alpha, expm1))); + } +}; + +class SeluGradOp : public XlaOpKernel { + public: + explicit SeluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), + 1.0507009873554804934193349852946); + const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), + 1.7580993408473768599402175208123); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto lin_grad = b->Mul(grad, scale); + const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Selu"), SeluOp); +REGISTER_XLA_OP(Name("SeluGrad"), SeluGradOp); + } // namespace } // namespace tensorflow |