aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-16 17:13:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 17:17:04 -0700
commit2c442d26f36a0f167685fd31b9ecdb4e290c2b29 (patch)
tree8d8952c4ca9c02e83990a501bc86cf4c21ce2c13
parent6bfa38ef2963f0062fbe12d532ab188c7d5ea8dd (diff)
Implement digamma for XLA
Compute the Lgamma function using Lanczos' approximation from "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical Analysis series B. Vol. 1: digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) t(z) = z + kLanczosGamma + 1/2 A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) PiperOrigin-RevId: 204834091
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc23
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc52
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h3
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc31
5 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 0419419ea5..5f25ff9002 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -406,6 +406,38 @@ class UnaryOpsTest(xla_test.XLATestCase):
],
dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ math_ops.digamma,
+ np.array(
+ [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0],
+ [6.0, 8.0, 9.0]],
+ dtype=dtype),
+ expected=np.array(
+ [
+ [
+ -np.euler_gamma, -2 * np.log(2) - np.euler_gamma,
+ -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 -
+ np.euler_gamma
+ ],
+ [
+ -np.pi / 2 - 3 * np.log(2) - np.euler_gamma,
+ -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) -
+ 3 * np.log(3) / 2 - np.euler_gamma,
+ -np.pi / 2 - 4 * np.log(2) -
+ (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2)))
+ / np.sqrt(2) - np.euler_gamma
+ ],
+ [
+ 1 - np.euler_gamma, 1.5 - np.euler_gamma,
+ 11 / 6.0 - np.euler_gamma
+ ],
+ [
+ 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma,
+ 761 / 280.0 - np.euler_gamma
+ ],
+ ],
+ dtype=dtype))
+
def quantize_and_dequantize_v2(x):
return array_ops.quantize_and_dequantize_v2(
x, -127, 127, signed_input=True, num_bits=8)
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 76ab8b4c00..4bb31f4117 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -212,5 +212,28 @@ class LgammaOp : public XlaOpKernel {
}; // namespace
REGISTER_XLA_OP(Name("Lgamma"), LgammaOp);
+class DigammaOp : public XlaOpKernel {
+ public:
+ explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Calculate lgamma using the Lanczos approximation
+ // (https://en.wikipedia.org/wiki/Lanczos_approximation).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp input = ctx->Input(0);
+ xla::PrimitiveType input_type = ctx->input_xla_type(0);
+
+ if (input_type == xla::F16 || input_type == xla::BF16) {
+ // The approximation works better with at least 32-bits of accuracy.
+ xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32);
+ xla::XlaOp result_f32 = xla::Digamma(input_f32);
+ xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type);
+ ctx->SetOutput(0, result_x16);
+ } else {
+ xla::XlaOp result = xla::Digamma(input);
+ ctx->SetOutput(0, result);
+ }
+ }
+}; // namespace
+REGISTER_XLA_OP(Name("Digamma"), DigammaOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index fdc7057de3..2a7ac1d716 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -217,4 +217,56 @@ xla::XlaOp Lgamma(xla::XlaOp input) {
return result;
}
+// Compute the Digamma function using Lanczos' approximation from "A Precision
+// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
+// series B. Vol. 1:
+// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
+// t(z) = z + kLanczosGamma + 1/2
+// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
+// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
+xla::XlaOp Digamma(xla::XlaOp input) {
+ xla::XlaOp zero = xla::ScalarLike(input, 0);
+ xla::XlaOp one_half = xla::ScalarLike(input, 0.5);
+ xla::XlaOp one = xla::ScalarLike(input, 1);
+
+ xla::XlaOp pi = xla::ScalarLike(input, M_PI);
+
+ xla::XlaOp lanczos_gamma = xla::ScalarLike(input, kLanczosGamma);
+ xla::XlaOp lanczos_gamma_plus_one_half =
+ xla::ScalarLike(input, kLanczosGamma + 0.5);
+ xla::XlaOp log_lanczos_gamma_plus_one_half =
+ xla::ScalarLike(input, std::log(kLanczosGamma + 0.5));
+
+ xla::XlaOp base_lanczos_coeff = xla::ScalarLike(input, kBaseLanczosCoeff);
+
+ // If the input is less than 0.5 use Gauss's reflection formula:
+ // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
+ xla::XlaOp need_to_reflect = xla::Lt(xla::Real(input), one_half);
+ xla::XlaOp z = xla::Select(need_to_reflect, -input, input - one);
+
+ xla::XlaOp num = zero;
+ xla::XlaOp denom = base_lanczos_coeff;
+ for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
+ xla::XlaOp lanczos_coefficient =
+ xla::ScalarLike(input, kLanczosCoefficients[i]);
+ xla::XlaOp index = xla::ScalarLike(input, i);
+ num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
+ denom = denom + lanczos_coefficient / (z + index + one);
+ }
+
+ // To improve accuracy on platforms with less-precise log implementations,
+ // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
+ // the device.
+ // log(t) = log(kLanczosGamma + 0.5 + z)
+ // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
+ xla::XlaOp t = lanczos_gamma_plus_one_half + z;
+ xla::XlaOp log_t = log_lanczos_gamma_plus_one_half +
+ xla::Log1p(z / lanczos_gamma_plus_one_half);
+
+ xla::XlaOp y = log_t + num / denom - lanczos_gamma / t;
+ xla::XlaOp reflection = y - pi * xla::Cos(pi * input) / xla::Sin(pi * input);
+ xla::XlaOp result = xla::Select(need_to_reflect, reflection, y);
+ return result;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index c89c351cfc..e4c79b5f52 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -49,6 +49,9 @@ XlaOp ErfInv(XlaOp x);
// Computes an approximation of the lgamma function.
XlaOp Lgamma(XlaOp input);
+// Computes an approximation of the digamma function.
+XlaOp Digamma(XlaOp input);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
index 86e195a8c6..1df287d7db 100644
--- a/tensorflow/compiler/xla/client/lib/math_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -105,5 +105,36 @@ XLA_TEST_F(MathTest, Lgamma) {
error_spec_ = ErrorSpec{0.001};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
+
+XLA_TEST_F(MathTest, Digamma) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125,
+ 2.0, 3.0, 4.0, 6.0, 8.0, 9.0});
+ Digamma(x);
+
+ constexpr double euler_mascheroni =
+ 0.57721566490153286060651209008240243104215933593992;
+ std::vector<float> expected = {
+ static_cast<float>(-euler_mascheroni),
+ static_cast<float>(-2 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 -
+ euler_mascheroni),
+ static_cast<float>(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) -
+ 3 * std::log(3) / 2 - euler_mascheroni),
+ static_cast<float>(
+ -M_PI / 2 - 4 * std::log(2) -
+ (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) /
+ std::sqrt(2) -
+ euler_mascheroni),
+ static_cast<float>(1 - euler_mascheroni),
+ static_cast<float>(1.5 - euler_mascheroni),
+ static_cast<float>(11 / 6.0 - euler_mascheroni),
+ static_cast<float>(137 / 60.0 - euler_mascheroni),
+ static_cast<float>(363 / 140.0 - euler_mascheroni),
+ static_cast<float>(761 / 280.0 - euler_mascheroni)};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
} // namespace
} // namespace xla