From 7ff0b7a980ceffe7d0e72ebac924f514f7874e9b Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 12 Feb 2021 11:32:29 -0800 Subject: Updated pfrexp implementation. The original implementation fails for 0, denormals, inf, and NaN. See #2150 --- .../Core/arch/Default/GenericPacketMathFunctions.h | 164 ++++++++++++--------- .../arch/Default/GenericPacketMathFunctionsFwd.h | 28 ++-- 2 files changed, 106 insertions(+), 86 deletions(-) (limited to 'Eigen/src/Core/arch/Default') diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 452019ecc..42d310ab2 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -25,80 +25,114 @@ pset(const typename unpacket_traits::type (&a)[N] /* a */) { return pload(a); } -template EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent) { +// Creates a Scalar integer type with same bit-width. +template struct make_integer; +template<> struct make_integer { typedef numext::int32_t type; }; +template<> struct make_integer { typedef numext::int64_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; + +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& a) { + typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::integer_packet PacketI; - const Packet cst_126f = pset1(126.0f); - const Packet cst_half = pset1(0.5f); - const Packet cst_inv_mant_mask = pset1frombits(~0x7f800000u); - exponent = psub(pcast(plogical_shift_right<23>(preinterpret(pabs(a)))), cst_126f); - return por(pand(a, cst_inv_mant_mask), cst_half); + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + return pcast(plogical_shift_right(preinterpret(pabs(a)))); } -template EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent) { - typedef typename unpacket_traits::integer_packet PacketI; - const Packet cst_1022d = pset1(1022.0); - const Packet cst_half = pset1(0.5); - const Packet cst_inv_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - exponent = psub(pcast(plogical_shift_right<52>(preinterpret(pabs(a)))), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); +// Safely applies frexp, correctly handles denormals. +// Assumes IEEE floating point format. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent) { + typedef typename unpacket_traits::type Scalar; + typedef typename make_unsigned::type>::type ScalarUI; + + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = + ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000 + const Packet sign_mantissa_mask = pset1frombits(static_cast(scalar_sign_mantissa_mask)); + const Packet half = pset1(Scalar(0.5)); + const Packet zero = pzero(a); + const Packet normal_min = pset1((numext::numeric_limits::min)()); // Minimum normal value, 2^-126 + + // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1). + const Packet is_denormal = pcmp_lt(pabs(a), normal_min); + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24 + // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. + const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 + const Packet normalization_factor = pset1(scalar_normalization_factor); + const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); + + // Determine exponent offset: -126 if normal, -126-24 if denormal + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126 + Packet exponent_offset = pset1(scalar_exponent_offset); + const Packet normalization_offset = pset1(-Scalar(scalar_normalization_offset)); // -24 + exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); + + // Determine exponent and mantissa from normalized_a. + exponent = pfrexp_generic_get_biased_exponent(normalized_a); + // Zero, Inf and NaN return 'a' unmodified, exponent is zero + // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255 + const Packet non_finite_exponent = pset1(scalar_non_finite_exponent); + const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); + const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); + exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); + return m; } // Safely applies ldexp, correctly handles overflows, underflows and denormals. // Assumes IEEE floating point format. -template -struct pldexp_impl { +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent) { + // We want to return a * 2^exponent, allowing for all possible integer + // exponents without overflowing or underflowing in intermediate + // computations. + // + // Since 'a' and the output can be denormal, the maximum range of 'exponent' + // to consider for a float is: + // -255-23 -> 255+23 + // Below -278 any finite float 'a' will become zero, and above +278 any + // finite float will become inf, including when 'a' is the smallest possible + // denormal. + // + // Unfortunately, 2^(278) cannot be represented using either one or two + // finite normal floats, so we must split the scale factor into at least + // three parts. It turns out to be faster to split 'exponent' into four + // factors, since [exponent>>2] is much faster to compute that [exponent/3]. + // + // Set e = min(max(exponent, -278), 278); + // b = floor(e/4); + // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) + // + // This will avoid any intermediate overflows and correctly handle 0, inf, + // NaN cases. typedef typename unpacket_traits::integer_packet PacketI; typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::type ScalarI; - enum { - TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; - - static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC - Packet run(const Packet& a, const Packet& exponent) { - // We want to return a * 2^exponent, allowing for all possible integer - // exponents without overflowing or underflowing in intermediate - // computations. - // - // Since 'a' and the output can be denormal, the maximum range of 'exponent' - // to consider for a float is: - // -255-23 -> 255+23 - // Below -278 any finite float 'a' will become zero, and above +278 any - // finite float will become inf, including when 'a' is the smallest possible - // denormal. - // - // Unfortunately, 2^(278) cannot be represented using either one or two - // finite normal floats, so we must split the scale factor into at least - // three parts. It turns out to be faster to split 'exponent' into four - // factors, since [exponent>>2] is much faster to compute that [exponent/3]. - // - // Set e = min(max(exponent, -278), 278); - // b = floor(e/4); - // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) - // - // This will avoid any intermediate overflows and correctly handle 0, inf, - // NaN cases. - const Packet max_exponent = pset1(Scalar( (ScalarI(1)<((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 - const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); - PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); - Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b - Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) - b = psub(psub(psub(e, b), b), b); // e - 3b - c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) - out = pmul(out, c); - return out; - } -}; + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + const Packet max_exponent = pset1(Scalar((ScalarI(1)<((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127 + const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); + Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b + Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) + out = pmul(out, c); + return out; +} // Explicitly multiplies // a * (2^e) // clamping e to the range -// [std::numeric_limits::min_exponent-2, std::numeric_limits::max_exponent] +// [numeric_limits::min_exponent-2, numeric_limits::max_exponent] // // This is approx 7x faster than pldexp_impl, but will prematurely over/underflow // if 2^e doesn't fit into a normal floating-point Scalar. @@ -111,7 +145,7 @@ struct pldexp_fast_impl { typedef typename unpacket_traits::type ScalarI; enum { TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits::digits - 1, + MantissaBits = numext::numeric_limits::digits - 1, ExponentBits = int(TotalBits) - int(MantissaBits) - 1 }; @@ -126,14 +160,6 @@ struct pldexp_fast_impl { } }; -template EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent) -{ return pldexp_impl::run(a, exponent); } - -template EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent) -{ return pldexp_impl::run(a, exponent); } - // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 96c572fd3..637e5f4af 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -25,29 +25,23 @@ pset(const typename unpacket_traits::type (&a)[N] /* a */); * Some generic implementations to be used by implementors ***************************************************************************/ -/** Default implementation of pfrexp for float. +/** Default implementation of pfrexp. * It is expected to be called by implementers of template<> pfrexp. */ -template EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent); +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent); -/** Default implementation of pfrexp for double. - * It is expected to be called by implementers of template<> pfrexp. - */ -template EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent); - -/** Default implementation of pldexp for float. - * It is expected to be called by implementers of template<> pldexp. - */ -template EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent); +// Extracts the biased exponent value from Packet p, and casts the results to +// a floating-point Packet type. Used by pfrexp_generic. Override this if +// there is no unpacket_traits::integer_packet. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& p); -/** Default implementation of pldexp for double. +/** Default implementation of pldexp. * It is expected to be called by implementers of template<> pldexp. */ -template EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent); +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent); /** \internal \returns log(x) for single precision float */ template -- cgit v1.2.3