diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-01-20 19:00:09 -0800 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2021-01-21 19:32:28 +0000 |
commit | b2126fd6b5e232d072ceadb1abb6695ae3352e2e (patch) | |
tree | b86944d559717eeee3589efa21dcfd30cbdd2f3d | |
parent | 25d8498f8ba29c8dc055dd56113facbdbe154345 (diff) |
Fix pfrexp/pldexp for half.
The recent addition of vectorized pow (!330) relies on `pfrexp` and
`pldexp`. This was missing for `Eigen::half` and `Eigen::bfloat16`.
Adding tests for these packet ops also exposed an issue with handling
negative values in `pfrexp`, returning an incorrect exponent.
Added the missing implementations, corrected the exponent in `pfrexp1`,
and added `packetmath` tests.
-rw-r--r-- | Eigen/src/Core/GenericPacketMath.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX/MathFunctions.h | 26 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 26 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/arch/NEON/MathFunctions.h | 12 | ||||
-rw-r--r-- | test/packetmath.cpp | 26 | ||||
-rw-r--r-- | test/packetmath_test_shared.h | 18 |
7 files changed, 112 insertions, 4 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 16119c1d8..b02a9f20b 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -442,7 +442,7 @@ template <typename Packet> EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) { int exp; EIGEN_USING_STD(frexp); - Packet result = frexp(a, &exp); + Packet result = static_cast<Packet>(frexp(a, &exp)); exponent = static_cast<Packet>(exp); return result; } @@ -453,7 +453,7 @@ EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) { template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pldexp(const Packet &a, const Packet &exponent) { EIGEN_USING_STD(ldexp) - return ldexp(a, static_cast<int>(exponent)); + return static_cast<Packet>(ldexp(a, static_cast<int>(exponent))); } /** \internal \returns the min of \a a and \a b (coeff-wise) */ diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index c0d362fa7..67041c812 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -184,6 +184,19 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) +template <> +EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) { + Packet8f fexponent; + const Packet8h out = float2half(pfrexp<Packet8f>(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) { + return float2half(pldexp<Packet8f>(half2float(a), half2float(exponent))); +} + BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog) @@ -195,6 +208,19 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) +template <> +EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) { + Packet8f fexponent; + const Packet8bf out = F32ToBf16(pfrexp<Packet8f>(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) { + return F32ToBf16(pldexp<Packet8f>(Bf16ToF32(a), Bf16ToF32(exponent))); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 66f3252cd..41929cb34 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -191,6 +191,32 @@ pexp<Packet8d>(const Packet8d& _x) { F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) +template <> +EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) { + Packet16f fexponent; + const Packet16h out = float2half(pfrexp<Packet16f>(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) { + return float2half(pldexp<Packet16f>(half2float(a), half2float(exponent))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) { + Packet16f fexponent; + const Packet16bf out = F32ToBf16(pfrexp<Packet16f>(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) { + return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent))); +} + // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step // of Newton's method, at a cost of 1-2 bits of precision as opposed to the diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 9a1feb0d9..69c92a8cc 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -31,7 +31,7 @@ pfrexp_float(const Packet& a, Packet& exponent) { const Packet cst_126f = pset1<Packet>(126.0f); const Packet cst_half = pset1<Packet>(0.5f); const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u); - exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(a))), cst_126f); + exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f); return por(pand(a, cst_inv_mant_mask), cst_half); } @@ -41,7 +41,7 @@ pfrexp_double(const Packet& a, Packet& exponent) { const Packet cst_1022d = pset1<Packet>(1022.0); const Packet cst_half = pset1<Packet>(0.5); const Packet cst_inv_mant_mask = pset1frombits<Packet>(static_cast<uint64_t>(~0x7ff0000000000000ull)); - exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(a))), cst_1022d); + exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d); return por(pand(a, cst_inv_mant_mask), cst_half); } diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h index 28167b904..fa6615a85 100644 --- a/Eigen/src/Core/arch/NEON/MathFunctions.h +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -44,6 +44,18 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog) BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp) BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh) +template <> +EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) { + Packet4f fexponent; + const Packet4bf out = F32ToBf16(pfrexp<Packet4f>(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) { + return F32ToBf16(pldexp<Packet4f>(Bf16ToF32(a), Bf16ToF32(exponent))); +} //---------- double ---------- diff --git a/test/packetmath.cpp b/test/packetmath.cpp index ab9bec183..b7562e6a1 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -46,6 +46,21 @@ inline bool REF_MUL(const bool& a, const bool& b) { return a && b; } +template <typename T> +inline T REF_FREXP(const T& x, T& exp) { + int iexp; + EIGEN_USING_STD(frexp) + const T out = static_cast<T>(frexp(x, &iexp)); + exp = static_cast<T>(iexp); + return out; +} + +template <typename T> +inline T REF_LDEXP(const T& x, const T& exp) { + EIGEN_USING_STD(ldexp) + return static_cast<T>(ldexp(x, static_cast<int>(exp))); +} + // Uses pcast to cast from one array to another. template <typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio> struct pcast_array; @@ -552,6 +567,17 @@ void packetmath_real() { data2[i] = Scalar(internal::random<double>(-87, 88)); } CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp); + CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp); + for (int i = 0; i < PacketSize; ++i) { + data1[i] = Scalar(internal::random<double>(-1, 1)); + data2[i] = Scalar(internal::random<double>(-1, 1)); + } + for (int i = 0; i < PacketSize; ++i) { + data1[i+PacketSize] = Scalar(internal::random<int>(0, 4)); + data2[i+PacketSize] = Scalar(internal::random<double>(0, 4)); + } + CHECK_CWISE2_IF(PacketTraits::HasExp, REF_LDEXP, internal::pldexp); + for (int i = 0; i < size; ++i) { data1[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6))); data2[i] = Scalar(internal::random<double>(-1, 1) * std::pow(10., internal::random<double>(-6, 6))); diff --git a/test/packetmath_test_shared.h b/test/packetmath_test_shared.h index 46a42604b..027715a89 100644 --- a/test/packetmath_test_shared.h +++ b/test/packetmath_test_shared.h @@ -143,6 +143,9 @@ struct packet_helper template<typename T> inline void store(T* to, const Packet& x, unsigned long long umask) const { internal::pstoreu(to, x, umask); } + + template<typename T> + inline Packet& forward_reference(Packet& packet, T& /*scalar*/) const { return packet; } }; template<typename Packet> @@ -162,6 +165,9 @@ struct packet_helper<false,Packet> template<typename T> inline void store(T* to, const T& x, unsigned long long) const { *to = x; } + + template<typename T> + inline T& forward_reference(Packet& /*packet*/, T& scalar) const { return scalar; } }; #define CHECK_CWISE1_IF(COND, REFOP, POP) if(COND) { \ @@ -180,6 +186,18 @@ struct packet_helper<false,Packet> VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \ } +// One input, one output by reference. +#define CHECK_CWISE1_BYREF1_IF(COND, REFOP, POP) if(COND) { \ + test::packet_helper<COND,Packet> h; \ + for (int i=0; i<PacketSize; ++i) \ + ref[i] = Scalar(REFOP(data1[i], ref[i+PacketSize])); \ + Packet pout; \ + Scalar sout; \ + h.store(data2, POP(h.load(data1), h.forward_reference(pout, sout))); \ + h.store(data2+PacketSize, h.forward_reference(pout, sout)); \ + VERIFY(test::areApprox(ref, data2, 2 * PacketSize) && #POP); \ +} + #define CHECK_CWISE3_IF(COND, REFOP, POP) if (COND) { \ test::packet_helper<COND, Packet> h; \ for (int i = 0; i < PacketSize; ++i) \ |