aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-10-12 12:24:08 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-10 22:45:41 +0000
commit4cb563a01e0619ea1798c7927f1909755ead2dd8 (patch)
treef1a1c213a13ad6320fa86ebb144af777568eeeea /Eigen/src/Core/arch/Default
parent7eb07da538ecc1b8937bfb5dac0d071067728397 (diff)
Fix ldexp implementations.
The previous implementations produced garbage values if the exponent did not fit within the exponent bits. See #2131 for a complete discussion, and !375 for other possible implementations. Here we implement the 4-factor version. See `pldexp_impl` in `GenericPacketMathFunctions.h` for a full description. The SSE `pcmp*` methods were moved down since `pcmp_le<Packet4i>` requires `por`. Left as a "TODO" is to delegate to a faster version if we know the exponent does fit within the exponent bits. Fixes #2131.
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h110
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h21
2 files changed, 112 insertions, 19 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index b4fa0489b..09146f496 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -40,30 +40,99 @@ pfrexp_double(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const Packet cst_1022d = pset1<Packet>(1022.0);
const Packet cst_half = pset1<Packet>(0.5);
- const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull));
+ const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
}
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_float(Packet a, Packet exponent)
-{
+// Safely applies ldexp, correctly handles overflows, underflows and denormals.
+// Assumes IEEE floating point format.
+template<typename Packet>
+struct pldexp_impl {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- const Packet cst_127 = pset1<Packet>(127.f);
- // return a * 2^exponent
- PacketI ei = pcast<Packet,PacketI>(padd(exponent, cst_127));
- return pmul(a, preinterpret<Packet>(plogical_shift_left<23>(ei)));
-}
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = std::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+ Packet run(const Packet& a, const Packet& exponent) {
+ // We want to return a * 2^exponent, allowing for all possible integer
+ // exponents without overflowing or underflowing in intermediate
+ // computations.
+ //
+ // Since 'a' and the output can be denormal, the maximum range of 'exponent'
+ // to consider for a float is:
+ // -255-23 -> 255+23
+ // Below -278 any finite float 'a' will become zero, and above +278 any
+ // finite float will become inf, including when 'a' is the smallest possible
+ // denormal.
+ //
+ // Unfortunately, 2^(278) cannot be represented using either one or two
+ // finite normal floats, so we must split the scale factor into at least
+ // three parts. It turns out to be faster to split 'exponent' into four
+ // factors, since [exponent>>2] is much faster to compute that [exponent/3].
+ //
+ // Set e = min(max(exponent, -278), 278);
+ // b = floor(e/4);
+ // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
+ //
+ // This will avoid any intermediate overflows and correctly handle 0, inf,
+ // NaN cases.
+ const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+ PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
+ Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
+ Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
+ out = pmul(out, c);
+ return out;
+ }
+};
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_double(Packet a, Packet exponent)
-{
+// Explicitly multiplies
+// a * (2^e)
+// clamping e to the range
+// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent]
+//
+// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
+// if 2^e doesn't fit into a normal floating-point Scalar.
+//
+// Assumes IEEE floating point format
+template<typename Packet>
+struct pldexp_fast_impl {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- const Packet cst_1023 = pset1<Packet>(1023.0);
- // return a * 2^exponent
- PacketI ei = pcast<Packet,PacketI>(padd(exponent, cst_1023));
- return pmul(a, preinterpret<Packet>(plogical_shift_left<52>(ei)));
-}
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename unpacket_traits<PacketI>::type ScalarI;
+ enum {
+ TotalBits = sizeof(Scalar) * CHAR_BIT,
+ MantissaBits = std::numeric_limits<Scalar>::digits - 1,
+ ExponentBits = int(TotalBits) - int(MantissaBits) - 1
+ };
+
+ static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+ Packet run(const Packet& a, const Packet& exponent) {
+ const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127
+ const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255
+ // restrict biased exponent between 0 and 255 for float.
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127
+ // return a * (2^e)
+ return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e)));
+ }
+};
+
+template<typename Packet> EIGEN_STRONG_INLINE Packet
+pldexp_float(const Packet& a, const Packet& exponent)
+{ return pldexp_impl<Packet>::run(a, exponent); }
+
+template<typename Packet> EIGEN_STRONG_INLINE Packet
+pldexp_double(const Packet& a, const Packet& exponent)
+{ return pldexp_impl<Packet>::run(a, exponent); }
// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
@@ -394,6 +463,7 @@ Packet pexp_float(const Packet _x)
y = pmadd(y, r2, y2);
// Return 2^m * exp(r).
+ // TODO: replace pldexp with faster implementation since y in [-1, 1).
return pmax(pldexp(y,m), _x);
}
@@ -462,6 +532,7 @@ Packet pexp_double(const Packet _x)
// Construct the result 2^n * exp(g) = e * x. The max is used to catch
// non-finite values in the input.
+ // TODO: replace pldexp with faster implementation since x in [-1, 1).
return pmax(pldexp(x,fx), _x);
}
@@ -897,6 +968,8 @@ Packet generic_pow_impl(const Packet& x, const Packet& y) {
// Note: I experimented with using Dekker's algorithms for the
// multiplication by ln(2) here, but did not see any difference.
Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z));
+ // TODO: investigate bounds of e_r and n_z, potentially using faster
+ // implementation of ldexp.
return pldexp(e_r, n_z);
}
@@ -909,6 +982,7 @@ Packet generic_pow(const Packet& x, const Packet& y) {
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
const Packet cst_zero = pset1<Packet>(Scalar(0));
const Packet cst_one = pset1<Packet>(Scalar(1));
+ const Packet cst_half = pset1<Packet>(Scalar(0.5));
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
Packet abs_x = pabs(x);
@@ -937,7 +1011,7 @@ Packet generic_pow(const Packet& x, const Packet& y) {
// Predicates for whether y is integer and/or even.
Packet y_is_int = pcmp_eq(pfloor(y), y);
- Packet y_div_2 = pldexp(y, pset1<Packet>(Scalar(-1)));
+ Packet y_div_2 = pmul(y, cst_half);
Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
// Predicates encoding special cases for the value of pow(x,y)
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
index a623f54cb..96c572fd3 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -21,14 +21,33 @@ namespace internal {
template<typename Packet, int N> EIGEN_DEVICE_FUNC inline Packet
pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */);
+/***************************************************************************
+ * Some generic implementations to be used by implementors
+***************************************************************************/
+
+/** Default implementation of pfrexp for float.
+ * It is expected to be called by implementers of template<> pfrexp.
+ */
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_float(const Packet& a, Packet& exponent);
+/** Default implementation of pfrexp for double.
+ * It is expected to be called by implementers of template<> pfrexp.
+ */
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_double(const Packet& a, Packet& exponent);
+/** Default implementation of pldexp for float.
+ * It is expected to be called by implementers of template<> pldexp.
+ */
+template<typename Packet> EIGEN_STRONG_INLINE Packet
+pldexp_float(const Packet& a, const Packet& exponent);
+
+/** Default implementation of pldexp for double.
+ * It is expected to be called by implementers of template<> pldexp.
+ */
template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_float(Packet a, Packet exponent);
+pldexp_double(const Packet& a, const Packet& exponent);
/** \internal \returns log(x) for single precision float */
template <typename Packet>