From 29ebd84cb779eae01302a9f1e40cf06ca5eeeceb Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 26 Feb 2021 13:59:46 -0800 Subject: Fix NEON sqrt for 32-bit, add prsqrt. With !406, we accidentally broke arm 32-bit NEON builds, since `vsqrt_f32` is only available for 64-bit. Here we add back the `rsqrt` implementation for 32-bit, relying on a `prsqrt` implementation with better handling of edge cases. Note that several of the 32-bit NEON packet tests are currently failing - either due to denormal handling (NEON versions flush to zero, but scalar paths don't) or due to accuracy (e.g. sin/cos). --- Eigen/src/Core/GenericPacketMath.h | 2 +- Eigen/src/Core/arch/NEON/PacketMath.h | 48 +++++++++++++++++++++++++++++++++++ test/packetmath.cpp | 3 ++- 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index f8c7826db..bc0fe39a7 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -684,7 +684,7 @@ Packet plog2(const Packet& a) { /** \internal \returns the square-root of \a a (coeff-wise) */ template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -Packet psqrt(const Packet& a) { EIGEN_USING_STD(sqrt); return sqrt(a); } +Packet psqrt(const Packet& a) { return numext::sqrt(a); } /** \internal \returns the reciprocal square-root of \a a (coeff-wise) */ template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 9715bf4b2..f77a18a4f 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -202,6 +202,7 @@ struct packet_traits : default_packet_traits HasLog = 1, HasExp = 1, HasSqrt = 1, + HasRsqrt = 1, HasTanh = EIGEN_FAST_MATH, HasErf = EIGEN_FAST_MATH, HasBessel = 0, // Issues with accuracy. @@ -3329,8 +3330,42 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { return res; } +template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { + // Compute approximate reciprocal sqrt. + Packet4f x = vrsqrteq_f32(a); + // Do Newton iterations for 1/sqrt(x). + x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); + x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); + const Packet4f infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + +template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) { + // Compute approximate reciprocal sqrt. + Packet2f x = vrsqrte_f32(a); + // Do Newton iterations for 1/sqrt(x). + x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); + x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); + const Packet2f infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + +// Unfortunately vsqrt_f32 is only available for A64. +#if EIGEN_ARCH_ARM64 template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);} template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); } +#else +template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { + const Packet4f infinity = pset1(NumTraits::infinity()); + const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); + return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); +} +template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) { + const Packet2f infinity = pset1(NumTraits::infinity()); + const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); + return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); +} +#endif //---------- bfloat16 ---------- // TODO: Add support for native armv8.6-a bfloat16_t @@ -3722,6 +3757,7 @@ template<> struct packet_traits : default_packet_traits HasLog = 1, HasExp = 1, HasSqrt = 1, + HasRsqrt = 1, HasTanh = 0, HasErf = 0 }; @@ -3933,6 +3969,17 @@ template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Pack template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) { return vreinterpretq_f64_u64(vdupq_n_u64(from)); } +template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) { + // Compute approximate reciprocal sqrt. + Packet2d x = vrsqrteq_f64(a); + // Do Newton iterations for 1/sqrt(x). + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + const Packet2d infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); } #endif // EIGEN_ARCH_ARM64 @@ -3978,6 +4025,7 @@ struct packet_traits : default_packet_traits { HasLog = 0, HasExp = 0, HasSqrt = 1, + HasRsqrt = 1, HasErf = EIGEN_FAST_MATH, HasBessel = 0, // Issues with accuracy. HasNdtri = 0, diff --git a/test/packetmath.cpp b/test/packetmath.cpp index ae7168fc8..ceafb9002 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -504,6 +504,7 @@ void packetmath() { data1[i] = numext::abs(internal::random()); } CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); + CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); } // Notice that this definition works for complex types as well. @@ -532,7 +533,7 @@ void packetmath_real() { CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog); CHECK_CWISE1_IF(PacketTraits::HasLog, log2, internal::plog2); - CHECK_CWISE1_IF(PacketTraits::HasRsqrt, 1 / std::sqrt, internal::prsqrt); + CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); for (int i = 0; i < size; ++i) { data1[i] = Scalar(internal::random(-1, 1) * std::pow(10., internal::random(-3, 3))); -- cgit v1.2.3