From 7ff0b7a980ceffe7d0e72ebac924f514f7874e9b Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 12 Feb 2021 11:32:29 -0800 Subject: Updated pfrexp implementation. The original implementation fails for 0, denormals, inf, and NaN. See #2150 --- Eigen/src/Core/arch/AVX/PacketMath.h | 22 +-- Eigen/src/Core/arch/AVX512/PacketMath.h | 27 ++-- Eigen/src/Core/arch/AltiVec/PacketMath.h | 62 +++++--- .../Core/arch/Default/GenericPacketMathFunctions.h | 164 ++++++++++++--------- .../arch/Default/GenericPacketMathFunctionsFwd.h | 28 ++-- Eigen/src/Core/arch/NEON/PacketMath.h | 12 +- Eigen/src/Core/arch/SSE/PacketMath.h | 19 ++- Eigen/src/Core/arch/SVE/PacketMath.h | 4 +- 8 files changed, 195 insertions(+), 143 deletions(-) (limited to 'Eigen/src/Core') diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 34bc242ca..23a2da8e9 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -734,12 +734,13 @@ template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a) } template<> EIGEN_STRONG_INLINE Packet8f pfrexp(const Packet8f& a, Packet8f& exponent) { - return pfrexp_float(a,exponent); + return pfrexp_generic(a,exponent); } -template<> EIGEN_STRONG_INLINE Packet4d pfrexp(const Packet4d& a, Packet4d& exponent) { - const Packet4d cst_1022d = pset1(1022.0); - const Packet4d cst_half = pset1(0.5); +// Extract exponent without existence of Packet4l. +template<> +EIGEN_STRONG_INLINE +Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) { const Packet4d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); __m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask)); #ifdef EIGEN_VECTORIZE_AVX2 @@ -754,15 +755,18 @@ template<> EIGEN_STRONG_INLINE Packet4d pfrexp(const Packet4d& a, Pack #endif Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3)); Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3)); - exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0); + Packet4d exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0); exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1); - exponent = psub(exponent, cst_1022d); - const Packet4d cst_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - return por(pand(a, cst_mant_mask), cst_half); + return exponent; +} + + +template<> EIGEN_STRONG_INLINE Packet4d pfrexp(const Packet4d& a, Packet4d& exponent) { + return pfrexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet8f pldexp(const Packet8f& a, const Packet8f& exponent) { - return pldexp_float(a,exponent); + return pldexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet4d pldexp(const Packet4d& a, const Packet4d& exponent) { diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index b9e9fdbfd..f8741372d 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -895,25 +895,28 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { template<> EIGEN_STRONG_INLINE Packet16f pfrexp(const Packet16f& a, Packet16f& exponent){ - return pfrexp_float(a, exponent); + return pfrexp_generic(a, exponent); +} + +// Extract exponent without existence of Packet8l. +template<> +EIGEN_STRONG_INLINE +Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) { + const Packet8d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); + #ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)); + #else + return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52))); + #endif } template<> EIGEN_STRONG_INLINE Packet8d pfrexp(const Packet8d& a, Packet8d& exponent) { - const Packet8d cst_1022d = pset1(1022.0); -#ifdef EIGEN_TEST_AVX512DQ - exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d); -#else - exponent = psub(_mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(a), 52))), - cst_1022d); -#endif - const Packet8d cst_half = pset1(0.5); - const Packet8d cst_inv_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - return por(pand(a, cst_inv_mant_mask), cst_half); + return pfrexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet16f pldexp(const Packet16f& a, const Packet16f& exponent) { - return pldexp_float(a,exponent); + return pldexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, const Packet8d& exponent) { diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index f29b59590..df04b8e0f 100755 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -1160,43 +1160,48 @@ template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { return pand(p8us_abs_mask, a); } -template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) +template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) { return vec_sra(a,reinterpret_cast(pset1(N))); } -template EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a) +template EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a) { return vec_sr(a,reinterpret_cast(pset1(N))); } -template EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) +template EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a) { return vec_sl(a,reinterpret_cast(pset1(N))); } -template EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a) +template EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a) { const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); Packet4ui r = vec_sl(reinterpret_cast(a), p4ui_mask); return reinterpret_cast(r); } -template EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a) +template EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a) { const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); Packet4ui r = vec_sr(reinterpret_cast(a), p4ui_mask); return reinterpret_cast(r); } -template EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a) { const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); return vec_sr(a, p4ui_mask); } -template EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a) { const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); return vec_sl(a, p4ui_mask); } -template EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) +template EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a) { const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); return vec_sl(a, p8us_mask); } +template EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); + return vec_sr(a, p8us_mask); +} EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){ return plogical_shift_left<16>(reinterpret_cast(bf.m_val)); @@ -1323,14 +1328,14 @@ template<> EIGEN_STRONG_INLINE Packet8bf pexp (const Packet8bf& a){ } template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { - return pldexp_float(a,exponent); + return pldexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet8bf pldexp (const Packet8bf& a, const Packet8bf& exponent){ - BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent); + BF16_TO_F32_BINARY_OP_WRAPPER(pldexp, a, exponent); } template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { - return pfrexp_float(a,exponent); + return pfrexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet8bf pfrexp (const Packet8bf& a, Packet8bf& e){ Packet4f a_even = Bf16ToF32Even(a); @@ -2324,6 +2329,11 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { return v; } +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(unsigned long from) { + Packet2l v = {static_cast(from), static_cast(from)}; + return reinterpret_cast(v); +} + template<> EIGEN_STRONG_INLINE void pbroadcast4(const double *a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) @@ -2439,7 +2449,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs // a slow version that works with older compilers. // Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles // are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963 -static inline Packet2l ConvertToPacket2l(const Packet2d& x) { +template<> +inline Packet2l pcast(const Packet2d& x) { #if EIGEN_GNUC_AT_LEAST(5, 4) || \ (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1) return vec_cts(x, 0); // TODO: check clang version. @@ -2452,6 +2463,15 @@ static inline Packet2l ConvertToPacket2l(const Packet2d& x) { #endif } +template<> +inline Packet2d pcast(const Packet2l& x) { + unsigned long long tmp[2]; + memcpy(tmp, &x, sizeof(tmp)); + Packet2d d = { static_cast(tmp[0]), + static_cast(tmp[1]) }; + return d; +} + // Packet2l shifts. // For POWER8 we simply use vec_sr/l. @@ -2569,7 +2589,7 @@ EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) { template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) { // Clamp exponent to [-2099, 2099] const Packet2d max_exponent = pset1(2099.0); - const Packet2l e = ConvertToPacket2l(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + const Packet2l e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); // Split 2^e into four factors and multiply: const Packet2l bias = { 1023, 1023 }; @@ -2582,14 +2602,16 @@ template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, cons return out; } + +// Extract exponent without existence of Packet2l. +template<> +EIGEN_STRONG_INLINE +Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) { + return pcast(plogical_shift_right<52>(reinterpret_cast(pabs(a)))); +} + template<> EIGEN_STRONG_INLINE Packet2d pfrexp (const Packet2d& a, Packet2d& exponent) { - double exp[2] = { exponent[0], exponent[1] }; - Packet2d ret = { pfrexp(a[0], exp[0]), pfrexp(a[1], exp[1]) }; - exponent[0] = exp[0]; - exponent[1] = exp[1]; - return ret; -// This doesn't currently work (no integer_packet for Packet2d - but adding it causes other problems) -// return pfrexp_double(a, exponent); + return pfrexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 452019ecc..42d310ab2 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -25,80 +25,114 @@ pset(const typename unpacket_traits::type (&a)[N] /* a */) { return pload(a); } -template EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent) { +// Creates a Scalar integer type with same bit-width. +template struct make_integer; +template<> struct make_integer { typedef numext::int32_t type; }; +template<> struct make_integer { typedef numext::int64_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; + +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& a) { + typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::integer_packet PacketI; - const Packet cst_126f = pset1(126.0f); - const Packet cst_half = pset1(0.5f); - const Packet cst_inv_mant_mask = pset1frombits(~0x7f800000u); - exponent = psub(pcast(plogical_shift_right<23>(preinterpret(pabs(a)))), cst_126f); - return por(pand(a, cst_inv_mant_mask), cst_half); + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + return pcast(plogical_shift_right(preinterpret(pabs(a)))); } -template EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent) { - typedef typename unpacket_traits::integer_packet PacketI; - const Packet cst_1022d = pset1(1022.0); - const Packet cst_half = pset1(0.5); - const Packet cst_inv_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - exponent = psub(pcast(plogical_shift_right<52>(preinterpret(pabs(a)))), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); +// Safely applies frexp, correctly handles denormals. +// Assumes IEEE floating point format. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent) { + typedef typename unpacket_traits::type Scalar; + typedef typename make_unsigned::type>::type ScalarUI; + + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = + ~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000 + const Packet sign_mantissa_mask = pset1frombits(static_cast(scalar_sign_mantissa_mask)); + const Packet half = pset1(Scalar(0.5)); + const Packet zero = pzero(a); + const Packet normal_min = pset1((numext::numeric_limits::min)()); // Minimum normal value, 2^-126 + + // To handle denormals, normalize by multiplying by 2^(mantissa_bits+1). + const Packet is_denormal = pcmp_lt(pabs(a), normal_min); + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24 + // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. + const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 + const Packet normalization_factor = pset1(scalar_normalization_factor); + const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); + + // Determine exponent offset: -126 if normal, -126-24 if denormal + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126 + Packet exponent_offset = pset1(scalar_exponent_offset); + const Packet normalization_offset = pset1(-Scalar(scalar_normalization_offset)); // -24 + exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); + + // Determine exponent and mantissa from normalized_a. + exponent = pfrexp_generic_get_biased_exponent(normalized_a); + // Zero, Inf and NaN return 'a' unmodified, exponent is zero + // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255 + const Packet non_finite_exponent = pset1(scalar_non_finite_exponent); + const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); + const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); + exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); + return m; } // Safely applies ldexp, correctly handles overflows, underflows and denormals. // Assumes IEEE floating point format. -template -struct pldexp_impl { +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent) { + // We want to return a * 2^exponent, allowing for all possible integer + // exponents without overflowing or underflowing in intermediate + // computations. + // + // Since 'a' and the output can be denormal, the maximum range of 'exponent' + // to consider for a float is: + // -255-23 -> 255+23 + // Below -278 any finite float 'a' will become zero, and above +278 any + // finite float will become inf, including when 'a' is the smallest possible + // denormal. + // + // Unfortunately, 2^(278) cannot be represented using either one or two + // finite normal floats, so we must split the scale factor into at least + // three parts. It turns out to be faster to split 'exponent' into four + // factors, since [exponent>>2] is much faster to compute that [exponent/3]. + // + // Set e = min(max(exponent, -278), 278); + // b = floor(e/4); + // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) + // + // This will avoid any intermediate overflows and correctly handle 0, inf, + // NaN cases. typedef typename unpacket_traits::integer_packet PacketI; typedef typename unpacket_traits::type Scalar; typedef typename unpacket_traits::type ScalarI; - enum { - TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits::digits - 1, - ExponentBits = int(TotalBits) - int(MantissaBits) - 1 - }; - - static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC - Packet run(const Packet& a, const Packet& exponent) { - // We want to return a * 2^exponent, allowing for all possible integer - // exponents without overflowing or underflowing in intermediate - // computations. - // - // Since 'a' and the output can be denormal, the maximum range of 'exponent' - // to consider for a float is: - // -255-23 -> 255+23 - // Below -278 any finite float 'a' will become zero, and above +278 any - // finite float will become inf, including when 'a' is the smallest possible - // denormal. - // - // Unfortunately, 2^(278) cannot be represented using either one or two - // finite normal floats, so we must split the scale factor into at least - // three parts. It turns out to be faster to split 'exponent' into four - // factors, since [exponent>>2] is much faster to compute that [exponent/3]. - // - // Set e = min(max(exponent, -278), 278); - // b = floor(e/4); - // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) - // - // This will avoid any intermediate overflows and correctly handle 0, inf, - // NaN cases. - const Packet max_exponent = pset1(Scalar( (ScalarI(1)<((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 - const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); - PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); - Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b - Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) - b = psub(psub(psub(e, b), b), b); // e - 3b - c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) - out = pmul(out, c); - return out; - } -}; + EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT; + EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits::digits - 1; + EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1; + + const Packet max_exponent = pset1(Scalar((ScalarI(1)<((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127 + const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); + Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b + Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) + out = pmul(out, c); + return out; +} // Explicitly multiplies // a * (2^e) // clamping e to the range -// [std::numeric_limits::min_exponent-2, std::numeric_limits::max_exponent] +// [numeric_limits::min_exponent-2, numeric_limits::max_exponent] // // This is approx 7x faster than pldexp_impl, but will prematurely over/underflow // if 2^e doesn't fit into a normal floating-point Scalar. @@ -111,7 +145,7 @@ struct pldexp_fast_impl { typedef typename unpacket_traits::type ScalarI; enum { TotalBits = sizeof(Scalar) * CHAR_BIT, - MantissaBits = std::numeric_limits::digits - 1, + MantissaBits = numext::numeric_limits::digits - 1, ExponentBits = int(TotalBits) - int(MantissaBits) - 1 }; @@ -126,14 +160,6 @@ struct pldexp_fast_impl { } }; -template EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent) -{ return pldexp_impl::run(a, exponent); } - -template EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent) -{ return pldexp_impl::run(a, exponent); } - // Natural or base 2 logarithm. // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 96c572fd3..637e5f4af 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -25,29 +25,23 @@ pset(const typename unpacket_traits::type (&a)[N] /* a */); * Some generic implementations to be used by implementors ***************************************************************************/ -/** Default implementation of pfrexp for float. +/** Default implementation of pfrexp. * It is expected to be called by implementers of template<> pfrexp. */ -template EIGEN_STRONG_INLINE Packet -pfrexp_float(const Packet& a, Packet& exponent); +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent); -/** Default implementation of pfrexp for double. - * It is expected to be called by implementers of template<> pfrexp. - */ -template EIGEN_STRONG_INLINE Packet -pfrexp_double(const Packet& a, Packet& exponent); - -/** Default implementation of pldexp for float. - * It is expected to be called by implementers of template<> pldexp. - */ -template EIGEN_STRONG_INLINE Packet -pldexp_float(const Packet& a, const Packet& exponent); +// Extracts the biased exponent value from Packet p, and casts the results to +// a floating-point Packet type. Used by pfrexp_generic. Override this if +// there is no unpacket_traits::integer_packet. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& p); -/** Default implementation of pldexp for double. +/** Default implementation of pldexp. * It is expected to be called by implementers of template<> pldexp. */ -template EIGEN_STRONG_INLINE Packet -pldexp_double(const Packet& a, const Packet& exponent); +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent); /** \internal \returns log(x) for single precision float */ template diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 98d78751a..2e06befc2 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -2402,14 +2402,14 @@ template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) { template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; } template<> EIGEN_STRONG_INLINE Packet2f pfrexp(const Packet2f& a, Packet2f& exponent) -{ return pfrexp_float(a,exponent); } +{ return pfrexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) -{ return pfrexp_float(a,exponent); } +{ return pfrexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet2f pldexp(const Packet2f& a, const Packet2f& exponent) -{ return pldexp_float(a,exponent); } +{ return pldexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) -{ return pldexp_float(a,exponent); } +{ return pldexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE float predux(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); } template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) @@ -3907,10 +3907,10 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2 { return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); } template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) -{ return pldexp_double(a, exponent); } +{ return pldexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) -{ return pfrexp_double(a,exponent); } +{ return pfrexp_generic(a,exponent); } template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) { return vreinterpretq_f64_u64(vdupq_n_u64(from)); } diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 422b0fce7..401a497d2 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -887,21 +887,24 @@ template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) } template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { - return pfrexp_float(a,exponent); + return pfrexp_generic(a,exponent); } -template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) { - const Packet2d cst_1022d = pset1(1022.0); - const Packet2d cst_half = pset1(0.5); +// Extract exponent without existence of Packet2l. +template<> +EIGEN_STRONG_INLINE +Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) { const Packet2d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); __m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52); - exponent = psub(_mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)), cst_1022d); - const Packet2d cst_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - return por(pand(a, cst_mant_mask), cst_half); + return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)); +} + +template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) { + return pfrexp_generic(a, exponent); } template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { - return pldexp_float(a,exponent); + return pldexp_generic(a,exponent); } // We specialize pldexp here, since the generic implementation uses Packet2l, which is not well diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h index 98585e6a9..4877b6d80 100644 --- a/Eigen/src/Core/arch/SVE/PacketMath.h +++ b/Eigen/src/Core/arch/SVE/PacketMath.h @@ -669,7 +669,7 @@ EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) template <> EIGEN_STRONG_INLINE PacketXf pfrexp(const PacketXf& a, PacketXf& exponent) { - return pfrexp_float(a, exponent); + return pfrexp_generic(a, exponent); } template <> @@ -747,7 +747,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) template<> EIGEN_STRONG_INLINE PacketXf pldexp(const PacketXf& a, const PacketXf& exponent) { - return pldexp_float(a, exponent); + return pldexp_generic(a, exponent); } } // namespace internal -- cgit v1.2.3