aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-08-28 12:20:21 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-08-28 12:20:21 -0700
commit1187bb65ad196161a07f4e0125e478d022ea1b08 (patch)
tree86fa97b5b8f1f7377d51cdd202f30a4293b6b9ae
parent6e77f9bef35012f160b307bdeae73194fde91e51 (diff)
Add more tests for corner cases of log1p and expm1. Add handling of infinite arguments to log1p such that log1p(inf) = inf.
-rw-r--r--Eigen/src/Core/MathFunctions.h3
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h3
-rw-r--r--test/packetmath.cpp14
3 files changed, 17 insertions, 3 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index fcf62011e..fbec39d83 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -551,7 +551,8 @@ namespace std_fallback {
Scalar x1p = RealScalar(1) + x;
Scalar log_1p = log(x1p);
const bool is_small = numext::equal_strict(x1p, Scalar(1));
- return is_small ? x : x * (log_1p / (x1p - RealScalar(1)));
+ const bool is_inf = numext::equal_strict(x1p, log_1p);
+ return (is_small || is_inf) ? x : x * (log_1p / (x1p - RealScalar(1)));
}
}
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 505a0eec8..0fc673e12 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -137,8 +137,9 @@ Packet generic_plog1p(const Packet& x)
Packet xp1 = padd(x, one);
Packet small_mask = pcmp_eq(xp1, one);
Packet log1 = plog(xp1);
+ Packet inf_mask = pcmp_eq(xp1, log1);
Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
- return pselect(small_mask, x, log_large);
+ return pselect(por(small_mask, inf_mask), x, log_large);
}
/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 28768b18d..67ff6dc5b 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -607,8 +607,12 @@ template<typename Scalar,typename Packet> void packetmath_real()
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc);
- CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
+ data1[0] = std::numeric_limits<Scalar>::infinity();
+ data1[1] = Scalar(-1);
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
+ data1[0] = std::numeric_limits<Scalar>::infinity();
+ data1[1] = -std::numeric_limits<Scalar>::infinity();
+ CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
#endif
if(PacketSize>=2)
@@ -648,6 +652,14 @@ template<typename Scalar,typename Packet> void packetmath_real()
h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isinf)(data2[0]));
}
+ if(PacketTraits::HasLog1p) {
+ packet_helper<PacketTraits::HasLog1p,Packet> h;
+ data1[0] = Scalar(-2);
+ data1[1] = -std::numeric_limits<Scalar>::infinity();
+ h.store(data2, internal::plog1p(h.load(data1)));
+ VERIFY((numext::isnan)(data2[0]));
+ VERIFY((numext::isnan)(data2[1]));
+ }
if(PacketTraits::HasSqrt)
{
packet_helper<PacketTraits::HasSqrt,Packet> h;