aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2018-12-23 15:40:52 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2018-12-23 15:40:52 +0100
commit5713fb7febf24140bfe748d8b868391f01828992 (patch)
treef3acd9a13d0c898817cbde424beae4c7eae890b2
parent6dd93f7e3b92be11991049605655e0bb84ad7a13 (diff)
Fix plog(+INF): it returned ~87 instead of +INF
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h15
-rw-r--r--test/packetmath.cpp7
2 files changed, 16 insertions, 6 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 9481850c6..83fed95de 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -54,6 +54,7 @@ Packet plog_float(const Packet _x)
// The smallest non denormalized float number.
const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
+ const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u);
// Polynomial coefficients.
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
@@ -69,9 +70,6 @@ Packet plog_float(const Packet _x)
const Packet cst_cephes_log_q1 = pset1<Packet>(-2.12194440e-4f);
const Packet cst_cephes_log_q2 = pset1<Packet>(0.693359375f);
- Packet invalid_mask = pcmp_lt_or_nan(x, pzero(x));
- Packet iszero_mask = pcmp_eq(x,pzero(x));
-
// Truncate input values to the minimum positive normal.
x = pmax(x, cst_min_norm_pos);
@@ -117,8 +115,15 @@ Packet plog_float(const Packet _x)
x = padd(x, y);
x = padd(x, y2);
- // Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF.
- return pselect(iszero_mask, cst_minus_inf, por(x, invalid_mask));
+ Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
+ Packet iszero_mask = pcmp_eq(_x,pzero(_x));
+ Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
+ // Filter out invalid inputs, i.e.:
+ // - negative arg will be NAN
+ // - 0 will be -INF
+ // - +INF will be +INF
+ return pselect(iszero_mask, cst_minus_inf,
+ por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
// Exponential function. Works by writing "x = m*log(2) + r" where
diff --git a/test/packetmath.cpp b/test/packetmath.cpp
index 916b37bef..7e46b01de 100644
--- a/test/packetmath.cpp
+++ b/test/packetmath.cpp
@@ -520,10 +520,11 @@ template<typename Scalar,typename Packet> void packetmath_real()
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc);
#endif
- if(PacketTraits::HasLog && PacketSize>=2)
+ if(PacketSize>=2)
{
data1[0] = std::numeric_limits<Scalar>::quiet_NaN();
data1[1] = std::numeric_limits<Scalar>::epsilon();
+ if(PacketTraits::HasLog)
{
packet_helper<PacketTraits::HasLog,Packet> h;
h.store(data2, internal::plog(h.load(data1)));
@@ -551,6 +552,10 @@ template<typename Scalar,typename Packet> void packetmath_real()
data1[0] = Scalar(-1.0f);
h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isnan)(data2[0]));
+
+ data1[0] = std::numeric_limits<Scalar>::infinity();
+ h.store(data2, internal::plog(h.load(data1)));
+ VERIFY((numext::isinf)(data2[0]));
}
{
packet_helper<PacketTraits::HasSqrt,Packet> h;