diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-08-28 12:20:21 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-08-28 12:20:21 -0700 |
commit | 1187bb65ad196161a07f4e0125e478d022ea1b08 (patch) | |
tree | 86fa97b5b8f1f7377d51cdd202f30a4293b6b9ae | |
parent | 6e77f9bef35012f160b307bdeae73194fde91e51 (diff) |
Add more tests for corner cases of log1p and expm1. Add handling of infinite arguments to log1p such that log1p(inf) = inf.
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 3 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 3 | ||||
-rw-r--r-- | test/packetmath.cpp | 14 |
3 files changed, 17 insertions, 3 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index fcf62011e..fbec39d83 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -551,7 +551,8 @@ namespace std_fallback { Scalar x1p = RealScalar(1) + x; Scalar log_1p = log(x1p); const bool is_small = numext::equal_strict(x1p, Scalar(1)); - return is_small ? x : x * (log_1p / (x1p - RealScalar(1))); + const bool is_inf = numext::equal_strict(x1p, log_1p); + return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1))); } } diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 505a0eec8..0fc673e12 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -137,8 +137,9 @@ Packet generic_plog1p(const Packet& x) Packet xp1 = padd(x, one); Packet small_mask = pcmp_eq(xp1, one); Packet log1 = plog(xp1); + Packet inf_mask = pcmp_eq(xp1, log1); Packet log_large = pmul(x, pdiv(log1, psub(xp1, one))); - return pselect(small_mask, x, log_large); + return pselect(por(small_mask, inf_mask), x, log_large); } /** \internal \returns exp(x)-1 computed using W. Kahan's formula. diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 28768b18d..67ff6dc5b 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -607,8 +607,12 @@ template<typename Scalar,typename Packet> void packetmath_real() CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma); CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf); CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc); - CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1); + data1[0] = std::numeric_limits<Scalar>::infinity(); + data1[1] = Scalar(-1); CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p); + data1[0] = std::numeric_limits<Scalar>::infinity(); + data1[1] = -std::numeric_limits<Scalar>::infinity(); + CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1); #endif if(PacketSize>=2) @@ -648,6 +652,14 @@ template<typename Scalar,typename Packet> void packetmath_real() h.store(data2, internal::plog(h.load(data1))); VERIFY((numext::isinf)(data2[0])); } + if(PacketTraits::HasLog1p) { + packet_helper<PacketTraits::HasLog1p,Packet> h; + data1[0] = Scalar(-2); + data1[1] = -std::numeric_limits<Scalar>::infinity(); + h.store(data2, internal::plog1p(h.load(data1))); + VERIFY((numext::isnan)(data2[0])); + VERIFY((numext::isnan)(data2[1])); + } if(PacketTraits::HasSqrt) { packet_helper<PacketTraits::HasSqrt,Packet> h; |