aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-02-25 20:39:56 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-02-25 20:39:56 -0800
commite19829c3b0802c01c942fe9d095688f8ce2dcc7b (patch)
tree35b2298f0800a573a156b20cbfcb480943583e24 /Eigen/src/Core/arch/NEON/PacketMath.h
parent5529db7524b93208f3d97f5fadc53aff1de70190 (diff)
Fix floor/ceil for NEON fp16.
Forgot to test this. Fixes bug introduced in !416.
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h112
1 files changed, 56 insertions, 56 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 8a2a14f4d..9715bf4b2 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -4183,62 +4183,6 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, cons
}
template <>
-EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
- const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
- // If greater, substract one.
- uint16x8_t mask = vcgtq_f16(tmp, a);
- mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
- tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
- // Handle saturation cases.
- EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
- const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
- // If greater, substract one.
- uint16x4_t mask = vcgt_f16(tmp, a);
- mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
- tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
- // Handle saturation cases.
- EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
- const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
- // If smaller, add one.
- uint16x8_t mask = vcltq_f16(tmp, a);
- mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
- tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
- // Handle saturation cases.
- EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
- const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
- // If smaller, add one.
- uint16x4_t mask = vclt_f16(tmp, a);
- mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
- tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
- // Handle saturation cases.
- EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
return vsqrtq_f16(a);
}
@@ -4473,6 +4417,62 @@ EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
}
template <>
+EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
+ const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
+ // If greater, substract one.
+ uint16x8_t mask = vcgtq_f16(tmp, a);
+ mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
+ tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
+ // Handle saturation cases.
+ const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
+ const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
+ // If greater, substract one.
+ uint16x4_t mask = vcgt_f16(tmp, a);
+ mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
+ tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
+ // Handle saturation cases.
+ const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
+ const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
+ // If smaller, add one.
+ uint16x8_t mask = vcltq_f16(tmp, a);
+ mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
+ tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
+ // Handle saturation cases.
+ const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
+ const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
+ // If smaller, add one.
+ uint16x4_t mask = vclt_f16(tmp, a);
+ mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
+ tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
+ // Handle saturation cases.
+ const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
float16x4_t a_lo, a_hi, sum;