aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-02-26 13:59:46 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-02-26 14:08:40 -0800
commit29ebd84cb779eae01302a9f1e40cf06ca5eeeceb (patch)
treeb5f4e4991a36f234b68be6d7c2b00d62d060f8e3 /Eigen/src/Core/arch/NEON/PacketMath.h
parentfe19714f8094a2b6d6dab0cdd3c32874d0ad66b9 (diff)
Fix NEON sqrt for 32-bit, add prsqrt.
With !406, we accidentally broke arm 32-bit NEON builds, since `vsqrt_f32` is only available for 64-bit. Here we add back the `rsqrt` implementation for 32-bit, relying on a `prsqrt` implementation with better handling of edge cases. Note that several of the 32-bit NEON packet tests are currently failing - either due to denormal handling (NEON versions flush to zero, but scalar paths don't) or due to accuracy (e.g. sin/cos).
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h48
1 files changed, 48 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 9715bf4b2..f77a18a4f 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -202,6 +202,7 @@ struct packet_traits<float> : default_packet_traits
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
+ HasRsqrt = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasBessel = 0, // Issues with accuracy.
@@ -3329,8 +3330,42 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) {
return res;
}
+template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet4f x = vrsqrteq_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x);
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2f x = vrsqrte_f32(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x);
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
+// Unfortunately vsqrt_f32 is only available for A64.
+#if EIGEN_ARCH_ARM64
template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);}
template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); }
+#else
+template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) {
+ const Packet4f infinity = pset1<Packet4f>(NumTraits<float>::infinity());
+ const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) {
+ const Packet2f infinity = pset1<Packet2f>(NumTraits<float>::infinity());
+ const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity));
+ return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a)));
+}
+#endif
//---------- bfloat16 ----------
// TODO: Add support for native armv8.6-a bfloat16_t
@@ -3722,6 +3757,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
+ HasRsqrt = 1,
HasTanh = 0,
HasErf = 0
};
@@ -3933,6 +3969,17 @@ template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Pack
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
+template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) {
+ // Compute approximate reciprocal sqrt.
+ Packet2d x = vrsqrteq_f64(a);
+ // Do Newton iterations for 1/sqrt(x).
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x);
+ const Packet2d infinity = pset1<Packet2d>(NumTraits<double>::infinity());
+ return pselect(pcmp_eq(a, pzero(a)), infinity, x);
+}
+
template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); }
#endif // EIGEN_ARCH_ARM64
@@ -3978,6 +4025,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasLog = 0,
HasExp = 0,
HasSqrt = 1,
+ HasRsqrt = 1,
HasErf = EIGEN_FAST_MATH,
HasBessel = 0, // Issues with accuracy.
HasNdtri = 0,