diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-02-12 11:32:29 -0800 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-02-17 02:23:24 +0000 |
commit | 7ff0b7a980ceffe7d0e72ebac924f514f7874e9b (patch) | |
tree | 4022a25c9e1c909aff3f43b978a4a5d5ce0a1c47 /Eigen/src/Core/arch/Default | |
parent | 9ad4096ccb75dd5c5dd882576d49d48475afa300 (diff) |
Updated pfrexp implementation.
The original implementation fails for 0, denormals, inf, and NaN.
See #2150
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 164 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h | 28 |
2 files changed, 106 insertions, 86 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 452019ecc..42d310ab2 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -25,80 +25,114 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */) { return pload<Packet>(a); } -template<typename Packet> EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent) { +// Creates a Scalar integer type with same bit-width. +template<typename T> struct make_integer; +template<> struct make_integer<float> { typedef numext::int32_t type; }; +template<> struct make_integer<double> { typedef numext::int64_t type; }; +template<> struct make_integer<half> { typedef numext::int16_t type; }; +template<> struct make_integer<bfloat16> { typedef numext::int16_t type; }; + +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& a) { + typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::integer_packet PacketI; - const Packet cst_126f = pset1<Packet>(126.0f); - const Packet cst_half = pset1<Packet>(0.5f); - const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u); - exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f); - return por(pand(a, cst_inv_mant_mask), cst_half); + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; + return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a)))); } -template<typename Packet> EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent) { - typedef typename unpacket_traits<Packet>::integer_packet PacketI; - const Packet cst_1022d = pset1<Packet>(1022.0); - const Packet cst_half = pset1<Packet>(0.5); - const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull)); - exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); +// Safely applies frexp, correctly handles denormals. +// Assumes IEEE floating point format. +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent) { + typedef typename unpacket_traits<Packet>::type Scalar; + typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI; + + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = + ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000 + const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask)); + const Packet half = pset1<Packet>(Scalar(0.5)); + const Packet zero = pzero(a); + const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126 + + // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1). + const Packet is_denormal = pcmp_lt(pabs(a), normal_min); + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24 + // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. + const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 + const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor); + const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); + + // Determine exponent offset: -126 if normal, -126-24 if denormal + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126 + Packet exponent_offset = pset1<Packet>(scalar_exponent_offset); + const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24 + exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); + + // Determine exponent and mantissa from normalized_a. + exponent = pfrexp_generic_get_biased_exponent(normalized_a); + // Zero, Inf and NaN return 'a' unmodified, exponent is zero + // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255 + const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent); + const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); + const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); + exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); + return m; } // Safely applies ldexp, correctly handles overflows, underflows and denormals. // Assumes IEEE floating point format. -template<typename Packet> -struct pldexp_impl { +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent) { + // We want to return a * 2^exponent, allowing for all possible integer + // exponents without overflowing or underflowing in intermediate + // computations. + // + // Since 'a' and the output can be denormal, the maximum range of 'exponent' + // to consider for a float is: + // -255-23 -> 255+23 + // Below -278 any finite float 'a' will become zero, and above +278 any + // finite float will become inf, including when 'a' is the smallest possible + // denormal. + // + // Unfortunately, 2^(278) cannot be represented using either one or two + // finite normal floats, so we must split the scale factor into at least + // three parts. It turns out to be faster to split 'exponent' into four + // factors, since [exponent>>2] is much faster to compute that [exponent/3]. + // + // Set e = min(max(exponent, -278), 278); + // b = floor(e/4); + // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) + // + // This will avoid any intermediate overflows and correctly handle 0, inf, + // NaN cases. typedef typename unpacket_traits<Packet>::integer_packet PacketI; typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<PacketI>::type ScalarI; - enum { - TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits<Scalar>::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; - - static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC - Packet run(const Packet& a, const Packet& exponent) { - // We want to return a * 2^exponent, allowing for all possible integer - // exponents without overflowing or underflowing in intermediate - // computations. - // - // Since 'a' and the output can be denormal, the maximum range of 'exponent' - // to consider for a float is: - // -255-23 -> 255+23 - // Below -278 any finite float 'a' will become zero, and above +278 any - // finite float will become inf, including when 'a' is the smallest possible - // denormal. - // - // Unfortunately, 2^(278) cannot be represented using either one or two - // finite normal floats, so we must split the scale factor into at least - // three parts. It turns out to be faster to split 'exponent' into four - // factors, since [exponent>>2] is much faster to compute that [exponent/3]. - // - // Set e = min(max(exponent, -278), 278); - // b = floor(e/4); - // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) - // - // This will avoid any intermediate overflows and correctly handle 0, inf, - // NaN cases. - const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278 - const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 - const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); - PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); - Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b - Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) - b = psub(psub(psub(e, b), b), b); // e - 3b - c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b) - out = pmul(out, c); - return out; - } -}; + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278 + const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127 + const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); + Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b + Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b) + out = pmul(out, c); + return out; +} // Explicitly multiplies // a * (2^e) // clamping e to the range -// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent] +// [numeric_limits<Scalar>::min_exponent-2, numeric_limits<Scalar>::max_exponent] // // This is approx 7x faster than pldexp_impl, but will prematurely over/underflow // if 2^e doesn't fit into a normal floating-point Scalar. @@ -111,7 +145,7 @@ struct pldexp_fast_impl { typedef typename unpacket_traits<PacketI>::type ScalarI; enum { TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits<Scalar>::digits - 1, + MantissaBits = numext::numeric_limits<Scalar>::digits - 1, ExponentBits = int(TotalBits) - int(MantissaBits) - 1 }; @@ -126,14 +160,6 @@ struct pldexp_fast_impl { } }; -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent) -{ return pldexp_impl<Packet>::run(a, exponent); } - -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent) -{ return pldexp_impl<Packet>::run(a, exponent); } - // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 96c572fd3..637e5f4af 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -25,29 +25,23 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */); * Some generic implementations to be used by implementors ***************************************************************************/ -/** Default implementation of pfrexp for float. +/** Default implementation of pfrexp. * It is expected to be called by implementers of template<> pfrexp. */ -template<typename Packet> EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent); +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent); -/** Default implementation of pfrexp for double. - * It is expected to be called by implementers of template<> pfrexp. - */ -template<typename Packet> EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent); - -/** Default implementation of pldexp for float. - * It is expected to be called by implementers of template<> pldexp. - */ -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent); +// Extracts the biased exponent value from Packet p, and casts the results to +// a floating-point Packet type. Used by pfrexp_generic. Override this if +// there is no unpacket_traits<Packet>::integer_packet. +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& p); -/** Default implementation of pldexp for double. +/** Default implementation of pldexp. * It is expected to be called by implementers of template<> pldexp. */ -template<typename Packet> EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent); +template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent); /** \internal \returns log(x) for single precision float */ template <typename Packet> |