diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/unary_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 90 |
1 files changed, 55 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 116a020437..e6ec794cfd 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -51,43 +51,18 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); - -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) -XLAJIT_MAKE_UNARY(Acos, - xla::ScalarLike(x, 2.0) * - xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x), - xla::ScalarLike(x, 1.0) + x)); - -// acosh(x) = log(x + sqrt(x^2 - 1)) -// = log(x + sqrt((x+1)*(x-1))) -XLAJIT_MAKE_UNARY(Acosh, - xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) * - (x - xla::ScalarLike(x, 1.0))))); - -// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -XLAJIT_MAKE_UNARY( - Asin, xla::ScalarLike(x, 2.0) * - xla::Atan2(x, xla::ScalarLike(x, 1.0) + - xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x))); - -// asinh(x) = log(x + sqrt(x^2 + 1)) -XLAJIT_MAKE_UNARY(Asinh, - xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0)))); - -XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0))); - -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) -XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) / - (xla::ScalarLike(x, 1.0) - x)) * - xla::ScalarLike(x, 0.5)); +XLAJIT_MAKE_UNARY(Acos, xla::Acos(x)); +XLAJIT_MAKE_UNARY(Acosh, xla::Acosh(x)); +XLAJIT_MAKE_UNARY(Asin, xla::Asin(x)) +XLAJIT_MAKE_UNARY(Asinh, xla::Asinh(x)); +XLAJIT_MAKE_UNARY(Atan, xla::Atan(x)); +XLAJIT_MAKE_UNARY(Atanh, xla::Atanh(x)); XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); -XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); +XLAJIT_MAKE_UNARY(Cosh, xla::Cosh(x)); XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); - XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); - XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); XLAJIT_MAKE_UNARY( @@ -99,7 +74,6 @@ XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Log, xla::Log(x)); - XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); @@ -136,7 +110,7 @@ XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); -XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5)); +XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) // @@ -153,7 +127,7 @@ XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) + XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x)); XLAJIT_MAKE_UNARY(Square, x* x); -XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x)); +XLAJIT_MAKE_UNARY(Tan, xla::Tan(x)); XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); @@ -189,5 +163,51 @@ class ErfcOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("Erfc"), ErfcOp); +class LgammaOp : public XlaOpKernel { + public: + explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Lgamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Lgamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Lgamma"), LgammaOp); + +class DigammaOp : public XlaOpKernel { + public: + explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Digamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Digamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Digamma"), DigammaOp); + } // namespace } // namespace tensorflow |