aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Srinivas Vasudevan <srvasude@gmail.com>2020-01-14 21:32:48 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-01-14 21:32:48 +0000
commitf6c6de5d63a0c68e71d846604779867ce126d91b (patch)
tree7e3e16c7a7aeccfa4587d4f751cc7319a6233ee6 /unsupported
parent6601abce868e3284b4829a4fbf91eefaa0d704af (diff)
Ensure Igamma does not NaN or Inf for large values.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h38
-rw-r--r--unsupported/test/special_functions.cpp15
2 files changed, 48 insertions, 5 deletions
diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
index fe3f6d710..425231aa6 100644
--- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
+++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h
@@ -713,6 +713,18 @@ struct cephes_helper<double> {
enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
+template <typename Scalar>
+static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
+ /* Compute x**a * exp(-x) / gamma(a) */
+ Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
+ if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
+ // Assuming x and a aren't Nan.
+ (numext::isnan)(logax)) {
+ return Scalar(0);
+ }
+ return numext::exp(logax);
+}
+
template <typename Scalar, IgammaComputationMode mode>
EIGEN_DEVICE_FUNC
int igamma_num_iterations() {
@@ -755,6 +767,15 @@ struct igammac_cf_impl {
return zero;
}
+ Scalar ax = main_igamma_term<Scalar>(a, x);
+ // This is independent of mode. If this value is zero,
+ // then the function value is zero. If the function value is zero,
+ // then we are in a neighborhood where the function value evalutes to zero,
+ // so the derivative is zero.
+ if (ax == zero) {
+ return zero;
+ }
+
// continued fraction
Scalar y = one - a;
Scalar z = x + y + one;
@@ -825,9 +846,7 @@ struct igammac_cf_impl {
}
/* Compute x**a * exp(-x) / gamma(a) */
- Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
- Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {
@@ -858,6 +877,18 @@ struct igamma_series_impl {
const Scalar one = 1;
const Scalar machep = cephes_helper<Scalar>::machep();
+ Scalar ax = main_igamma_term<Scalar>(a, x);
+
+ // This is independent of mode. If this value is zero,
+ // then the function value is zero. If the function value is zero,
+ // then we are in a neighborhood where the function value evalutes to zero,
+ // so the derivative is zero.
+ if (ax == zero) {
+ return zero;
+ }
+
+ ax /= a;
+
/* power series */
Scalar r = a;
Scalar c = one;
@@ -886,10 +917,7 @@ struct igamma_series_impl {
}
}
- /* Compute x**a * exp(-x) / gamma(a + 1) */
- Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a + one);
Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
- Scalar ax = numext::exp(logax);
Scalar dax_da = ax * dlogax_da;
switch (mode) {
diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp
index c104ac3c5..1027272d5 100644
--- a/unsupported/test/special_functions.cpp
+++ b/unsupported/test/special_functions.cpp
@@ -7,6 +7,7 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+#include <limits.h>
#include "main.h"
#include "../Eigen/SpecialFunctions"
@@ -74,6 +75,7 @@ template<typename ArrayType> void array_special_functions()
ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
+
// Gamma(a, 0) == Gamma(a)
VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
@@ -86,6 +88,19 @@ template<typename ArrayType> void array_special_functions()
// gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
}
+ {
+ // Verify for large a and x that values are between 0 and 1.
+ ArrayType m1 = ArrayType::Random(rows,cols);
+ ArrayType m2 = ArrayType::Random(rows,cols);
+ Scalar max_exponent = std::numeric_limits<Scalar>::max_exponent10;
+ ArrayType a = m1.abs() * pow(10., max_exponent - 1);
+ ArrayType x = m2.abs() * pow(10., max_exponent - 1);
+ for (int i = 0; i < a.size(); ++i) {
+ Scalar igam = numext::igamma(a(i), x(i));
+ VERIFY(0 <= igam);
+ VERIFY(igam <= 1);
+ }
+ }
{
// Check exact values of igamma and igammac against a third party calculation.