diff options
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/NEON/PacketMath.h | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 709cebe4e..af1b48744 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3371,6 +3371,10 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p) return reinterpret_cast<Packet4f>(vshlq_n_u32(vmovl_u16(p), 16)); } +EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) { + return vmovn_u32(vreinterpretq_f32_u32(p)); +} + template<> EIGEN_STRONG_INLINE Packet4bf pset1<Packet4bf>(const bfloat16& from) { return pset1<Packet4us>(from.value); } @@ -3528,17 +3532,17 @@ template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff<Packet4bf>(const Packet4bf& a, template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_eq<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_lt<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { - return F32ToBf16(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); + return F32MaskToBf16Mask(pcmp_le<Packet4f>(Bf16ToF32(a), Bf16ToF32(b))); } template<> EIGEN_STRONG_INLINE Packet4bf pnegate<Packet4bf>(const Packet4bf& a) |