aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h10
1 files changed, 7 insertions, 3 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 709cebe4e..af1b48744 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -3371,6 +3371,10 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p)
return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16));
}
+EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) {
+ return vmovn_u32(vreinterpretq_f32_u32(p));
+}
+
template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) {
return pset1<Packet4us>(from.value);
}
@@ -3528,17 +3532,17 @@ template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a,
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b)
{
- return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
+ return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a)