diff options
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/NEON/PacketMath.h | 187 |
1 files changed, 136 insertions, 51 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index b2170e9f7..8a2a14f4d 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -194,7 +194,8 @@ struct packet_traits<float> : default_packet_traits HasBlend = 0, HasDiv = 1, - HasFloor = 0, + HasFloor = 1, + HasCeil = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, @@ -1462,32 +1463,6 @@ template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt_or_nan<Packet2f>(const Packet2f& template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan<Packet4f>(const Packet4f& a, const Packet4f& b) { return vreinterpretq_f32_u32(vmvnq_u32(vcgeq_f32(a,b))); } -// WARNING: this pfloor implementation makes sense for inputs that fit in -// signed int32 integers (up to ~2.14e9), hence this is currently only used -// by pexp and not exposed through HasFloor. -template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a) -{ - const Packet2f cst_1 = pset1<Packet2f>(1.0f); - /* perform a floorf */ - Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a)); - - /* if greater, substract 1 */ - Packet2ui mask = vcgt_f32(tmp, a); - mask = vand_u32(mask, vreinterpret_u32_f32(cst_1)); - return vsub_f32(tmp, vreinterpret_f32_u32(mask)); -} -template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) -{ - const Packet4f cst_1 = pset1<Packet4f>(1.0f); - /* perform a floorf */ - Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a)); - - /* if greater, substract 1 */ - Packet4ui mask = vcgtq_f32(tmp, a); - mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1)); - return vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); -} - // Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics template<> EIGEN_STRONG_INLINE Packet2f pand<Packet2f>(const Packet2f& a, const Packet2f& b) { return vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); } @@ -3206,6 +3181,63 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b) { return vbslq_u64(mask, a, b); } + +template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a) +{ + const Packet2f cst_1 = pset1<Packet2f>(1.0f); + // Round to nearest. + Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a)); + // If greater, subtract one. + Packet2ui mask = vcgt_f32(tmp, a); + mask = vand_u32(mask, vreinterpret_u32_f32(cst_1)); + tmp = vsub_f32(tmp, vreinterpret_f32_u32(mask)); + // Handle saturation cases. + const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a) +{ + const Packet4f cst_1 = pset1<Packet4f>(1.0f); + // Round to nearest. + Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a)); + // If greater, subtract one. + Packet4ui mask = vcgtq_f32(tmp, a); + mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1)); + tmp = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + // Handle saturation cases. + const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a) +{ + const Packet2f cst_1 = pset1<Packet2f>(1.0f); + // Round to nearest. + Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a)); + // If smaller, add one. + Packet2ui mask = vclt_f32(tmp, a); + mask = vand_u32(mask, vreinterpret_u32_f32(cst_1)); + tmp = vadd_f32(tmp, vreinterpret_f32_u32(mask)); + // Handle saturation cases. + const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) +{ + const Packet4f cst_1 = pset1<Packet4f>(1.0f); + // Round to nearest. + Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a)); + // If smaller, add one. + Packet4ui mask = vcltq_f32(tmp, a); + mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1)); + tmp = vaddq_f32(tmp, vreinterpretq_f32_u32(mask)); + // Handle saturation cases. + const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + /** * Computes the integer square root * @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result @@ -3336,6 +3368,7 @@ template<> struct packet_traits<bfloat16> : default_packet_traits HasBlend = 0, HasDiv = 1, HasFloor = 1, + HasCeil = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, @@ -3502,6 +3535,11 @@ template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a) return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a))); } +template<> EIGEN_STRONG_INLINE Packet4bf pceil<Packet4bf>(const Packet4bf& a) +{ + return F32ToBf16(pceil<Packet4f>(Bf16ToF32(a))); +} + template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; } template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) { @@ -3676,7 +3714,8 @@ template<> struct packet_traits<double> : default_packet_traits HasBlend = 0, HasDiv = 1, - HasFloor = 0, + HasFloor = 1, + HasCeil = 1, HasSin = 0, HasCos = 0, @@ -3754,21 +3793,6 @@ template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmax<Packet2d>(a, b); } -// WARNING: this pfloor implementation makes sense for inputs that fit in -// signed int64 integers (up to ~9.22e18), hence this is currently only used -// by pexp and not exposed through HasFloor. -template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) -{ - const Packet2d cst_1 = pset1<Packet2d>(1.0); - /* perform a floorf */ - const Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a)); - - /* if greater, substract 1 */ - uint64x2_t mask = vcgtq_f64(tmp, a); - mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1)); - return vsubq_f64(tmp, vreinterpretq_f64_u64(mask)); -} - // Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); } @@ -3872,6 +3896,34 @@ ptranspose(PacketBlock<Packet2d, 2>& kernel) template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a) +{ + const Packet2d cst_1 = pset1<Packet2d>(1.0); + // Round to nearest. + Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a)); + // If greater, substract 1. + uint64x2_t mask = vcgtq_f64(tmp, a); + mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1)); + tmp = vsubq_f64(tmp, vreinterpretq_f64_u64(mask)); + // Handle saturation cases. + const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) +{ + const Packet2d cst_1 = pset1<Packet2d>(1.0); + // Round to nearest. + Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a)); + // If smaller, add one. + uint64x2_t mask = vcltq_f64(tmp, a); + mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1)); + tmp = vaddq_f64(tmp, vreinterpretq_f64_u64(mask)); + // Handle saturation cases. + const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) { return pldexp_generic(a, exponent); } @@ -3920,6 +3972,7 @@ struct packet_traits<Eigen::half> : default_packet_traits { HasReduxp = 1, HasDiv = 1, HasFloor = 1, + HasCeil = 1, HasSin = 0, HasCos = 0, HasLog = 0, @@ -4132,25 +4185,57 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, cons template <> EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) { const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f)); - /* perform a floorf */ + // Round to nearest. Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); - - /* if greater, substract 1 */ + // If greater, substract one. uint16x8_t mask = vcgtq_f16(tmp, a); mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); - return vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); + tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask)); + // Handle saturation cases. + EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); } template <> EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) { const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f)); - /* perform a floorf */ + // Round to nearest. Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); - - /* if greater, substract 1 */ + // If greater, substract one. uint16x4_t mask = vcgt_f16(tmp, a); mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); - return vsub_f16(tmp, vreinterpret_f16_u16(mask)); + tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask)); + // Handle saturation cases. + EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) { + const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f)); + // Round to nearest. + Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a)); + // If smaller, add one. + uint16x8_t mask = vcltq_f16(tmp, a); + mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1)); + tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask)); + // Handle saturation cases. + EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) { + const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f)); + // Round to nearest. + Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a)); + // If smaller, add one. + uint16x4_t mask = vclt_f16(tmp, a); + mask = vand_u16(mask, vreinterpret_u16_f16(cst_1)); + tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask)); + // Handle saturation cases. + EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest())); + return pselect(pcmp_lt(pabs(a), cst_max), tmp, a); } template <> |