aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/unary_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc90
1 files changed, 55 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 116a020437..e6ec794cfd 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -51,43 +51,18 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x));
// Return x if x>0, otherwise -x.
XLAJIT_MAKE_UNARY(Abs, xla::Abs(x));
-
-// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
-XLAJIT_MAKE_UNARY(Acos,
- xla::ScalarLike(x, 2.0) *
- xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x),
- xla::ScalarLike(x, 1.0) + x));
-
-// acosh(x) = log(x + sqrt(x^2 - 1))
-// = log(x + sqrt((x+1)*(x-1)))
-XLAJIT_MAKE_UNARY(Acosh,
- xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) *
- (x - xla::ScalarLike(x, 1.0)))));
-
-// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
-XLAJIT_MAKE_UNARY(
- Asin, xla::ScalarLike(x, 2.0) *
- xla::Atan2(x, xla::ScalarLike(x, 1.0) +
- xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x)));
-
-// asinh(x) = log(x + sqrt(x^2 + 1))
-XLAJIT_MAKE_UNARY(Asinh,
- xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0))));
-
-XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0)));
-
-// atanh(x) = 0.5 * log((1 + x) / (1 - x))
-XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) /
- (xla::ScalarLike(x, 1.0) - x)) *
- xla::ScalarLike(x, 0.5));
+XLAJIT_MAKE_UNARY(Acos, xla::Acos(x));
+XLAJIT_MAKE_UNARY(Acosh, xla::Acosh(x));
+XLAJIT_MAKE_UNARY(Asin, xla::Asin(x))
+XLAJIT_MAKE_UNARY(Asinh, xla::Asinh(x));
+XLAJIT_MAKE_UNARY(Atan, xla::Atan(x));
+XLAJIT_MAKE_UNARY(Atanh, xla::Atanh(x));
XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x));
XLAJIT_MAKE_UNARY(Cos, xla::Cos(x));
-XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
+XLAJIT_MAKE_UNARY(Cosh, xla::Cosh(x));
XLAJIT_MAKE_UNARY(Sin, xla::Sin(x));
XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
-
XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
-
XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
XLAJIT_MAKE_UNARY(
@@ -99,7 +74,6 @@ XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
XLAJIT_MAKE_UNARY(Log, xla::Log(x));
-
XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
@@ -136,7 +110,7 @@ XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
-XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
+XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
// softplus(x) = log(1 + exp(x))
//
@@ -153,7 +127,7 @@ XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x));
XLAJIT_MAKE_UNARY(Square, x* x);
-XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x));
+XLAJIT_MAKE_UNARY(Tan, xla::Tan(x));
XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x));
XLAJIT_MAKE_UNARY(Real, xla::Real(x));
@@ -189,5 +163,51 @@ class ErfcOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("Erfc"), ErfcOp);
+class LgammaOp : public XlaOpKernel {
+ public:
+ explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Calculate lgamma using the Lanczos approximation
+ // (https://en.wikipedia.org/wiki/Lanczos_approximation).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp input = ctx->Input(0);
+ xla::PrimitiveType input_type = ctx->input_xla_type(0);
+
+ if (input_type == xla::F16 || input_type == xla::BF16) {
+ // The approximation works better with at least 32-bits of accuracy.
+ xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32);
+ xla::XlaOp result_f32 = xla::Lgamma(input_f32);
+ xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type);
+ ctx->SetOutput(0, result_x16);
+ } else {
+ xla::XlaOp result = xla::Lgamma(input);
+ ctx->SetOutput(0, result);
+ }
+ }
+}; // namespace
+REGISTER_XLA_OP(Name("Lgamma"), LgammaOp);
+
+class DigammaOp : public XlaOpKernel {
+ public:
+ explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Calculate lgamma using the Lanczos approximation
+ // (https://en.wikipedia.org/wiki/Lanczos_approximation).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp input = ctx->Input(0);
+ xla::PrimitiveType input_type = ctx->input_xla_type(0);
+
+ if (input_type == xla::F16 || input_type == xla::BF16) {
+ // The approximation works better with at least 32-bits of accuracy.
+ xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32);
+ xla::XlaOp result_f32 = xla::Digamma(input_f32);
+ xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type);
+ ctx->SetOutput(0, result_x16);
+ } else {
+ xla::XlaOp result = xla::Digamma(input);
+ ctx->SetOutput(0, result);
+ }
+ }
+}; // namespace
+REGISTER_XLA_OP(Name("Digamma"), DigammaOp);
+
} // namespace
} // namespace tensorflow