From 33e0af0130f0086ff82ba924c6a6ec09a144ff20 Mon Sep 17 00:00:00 2001 From: frgossen Date: Fri, 19 Feb 2021 16:35:11 +0000 Subject: Return nan at poles of polygamma, digamma, and zeta if limit is not defined --- .../src/SpecialFunctions/SpecialFunctionsImpl.h | 17 +++++++++++------ unsupported/test/special_functions.cpp | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 17 deletions(-) (limited to 'unsupported') diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index cfc13aff7..f1c260e29 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -241,7 +241,7 @@ struct digamma_impl { Scalar p, q, nz, s, w, y; bool negative = false; - const Scalar maxnum = NumTraits::infinity(); + const Scalar nan = NumTraits::quiet_NaN(); const Scalar m_pi = Scalar(EIGEN_PI); const Scalar zero = Scalar(0); @@ -254,7 +254,7 @@ struct digamma_impl { q = x; p = numext::floor(q); if (p == q) { - return maxnum; + return nan; } /* Remove the zeros of tan(m_pi x) * by subtracting the nearest integer from x @@ -1403,7 +1403,12 @@ struct zeta_impl { { if(q == numext::floor(q)) { - return maxnum; + if (x == numext::floor(x) && long(x) % 2 == 0) { + return maxnum; + } + else { + return nan; + } } p = x; r = numext::floor(p); @@ -1479,11 +1484,11 @@ struct polygamma_impl { Scalar nplus = n + one; const Scalar nan = NumTraits::quiet_NaN(); - // Check that n is an integer - if (numext::floor(n) != n) { + // Check that n is a non-negative integer + if (numext::floor(n) != n || n < zero) { return nan; } - // Just return the digamma function for n = 1 + // Just return the digamma function for n = 0 else if (n == zero) { return digamma_impl::run(x); } diff --git a/unsupported/test/special_functions.cpp b/unsupported/test/special_functions.cpp index 56848daa5..589bb76e1 100644 --- a/unsupported/test/special_functions.cpp +++ b/unsupported/test/special_functions.cpp @@ -191,10 +191,10 @@ template void array_special_functions() // Check the zeta function against scipy.special.zeta { - ArrayType x(7), q(7), res(7), ref(7); - x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9; - q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345; - ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan; + ArrayType x(10), q(10), res(10), ref(10); + x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4; + q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3; + ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf; CALL_SUBTEST( verify_component_wise(ref, ref); ); CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); ); CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); ); @@ -202,9 +202,9 @@ template void array_special_functions() // digamma { - ArrayType x(7), res(7), ref(7); - x << 1, 1.5, 4, -10.5, 10000.5, 0, -1; - ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf; + ArrayType x(9), res(9), ref(9); + x << 1, 1.5, 4, -10.5, 10000.5, 0, -1, -2, -3; + ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, nan, nan, nan, nan; CALL_SUBTEST( verify_component_wise(ref, ref); ); CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); ); @@ -213,10 +213,10 @@ template void array_special_functions() #if EIGEN_HAS_C99_MATH { - ArrayType n(11), x(11), res(11), ref(11); - n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170; - x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64; - ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927; + ArrayType n(16), x(16), res(16), ref(16); + n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170, -1, 0, 1, 2, 3; + x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64, -1, -2, -3, -4, -5; + ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927, nan, nan, plusinf, nan, plusinf; CALL_SUBTEST( verify_component_wise(ref, ref); ); if(sizeof(RealScalar)>=8) { // double -- cgit v1.2.3