From f6c6de5d63a0c68e71d846604779867ce126d91b Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Tue, 14 Jan 2020 21:32:48 +0000 Subject: Ensure Igamma does not NaN or Inf for large values. --- .../src/SpecialFunctions/SpecialFunctionsImpl.h | 38 +++++++++++++++++++--- unsupported/test/special_functions.cpp | 15 +++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) (limited to 'unsupported') diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index fe3f6d710..425231aa6 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -713,6 +713,18 @@ struct cephes_helper { enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE }; +template +static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) { + /* Compute x**a * exp(-x) / gamma(a) */ + Scalar logax = a * numext::log(x) - x - lgamma_impl::run(a); + if (logax < -numext::log(NumTraits::highest()) || + // Assuming x and a aren't Nan. + (numext::isnan)(logax)) { + return Scalar(0); + } + return numext::exp(logax); +} + template EIGEN_DEVICE_FUNC int igamma_num_iterations() { @@ -755,6 +767,15 @@ struct igammac_cf_impl { return zero; } + Scalar ax = main_igamma_term(a, x); + // This is independent of mode. If this value is zero, + // then the function value is zero. If the function value is zero, + // then we are in a neighborhood where the function value evalutes to zero, + // so the derivative is zero. + if (ax == zero) { + return zero; + } + // continued fraction Scalar y = one - a; Scalar z = x + y + one; @@ -825,9 +846,7 @@ struct igammac_cf_impl { } /* Compute x**a * exp(-x) / gamma(a) */ - Scalar logax = a * numext::log(x) - x - lgamma_impl::run(a); Scalar dlogax_da = numext::log(x) - digamma_impl::run(a); - Scalar ax = numext::exp(logax); Scalar dax_da = ax * dlogax_da; switch (mode) { @@ -858,6 +877,18 @@ struct igamma_series_impl { const Scalar one = 1; const Scalar machep = cephes_helper::machep(); + Scalar ax = main_igamma_term(a, x); + + // This is independent of mode. If this value is zero, + // then the function value is zero. If the function value is zero, + // then we are in a neighborhood where the function value evalutes to zero, + // so the derivative is zero. + if (ax == zero) { + return zero; + } + + ax /= a; + /* power series */ Scalar r = a; Scalar c = one; @@ -886,10 +917,7 @@ struct igamma_series_impl { } } - /* Compute x**a * exp(-x) / gamma(a + 1) */ - Scalar logax = a * numext::log(x) - x - lgamma_impl::run(a + one); Scalar dlogax_da = numext::log(x) - digamma_impl::run(a + one); - Scalar ax = numext::exp(logax); Scalar dax_da = ax * dlogax_da; switch (mode) { diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index c104ac3c5..1027272d5 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -7,6 +7,7 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +#include #include "main.h" #include "../Eigen/SpecialFunctions" @@ -74,6 +75,7 @@ template void array_special_functions() ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp(); ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp(); + // Gamma(a, 0) == Gamma(a) VERIFY_IS_APPROX(Eigen::igammac(a, zero), one); @@ -86,6 +88,19 @@ template void array_special_functions() // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x) VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp()); } + { + // Verify for large a and x that values are between 0 and 1. + ArrayType m1 = ArrayType::Random(rows,cols); + ArrayType m2 = ArrayType::Random(rows,cols); + Scalar max_exponent = std::numeric_limits::max_exponent10; + ArrayType a = m1.abs() * pow(10., max_exponent - 1); + ArrayType x = m2.abs() * pow(10., max_exponent - 1); + for (int i = 0; i < a.size(); ++i) { + Scalar igam = numext::igamma(a(i), x(i)); + VERIFY(0 <= igam); + VERIFY(igam <= 1); + } + } { // Check exact values of igamma and igammac against a third party calculation. -- cgit v1.2.3