diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-10-12 12:24:08 +0100 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-02-10 22:45:41 +0000 |
commit | 4cb563a01e0619ea1798c7927f1909755ead2dd8 (patch) | |
tree | f1a1c213a13ad6320fa86ebb144af777568eeeea /Eigen/src/Core/arch/Default | |
parent | 7eb07da538ecc1b8937bfb5dac0d071067728397 (diff) |
Fix ldexp implementations.
The previous implementations produced garbage values if the exponent did
not fit within the exponent bits. See #2131 for a complete discussion,
and !375 for other possible implementations.
Here we implement the 4-factor version. See `pldexp_impl` in
`GenericPacketMathFunctions.h` for a full description.
The SSE `pcmp*` methods were moved down since `pcmp_le<Packet4i>`
requires `por`.
Left as a "TODO" is to delegate to a faster version if we know the
exponent does fit within the exponent bits.
Fixes #2131.
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 110 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h | 21 |
2 files changed, 112 insertions, 19 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index b4fa0489b..09146f496 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -40,30 +40,99 @@ pfrexp_double(const Packet& a, Packet& exponent) { typedef typename unpacket_traits<Packet>::integer_packet PacketI; const Packet cst_1022d = pset1<Packet>(1022.0); const Packet cst_half = pset1<Packet>(0.5); - const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull)); + const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull)); exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d); return por(pand(a, cst_inv_mant_mask), cst_half); } -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_float(Packet a, Packet exponent) -{ +// Safely applies ldexp, correctly handles overflows, underflows and denormals. +// Assumes IEEE floating point format. +template<typename Packet> +struct pldexp_impl { typedef typename unpacket_traits<Packet>::integer_packet PacketI; - const Packet cst_127 = pset1<Packet>(127.f); - // return a * 2^exponent - PacketI ei = pcast<Packet,PacketI>(padd(exponent, cst_127)); - return pmul(a, preinterpret<Packet>(plogical_shift_left<23>(ei))); -} + typedef typename unpacket_traits<Packet>::type Scalar; + typedef typename unpacket_traits<PacketI>::type ScalarI; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = std::numeric_limits<Scalar>::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + + static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC + Packet run(const Packet& a, const Packet& exponent) { + // We want to return a * 2^exponent, allowing for all possible integer + // exponents without overflowing or underflowing in intermediate + // computations. + // + // Since 'a' and the output can be denormal, the maximum range of 'exponent' + // to consider for a float is: + // -255-23 -> 255+23 + // Below -278 any finite float 'a' will become zero, and above +278 any + // finite float will become inf, including when 'a' is the smallest possible + // denormal. + // + // Unfortunately, 2^(278) cannot be represented using either one or two + // finite normal floats, so we must split the scale factor into at least + // three parts. It turns out to be faster to split 'exponent' into four + // factors, since [exponent>>2] is much faster to compute that [exponent/3]. + // + // Set e = min(max(exponent, -278), 278); + // b = floor(e/4); + // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) + // + // This will avoid any intermediate overflows and correctly handle 0, inf, + // NaN cases. + const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278 + const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 + const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); + Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b + Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b) + out = pmul(out, c); + return out; + } +}; -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_double(Packet a, Packet exponent) -{ +// Explicitly multiplies +// a * (2^e) +// clamping e to the range +// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent] +// +// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow +// if 2^e doesn't fit into a normal floating-point Scalar. +// +// Assumes IEEE floating point format +template<typename Packet> +struct pldexp_fast_impl { typedef typename unpacket_traits<Packet>::integer_packet PacketI; - const Packet cst_1023 = pset1<Packet>(1023.0); - // return a * 2^exponent - PacketI ei = pcast<Packet,PacketI>(padd(exponent, cst_1023)); - return pmul(a, preinterpret<Packet>(plogical_shift_left<52>(ei))); -} + typedef typename unpacket_traits<Packet>::type Scalar; + typedef typename unpacket_traits<PacketI>::type ScalarI; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = std::numeric_limits<Scalar>::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + + static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC + Packet run(const Packet& a, const Packet& exponent) { + const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127 + const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255 + // restrict biased exponent between 0 and 255 for float. + const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 + // return a * (2^e) + return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e))); + } +}; + +template<typename Packet> EIGEN_STRONG_INLINE Packet +pldexp_float(const Packet& a, const Packet& exponent) +{ return pldexp_impl<Packet>::run(a, exponent); } + +template<typename Packet> EIGEN_STRONG_INLINE Packet +pldexp_double(const Packet& a, const Packet& exponent) +{ return pldexp_impl<Packet>::run(a, exponent); } // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) @@ -394,6 +463,7 @@ Packet pexp_float(const Packet _x) y = pmadd(y, r2, y2); // Return 2^m * exp(r). + // TODO: replace pldexp with faster implementation since y in [-1, 1). return pmax(pldexp(y,m), _x); } @@ -462,6 +532,7 @@ Packet pexp_double(const Packet _x) // Construct the result 2^n * exp(g) = e * x. The max is used to catch // non-finite values in the input. + // TODO: replace pldexp with faster implementation since x in [-1, 1). return pmax(pldexp(x,fx), _x); } @@ -897,6 +968,8 @@ Packet generic_pow_impl(const Packet& x, const Packet& y) { // Note: I experimented with using Dekker's algorithms for the // multiplication by ln(2) here, but did not see any difference. Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z)); + // TODO: investigate bounds of e_r and n_z, potentially using faster + // implementation of ldexp. return pldexp(e_r, n_z); } @@ -909,6 +982,7 @@ Packet generic_pow(const Packet& x, const Packet& y) { const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity()); const Packet cst_zero = pset1<Packet>(Scalar(0)); const Packet cst_one = pset1<Packet>(Scalar(1)); + const Packet cst_half = pset1<Packet>(Scalar(0.5)); const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN()); Packet abs_x = pabs(x); @@ -937,7 +1011,7 @@ Packet generic_pow(const Packet& x, const Packet& y) { // Predicates for whether y is integer and/or even. Packet y_is_int = pcmp_eq(pfloor(y), y); - Packet y_div_2 = pldexp(y, pset1<Packet>(Scalar(-1))); + Packet y_div_2 = pmul(y, cst_half); Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); // Predicates encoding special cases for the value of pow(x,y) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index a623f54cb..96c572fd3 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -21,14 +21,33 @@ namespace internal { template<typename Packet, int N> EIGEN_DEVICE_FUNC inline Packet pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */); +/*************************************************************************** + * Some generic implementations to be used by implementors +***************************************************************************/ + +/** Default implementation of pfrexp for float. + * It is expected to be called by implementers of template<> pfrexp. + */ template<typename Packet> EIGEN_STRONG_INLINE Packet pfrexp_float(const Packet& a, Packet& exponent); +/** Default implementation of pfrexp for double. + * It is expected to be called by implementers of template<> pfrexp. + */ template<typename Packet> EIGEN_STRONG_INLINE Packet pfrexp_double(const Packet& a, Packet& exponent); +/** Default implementation of pldexp for float. + * It is expected to be called by implementers of template<> pldexp. + */ +template<typename Packet> EIGEN_STRONG_INLINE Packet +pldexp_float(const Packet& a, const Packet& exponent); + +/** Default implementation of pldexp for double. + * It is expected to be called by implementers of template<> pldexp. + */ template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_float(Packet a, Packet exponent); +pldexp_double(const Packet& a, const Packet& exponent); /** \internal \returns log(x) for single precision float */ template <typename Packet> |