diff options
-rw-r--r-- | tensorflow/compiler/tests/unary_ops_test.py | 32 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/math.cc | 52 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/math.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/math_test.cc | 31 |
5 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 0419419ea5..5f25ff9002 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -406,6 +406,38 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.digamma, + np.array( + [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0], + [6.0, 8.0, 9.0]], + dtype=dtype), + expected=np.array( + [ + [ + -np.euler_gamma, -2 * np.log(2) - np.euler_gamma, + -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 - + np.euler_gamma + ], + [ + -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, + -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) - + 3 * np.log(3) / 2 - np.euler_gamma, + -np.pi / 2 - 4 * np.log(2) - + (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2))) + / np.sqrt(2) - np.euler_gamma + ], + [ + 1 - np.euler_gamma, 1.5 - np.euler_gamma, + 11 / 6.0 - np.euler_gamma + ], + [ + 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma, + 761 / 280.0 - np.euler_gamma + ], + ], + dtype=dtype)) + def quantize_and_dequantize_v2(x): return array_ops.quantize_and_dequantize_v2( x, -127, 127, signed_input=True, num_bits=8) diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 76ab8b4c00..4bb31f4117 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -212,5 +212,28 @@ class LgammaOp : public XlaOpKernel { }; // namespace REGISTER_XLA_OP(Name("Lgamma"), LgammaOp); +class DigammaOp : public XlaOpKernel { + public: + explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Digamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Digamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Digamma"), DigammaOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index fdc7057de3..2a7ac1d716 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -217,4 +217,56 @@ xla::XlaOp Lgamma(xla::XlaOp input) { return result; } +// Compute the Digamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) +xla::XlaOp Digamma(xla::XlaOp input) { + xla::XlaOp zero = xla::ScalarLike(input, 0); + xla::XlaOp one_half = xla::ScalarLike(input, 0.5); + xla::XlaOp one = xla::ScalarLike(input, 1); + + xla::XlaOp pi = xla::ScalarLike(input, M_PI); + + xla::XlaOp lanczos_gamma = xla::ScalarLike(input, kLanczosGamma); + xla::XlaOp lanczos_gamma_plus_one_half = + xla::ScalarLike(input, kLanczosGamma + 0.5); + xla::XlaOp log_lanczos_gamma_plus_one_half = + xla::ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + xla::XlaOp base_lanczos_coeff = xla::ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Gauss's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + xla::XlaOp need_to_reflect = xla::Lt(xla::Real(input), one_half); + xla::XlaOp z = xla::Select(need_to_reflect, -input, input - one); + + xla::XlaOp num = zero; + xla::XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + xla::XlaOp lanczos_coefficient = + xla::ScalarLike(input, kLanczosCoefficients[i]); + xla::XlaOp index = xla::ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + xla::XlaOp t = lanczos_gamma_plus_one_half + z; + xla::XlaOp log_t = log_lanczos_gamma_plus_one_half + + xla::Log1p(z / lanczos_gamma_plus_one_half); + + xla::XlaOp y = log_t + num / denom - lanczos_gamma / t; + xla::XlaOp reflection = y - pi * xla::Cos(pi * input) / xla::Sin(pi * input); + xla::XlaOp result = xla::Select(need_to_reflect, reflection, y); + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index c89c351cfc..e4c79b5f52 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -49,6 +49,9 @@ XlaOp ErfInv(XlaOp x); // Computes an approximation of the lgamma function. XlaOp Lgamma(XlaOp input); +// Computes an approximation of the digamma function. +XlaOp Digamma(XlaOp input); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 86e195a8c6..1df287d7db 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -105,5 +105,36 @@ XLA_TEST_F(MathTest, Lgamma) { error_spec_ = ErrorSpec{0.001}; ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(MathTest, Digamma) { + XlaBuilder builder(TestName()); + auto x = ConstantR1<float>(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, + 2.0, 3.0, 4.0, 6.0, 8.0, 9.0}); + Digamma(x); + + constexpr double euler_mascheroni = + 0.57721566490153286060651209008240243104215933593992; + std::vector<float> expected = { + static_cast<float>(-euler_mascheroni), + static_cast<float>(-2 * std::log(2) - euler_mascheroni), + static_cast<float>(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 - + euler_mascheroni), + static_cast<float>(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni), + static_cast<float>(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) - + 3 * std::log(3) / 2 - euler_mascheroni), + static_cast<float>( + -M_PI / 2 - 4 * std::log(2) - + (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) / + std::sqrt(2) - + euler_mascheroni), + static_cast<float>(1 - euler_mascheroni), + static_cast<float>(1.5 - euler_mascheroni), + static_cast<float>(11 / 6.0 - euler_mascheroni), + static_cast<float>(137 / 60.0 - euler_mascheroni), + static_cast<float>(363 / 140.0 - euler_mascheroni), + static_cast<float>(761 / 280.0 - euler_mascheroni)}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla |