From e19829c3b0802c01c942fe9d095688f8ce2dcc7b Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 25 Feb 2021 20:39:56 -0800 Subject: Fix floor/ceil for NEON fp16. Forgot to test this. Fixes bug introduced in !416. --- Eigen/src/Core/arch/NEON/PacketMath.h | 112 +++++++++++++++++----------------- 1 file changed, 56 insertions(+), 56 deletions(-) (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h') diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 8a2a14f4d..9715bf4b2 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -4182,62 +4182,6 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan(const Packet4hf& a, cons return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b))); } -template <> -EIGEN_STRONG_INLINE Packet8hf pfloor(const Packet8hf& a) { - const Packet8hf cst_1 = pset1(Eigen::half(1.0f)); - // Round to nearest. - Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); - // If greater, substract one. - uint16x8_t mask = vcgtq_f16(tmp, a); - mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); - tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); - // Handle saturation cases. - EIGEN_CONSTEXPR Packet8hf cst_max = pset1(static_cast(NumTraits::highest())); - return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); -} - -template <> -EIGEN_STRONG_INLINE Packet4hf pfloor(const Packet4hf& a) { - const Packet4hf cst_1 = pset1(Eigen::half(1.0f)); - // Round to nearest. - Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); - // If greater, substract one. - uint16x4_t mask = vcgt_f16(tmp, a); - mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); - tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask)); - // Handle saturation cases. - EIGEN_CONSTEXPR Packet4hf cst_max = pset1(static_cast(NumTraits::highest())); - return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); -} - -template <> -EIGEN_STRONG_INLINE Packet8hf pceil(const Packet8hf& a) { - const Packet8hf cst_1 = pset1(Eigen::half(1.0f)); - // Round to nearest. - Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); - // If smaller, add one. - uint16x8_t mask = vcltq_f16(tmp, a); - mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); - tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask)); - // Handle saturation cases. - EIGEN_CONSTEXPR Packet8hf cst_max = pset1(static_cast(NumTraits::highest())); - return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); -} - -template <> -EIGEN_STRONG_INLINE Packet4hf pceil(const Packet4hf& a) { - const Packet4hf cst_1 = pset1(Eigen::half(1.0f)); - // Round to nearest. - Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); - // If smaller, add one. - uint16x4_t mask = vclt_f16(tmp, a); - mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); - tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask)); - // Handle saturation cases. - EIGEN_CONSTEXPR Packet4hf cst_max = pset1(static_cast(NumTraits::highest())); - return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); -} - template <> EIGEN_STRONG_INLINE Packet8hf psqrt(const Packet8hf& a) { return vsqrtq_f16(a); @@ -4472,6 +4416,62 @@ EIGEN_STRONG_INLINE Packet4hf pabs(const Packet4hf& a) { return vabs_f16(a); } +template <> +EIGEN_STRONG_INLINE Packet8hf pfloor(const Packet8hf& a) { + const Packet8hf cst_1 = pset1(Eigen::half(1.0f)); + // Round to nearest. + Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); + // If greater, substract one. + uint16x8_t mask = vcgtq_f16(tmp, a); + mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); + tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); + // Handle saturation cases. + const Packet8hf cst_max = pset1(static_cast(NumTraits::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pfloor(const Packet4hf& a) { + const Packet4hf cst_1 = pset1(Eigen::half(1.0f)); + // Round to nearest. + Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); + // If greater, substract one. + uint16x4_t mask = vcgt_f16(tmp, a); + mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); + tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask)); + // Handle saturation cases. + const Packet4hf cst_max = pset1(static_cast(NumTraits::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pceil(const Packet8hf& a) { + const Packet8hf cst_1 = pset1(Eigen::half(1.0f)); + // Round to nearest. + Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); + // If smaller, add one. + uint16x8_t mask = vcltq_f16(tmp, a); + mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); + tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask)); + // Handle saturation cases. + const Packet8hf cst_max = pset1(static_cast(NumTraits::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pceil(const Packet4hf& a) { + const Packet4hf cst_1 = pset1(Eigen::half(1.0f)); + // Round to nearest. + Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); + // If smaller, add one. + uint16x4_t mask = vclt_f16(tmp, a); + mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); + tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask)); + // Handle saturation cases. + const Packet4hf cst_max = pset1(static_cast(NumTraits::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + template <> EIGEN_STRONG_INLINE Eigen::half predux(const Packet8hf& a) { float16x4_t a_lo, a_hi, sum; -- cgit v1.2.3