aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc23
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc52
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h3
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc31
5 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 0419419ea5..5f25ff9002 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -406,6 +406,38 @@ class UnaryOpsTest(xla_test.XLATestCase):
],
dtype=dtype))
+ self._assertOpOutputMatchesExpected(
+ math_ops.digamma,
+ np.array(
+ [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0],
+ [6.0, 8.0, 9.0]],
+ dtype=dtype),
+ expected=np.array(
+ [
+ [
+ -np.euler_gamma, -2 * np.log(2) - np.euler_gamma,
+ -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 -
+ np.euler_gamma
+ ],
+ [
+ -np.pi / 2 - 3 * np.log(2) - np.euler_gamma,
+ -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) -
+ 3 * np.log(3) / 2 - np.euler_gamma,
+ -np.pi / 2 - 4 * np.log(2) -
+ (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2)))
+ / np.sqrt(2) - np.euler_gamma
+ ],
+ [
+ 1 - np.euler_gamma, 1.5 - np.euler_gamma,
+ 11 / 6.0 - np.euler_gamma
+ ],
+ [
+ 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma,
+ 761 / 280.0 - np.euler_gamma
+ ],
+ ],
+ dtype=dtype))
+
def quantize_and_dequantize_v2(x):
return array_ops.quantize_and_dequantize_v2(
x, -127, 127, signed_input=True, num_bits=8)
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 76ab8b4c00..4bb31f4117 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -212,5 +212,28 @@ class LgammaOp : public XlaOpKernel {
}; // namespace
REGISTER_XLA_OP(Name("Lgamma"), LgammaOp);
+class DigammaOp : public XlaOpKernel {
+ public:
+ explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Calculate lgamma using the Lanczos approximation
+ // (https://en.wikipedia.org/wiki/Lanczos_approximation).
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp input = ctx->Input(0);
+ xla::PrimitiveType input_type = ctx->input_xla_type(0);
+
+ if (input_type == xla::F16 || input_type == xla::BF16) {
+ // The approximation works better with at least 32-bits of accuracy.
+ xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32);
+ xla::XlaOp result_f32 = xla::Digamma(input_f32);
+ xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type);
+ ctx->SetOutput(0, result_x16);
+ } else {
+ xla::XlaOp result = xla::Digamma(input);
+ ctx->SetOutput(0, result);
+ }
+ }
+}; // namespace
+REGISTER_XLA_OP(Name("Digamma"), DigammaOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index fdc7057de3..2a7ac1d716 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -217,4 +217,56 @@ xla::XlaOp Lgamma(xla::XlaOp input) {
return result;
}
+// Compute the Digamma function using Lanczos' approximation from "A Precision
+// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
+// series B. Vol. 1:
+// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
+// t(z) = z + kLanczosGamma + 1/2
+// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
+// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
+xla::XlaOp Digamma(xla::XlaOp input) {
+ xla::XlaOp zero = xla::ScalarLike(input, 0);
+ xla::XlaOp one_half = xla::ScalarLike(input, 0.5);
+ xla::XlaOp one = xla::ScalarLike(input, 1);
+
+ xla::XlaOp pi = xla::ScalarLike(input, M_PI);
+
+ xla::XlaOp lanczos_gamma = xla::ScalarLike(input, kLanczosGamma);
+ xla::XlaOp lanczos_gamma_plus_one_half =
+ xla::ScalarLike(input, kLanczosGamma + 0.5);
+ xla::XlaOp log_lanczos_gamma_plus_one_half =
+ xla::ScalarLike(input, std::log(kLanczosGamma + 0.5));
+
+ xla::XlaOp base_lanczos_coeff = xla::ScalarLike(input, kBaseLanczosCoeff);
+
+ // If the input is less than 0.5 use Gauss's reflection formula:
+ // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
+ xla::XlaOp need_to_reflect = xla::Lt(xla::Real(input), one_half);
+ xla::XlaOp z = xla::Select(need_to_reflect, -input, input - one);
+
+ xla::XlaOp num = zero;
+ xla::XlaOp denom = base_lanczos_coeff;
+ for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
+ xla::XlaOp lanczos_coefficient =
+ xla::ScalarLike(input, kLanczosCoefficients[i]);
+ xla::XlaOp index = xla::ScalarLike(input, i);
+ num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
+ denom = denom + lanczos_coefficient / (z + index + one);
+ }
+
+ // To improve accuracy on platforms with less-precise log implementations,
+ // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
+ // the device.
+ // log(t) = log(kLanczosGamma + 0.5 + z)
+ // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
+ xla::XlaOp t = lanczos_gamma_plus_one_half + z;
+ xla::XlaOp log_t = log_lanczos_gamma_plus_one_half +
+ xla::Log1p(z / lanczos_gamma_plus_one_half);
+
+ xla::XlaOp y = log_t + num / denom - lanczos_gamma / t;
+ xla::XlaOp reflection = y - pi * xla::Cos(pi * input) / xla::Sin(pi * input);
+ xla::XlaOp result = xla::Select(need_to_reflect, reflection, y);
+ return result;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index c89c351cfc..e4c79b5f52 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -49,6 +49,9 @@ XlaOp ErfInv(XlaOp x);
// Computes an approximation of the lgamma function.
XlaOp Lgamma(XlaOp input);
+// Computes an approximation of the digamma function.
+XlaOp Digamma(XlaOp input);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
index 86e195a8c6..1df287d7db 100644
--- a/tensorflow/compiler/xla/client/lib/math_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -105,5 +105,36 @@ XLA_TEST_F(MathTest, Lgamma) {
error_spec_ = ErrorSpec{0.001};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
+
+XLA_TEST_F(MathTest, Digamma) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125,
+ 2.0, 3.0, 4.0, 6.0, 8.0, 9.0});
+ Digamma(x);
+
+ constexpr double euler_mascheroni =
+ 0.57721566490153286060651209008240243104215933593992;
+ std::vector<float> expected = {
+ static_cast<float>(-euler_mascheroni),
+ static_cast<float>(-2 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 -
+ euler_mascheroni),
+ static_cast<float>(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) -
+ 3 * std::log(3) / 2 - euler_mascheroni),
+ static_cast<float>(
+ -M_PI / 2 - 4 * std::log(2) -
+ (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) /
+ std::sqrt(2) -
+ euler_mascheroni),
+ static_cast<float>(1 - euler_mascheroni),
+ static_cast<float>(1.5 - euler_mascheroni),
+ static_cast<float>(11 / 6.0 - euler_mascheroni),
+ static_cast<float>(137 / 60.0 - euler_mascheroni),
+ static_cast<float>(363 / 140.0 - euler_mascheroni),
+ static_cast<float>(761 / 280.0 - euler_mascheroni)};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
} // namespace
} // namespace xla