aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-01 16:28:48 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-12-02 01:29:34 +0000
commit2627e2f2e6125cf09fa32789755135e84552275b (patch)
tree98aac84d6faf17d02688a2a4478c4c582932a812 /Eigen/src/Core/arch/NEON/PacketMath.h
parentddd48b242cbf3aa79ad13668b66089cece6d1ea0 (diff)
Fix neon cmp* functions for bf16.
The current impl corrupts the comparison masks when converting from float back to bfloat16. The resulting masks are then no longer all zeros or all ones, which breaks when used with `pselect` (e.g. in `pmin<PropagateNumbers>`). This was causing `packetmath_15` to fail on arm. Introducing a simple `F32MaskToBf16Mask` corrects this (takes the lower 16-bits for each float mask).
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h10
1 files changed, 7 insertions, 3 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 709cebe4e..af1b48744 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -3371,6 +3371,10 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
}
+EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) {
+ return vmovn_u32(vreinterpretq_f32_u32(p));
+}
+
template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
return pset1<Packet4us>(from.value);
}
@@ -3528,17 +3532,17 @@ template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a,
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)