From b2126fd6b5e232d072ceadb1abb6695ae3352e2e Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 20 Jan 2021 19:00:09 -0800 Subject: Fix pfrexp/pldexp for half. The recent addition of vectorized pow (!330) relies on `pfrexp` and `pldexp`. This was missing for `Eigen::half` and `Eigen::bfloat16`. Adding tests for these packet ops also exposed an issue with handling negative values in `pfrexp`, returning an incorrect exponent. Added the missing implementations, corrected the exponent in `pfrexp1`, and added `packetmath` tests. --- Eigen/src/Core/arch/AVX/MathFunctions.h | 26 ++++++++++++++++++++++ Eigen/src/Core/arch/AVX512/MathFunctions.h | 26 ++++++++++++++++++++++ .../Core/arch/Default/GenericPacketMathFunctions.h | 4 ++-- Eigen/src/Core/arch/NEON/MathFunctions.h | 12 ++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) (limited to 'Eigen/src/Core/arch') diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index c0d362fa7..67041c812 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -184,6 +184,19 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) +template <> +EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) { + Packet8f fexponent; + const Packet8h out = float2half(pfrexp(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) { + return float2half(pldexp(half2float(a), half2float(exponent))); +} + BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog) @@ -195,6 +208,19 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) +template <> +EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) { + Packet8f fexponent; + const Packet8bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 66f3252cd..41929cb34 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -191,6 +191,32 @@ pexp(const Packet8d& _x) { F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) +template <> +EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) { + Packet16f fexponent; + const Packet16h out = float2half(pfrexp(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) { + return float2half(pldexp(half2float(a), half2float(exponent))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) { + Packet16f fexponent; + const Packet16bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} + // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step // of Newton's method, at a cost of 1-2 bits of precision as opposed to the diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 9a1feb0d9..69c92a8cc 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -31,7 +31,7 @@ pfrexp_float(const Packet& a, Packet& exponent) { const Packet cst_126f = pset1(126.0f); const Packet cst_half = pset1(0.5f); const Packet cst_inv_mant_mask = pset1frombits(~0x7f800000u); - exponent = psub(pcast(plogical_shift_right<23>(preinterpret(a))), cst_126f); + exponent = psub(pcast(plogical_shift_right<23>(preinterpret(pabs(a)))), cst_126f); return por(pand(a, cst_inv_mant_mask), cst_half); } @@ -41,7 +41,7 @@ pfrexp_double(const Packet& a, Packet& exponent) { const Packet cst_1022d = pset1(1022.0); const Packet cst_half = pset1(0.5); const Packet cst_inv_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); - exponent = psub(pcast(plogical_shift_right<52>(preinterpret(a))), cst_1022d); + exponent = psub(pcast(plogical_shift_right<52>(preinterpret(pabs(a)))), cst_1022d); return por(pand(a, cst_inv_mant_mask), cst_half); } diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h index 28167b904..fa6615a85 100644 --- a/Eigen/src/Core/arch/NEON/MathFunctions.h +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -44,6 +44,18 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog) BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp) BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh) +template <> +EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) { + Packet4f fexponent; + const Packet4bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} //---------- double ---------- -- cgit v1.2.3