aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-02-12 11:32:29 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-17 02:23:24 +0000
commit7ff0b7a980ceffe7d0e72ebac924f514f7874e9b (patch)
tree4022a25c9e1c909aff3f43b978a4a5d5ce0a1c47 /Eigen/src/Core/arch/Default
parent9ad4096ccb75dd5c5dd882576d49d48475afa300 (diff)
Updated pfrexp implementation.
The original implementation fails for 0, denormals, inf, and NaN. See #2150
Diffstat (limited to 'Eigen/src/Core/arch/Default')
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h164
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h28
2 files changed, 106 insertions, 86 deletions
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 452019ecc..42d310ab2 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -25,80 +25,114 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */) {
return pload<Packet>(a);
}
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pfrexp_float(const Packet& a, Packet& exponent) {
+// Creates a Scalar integer type with same bit-width.
+template<typename T> struct make_integer;
+template<> struct make_integer<float> { typedef numext::int32_t type; };
+template<> struct make_integer<double> { typedef numext::int64_t type; };
+template<> struct make_integer<half> { typedef numext::int16_t type; };
+template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
+
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- const Packet cst_126f = pset1<Packet>(126.0f);
- const Packet cst_half = pset1<Packet>(0.5f);
- const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
- exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f);
- return por(pand(a, cst_inv_mant_mask), cst_half);
+ EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
+ return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pfrexp_double(const Packet& a, Packet& exponent) {
- typedef typename unpacket_traits<Packet>::integer_packet PacketI;
- const Packet cst_1022d = pset1<Packet>(1022.0);
- const Packet cst_half = pset1<Packet>(0.5);
- const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull));
- exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
- return por(pand(a, cst_inv_mant_mask), cst_half);
+// Safely applies frexp, correctly handles denormals.
+// Assumes IEEE floating point format.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent) {
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
+
+ EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
+ EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
+ EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
+
+ EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
+ ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000
+ const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
+ const Packet half = pset1<Packet>(Scalar(0.5));
+ const Packet zero = pzero(a);
+ const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
+
+ // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1).
+ const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
+ EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24
+ // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
+ const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
+ const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
+ const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
+
+ // Determine exponent offset: -126 if normal, -126-24 if denormal
+ const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126
+ Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
+ const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
+ exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
+
+ // Determine exponent and mantissa from normalized_a.
+ exponent = pfrexp_generic_get_biased_exponent(normalized_a);
+ // Zero, Inf and NaN return 'a' unmodified, exponent is zero
+ // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
+ const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255
+ const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
+ const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
+ const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
+ exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
+ return m;
}
// Safely applies ldexp, correctly handles overflows, underflows and denormals.
// Assumes IEEE floating point format.
-template<typename Packet>
-struct pldexp_impl {
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent) {
+ // We want to return a * 2^exponent, allowing for all possible integer
+ // exponents without overflowing or underflowing in intermediate
+ // computations.
+ //
+ // Since 'a' and the output can be denormal, the maximum range of 'exponent'
+ // to consider for a float is:
+ // -255-23 -> 255+23
+ // Below -278 any finite float 'a' will become zero, and above +278 any
+ // finite float will become inf, including when 'a' is the smallest possible
+ // denormal.
+ //
+ // Unfortunately, 2^(278) cannot be represented using either one or two
+ // finite normal floats, so we must split the scale factor into at least
+ // three parts. It turns out to be faster to split 'exponent' into four
+ // factors, since [exponent>>2] is much faster to compute that [exponent/3].
+ //
+ // Set e = min(max(exponent, -278), 278);
+ // b = floor(e/4);
+ // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
+ //
+ // This will avoid any intermediate overflows and correctly handle 0, inf,
+ // NaN cases.
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
- enum {
- TotalBits = sizeof(Scalar) * CHAR_BIT,
- MantissaBits = std::numeric_limits<Scalar>::digits - 1,
- ExponentBits = int(TotalBits) - int(MantissaBits) - 1
- };
-
- static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
- Packet run(const Packet& a, const Packet& exponent) {
- // We want to return a * 2^exponent, allowing for all possible integer
- // exponents without overflowing or underflowing in intermediate
- // computations.
- //
- // Since 'a' and the output can be denormal, the maximum range of 'exponent'
- // to consider for a float is:
- // -255-23 -> 255+23
- // Below -278 any finite float 'a' will become zero, and above +278 any
- // finite float will become inf, including when 'a' is the smallest possible
- // denormal.
- //
- // Unfortunately, 2^(278) cannot be represented using either one or two
- // finite normal floats, so we must split the scale factor into at least
- // three parts. It turns out to be faster to split 'exponent' into four
- // factors, since [exponent>>2] is much faster to compute that [exponent/3].
- //
- // Set e = min(max(exponent, -278), 278);
- // b = floor(e/4);
- // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
- //
- // This will avoid any intermediate overflows and correctly handle 0, inf,
- // NaN cases.
- const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278
- const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
- const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
- PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
- Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
- Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
- b = psub(psub(psub(e, b), b), b); // e - 3b
- c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
- out = pmul(out, c);
- return out;
- }
-};
+ EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
+ EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
+ EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
+
+ const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278
+ const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127
+ const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
+ PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
+ Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b
+ Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
+ b = psub(psub(psub(e, b), b), b); // e - 3b
+ c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b)
+ out = pmul(out, c);
+ return out;
+}
// Explicitly multiplies
// a * (2^e)
// clamping e to the range
-// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent]
+// [numeric_limits<Scalar>::min_exponent-2, numeric_limits<Scalar>::max_exponent]
//
// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
// if 2^e doesn't fit into a normal floating-point Scalar.
@@ -111,7 +145,7 @@ struct pldexp_fast_impl {
typedef typename unpacket_traits<PacketI>::type ScalarI;
enum {
TotalBits = sizeof(Scalar) * CHAR_BIT,
- MantissaBits = std::numeric_limits<Scalar>::digits - 1,
+ MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
@@ -126,14 +160,6 @@ struct pldexp_fast_impl {
}
};
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_float(const Packet& a, const Packet& exponent)
-{ return pldexp_impl<Packet>::run(a, exponent); }
-
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_double(const Packet& a, const Packet& exponent)
-{ return pldexp_impl<Packet>::run(a, exponent); }
-
// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
index 96c572fd3..637e5f4af 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
@@ -25,29 +25,23 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */);
* Some generic implementations to be used by implementors
***************************************************************************/
-/** Default implementation of pfrexp for float.
+/** Default implementation of pfrexp.
* It is expected to be called by implementers of template<> pfrexp.
*/
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pfrexp_float(const Packet& a, Packet& exponent);
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic(const Packet& a, Packet& exponent);
-/** Default implementation of pfrexp for double.
- * It is expected to be called by implementers of template<> pfrexp.
- */
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pfrexp_double(const Packet& a, Packet& exponent);
-
-/** Default implementation of pldexp for float.
- * It is expected to be called by implementers of template<> pldexp.
- */
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_float(const Packet& a, const Packet& exponent);
+// Extracts the biased exponent value from Packet p, and casts the results to
+// a floating-point Packet type. Used by pfrexp_generic. Override this if
+// there is no unpacket_traits<Packet>::integer_packet.
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pfrexp_generic_get_biased_exponent(const Packet& p);
-/** Default implementation of pldexp for double.
+/** Default implementation of pldexp.
* It is expected to be called by implementers of template<> pldexp.
*/
-template<typename Packet> EIGEN_STRONG_INLINE Packet
-pldexp_double(const Packet& a, const Packet& exponent);
+template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
+Packet pldexp_generic(const Packet& a, const Packet& exponent);
/** \internal \returns log(x) for single precision float */
template <typename Packet>