aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-02-25 14:29:49 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-02-25 14:39:26 -0800
commit5529db7524b93208f3d97f5fadc53aff1de70190 (patch)
tree776d264bc8af0004bcd5eb6468ddb5c2bb4ea299 /Eigen/src/Core/arch/NEON/PacketMath.h
parentecb7b19dfa6c4bbf7a4068e114a1c86aa88908fe (diff)
Fix SSE/NEON pfloor/pceil for saturated values.
The original will saturate if the input does not fit into an integer type. Here we fix this, returning the input if it doesn't have enough precision to have a fractional part. Also added `pceil` for NEON. Fixes #1969.
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h187
1 files changed, 136 insertions, 51 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index b2170e9f7..8a2a14f4d 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -194,7 +194,8 @@ struct packet_traits<float> : default_packet_traits
HasBlend = 0,
HasDiv = 1,
- HasFloor = 0,
+ HasFloor = 1,
+ HasCeil = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -1462,32 +1463,6 @@ template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt_or_nan<Packet2f>(const Packet2f&
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan<Packet4f>(const Packet4f& a, const Packet4f& b)
{ return vreinterpretq_f32_u32(vmvnq_u32(vcgeq_f32(a,b))); }
-// WARNING: this pfloor implementation makes sense for inputs that fit in
-// signed int32 integers (up to ~2.14e9), hence this is currently only used
-// by pexp and not exposed through HasFloor.
-template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
-{
- const Packet2f cst_1 = pset1<Packet2f>(1.0f);
- /* perform a floorf */
- Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a));
-
- /* if greater, substract 1 */
- Packet2ui mask = vcgt_f32(tmp, a);
- mask = vand_u32(mask, vreinterpret_u32_f32(cst_1));
- return vsub_f32(tmp, vreinterpret_f32_u32(mask));
-}
-template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
-{
- const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- /* perform a floorf */
- Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a));
-
- /* if greater, substract 1 */
- Packet4ui mask = vcgtq_f32(tmp, a);
- mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1));
- return vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
-}
-
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
template<> EIGEN_STRONG_INLINE Packet2f pand<Packet2f>(const Packet2f& a, const Packet2f& b)
{ return vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); }
@@ -3206,6 +3181,63 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b)
{ return vbslq_u64(mask, a, b); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0f);
+ // Round to nearest.
+ Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a));
+ // If greater, subtract one.
+ Packet2ui mask = vcgt_f32(tmp, a);
+ mask = vand_u32(mask, vreinterpret_u32_f32(cst_1));
+ tmp = vsub_f32(tmp, vreinterpret_f32_u32(mask));
+ // Handle saturation cases.
+ const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ // Round to nearest.
+ Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a));
+ // If greater, subtract one.
+ Packet4ui mask = vcgtq_f32(tmp, a);
+ mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1));
+ tmp = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
+ // Handle saturation cases.
+ const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0f);
+ // Round to nearest.
+ Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a));
+ // If smaller, add one.
+ Packet2ui mask = vclt_f32(tmp, a);
+ mask = vand_u32(mask, vreinterpret_u32_f32(cst_1));
+ tmp = vadd_f32(tmp, vreinterpret_f32_u32(mask));
+ // Handle saturation cases.
+ const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{
+ const Packet4f cst_1 = pset1<Packet4f>(1.0f);
+ // Round to nearest.
+ Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a));
+ // If smaller, add one.
+ Packet4ui mask = vcltq_f32(tmp, a);
+ mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1));
+ tmp = vaddq_f32(tmp, vreinterpretq_f32_u32(mask));
+ // Handle saturation cases.
+ const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
/**
* Computes the integer square root
* @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
@@ -3336,6 +3368,7 @@ template<> struct packet_traits<bfloat16> : default_packet_traits
HasBlend = 0,
HasDiv = 1,
HasFloor = 1,
+ HasCeil = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -3502,6 +3535,11 @@ template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
}
+template<> EIGEN_STRONG_INLINE Packet4bf pceil<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(pceil<Packet4f>(Bf16ToF32(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet4bf padd<Packet4bf>(const Packet4bf& a, const Packet4bf& b) {
@@ -3676,7 +3714,8 @@ template<> struct packet_traits<double> : default_packet_traits
HasBlend = 0,
HasDiv = 1,
- HasFloor = 0,
+ HasFloor = 1,
+ HasCeil = 1,
HasSin = 0,
HasCos = 0,
@@ -3754,21 +3793,6 @@ template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const
template<> EIGEN_STRONG_INLINE Packet2d pmax<PropagateNaN, Packet2d>(const Packet2d& a, const Packet2d& b) { return pmax<Packet2d>(a, b); }
-// WARNING: this pfloor implementation makes sense for inputs that fit in
-// signed int64 integers (up to ~9.22e18), hence this is currently only used
-// by pexp and not exposed through HasFloor.
-template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
-{
- const Packet2d cst_1 = pset1<Packet2d>(1.0);
- /* perform a floorf */
- const Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
-
- /* if greater, substract 1 */
- uint64x2_t mask = vcgtq_f64(tmp, a);
- mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
- return vsubq_f64(tmp, vreinterpretq_f64_u64(mask));
-}
-
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b)
{ return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
@@ -3872,6 +3896,34 @@ ptranspose(PacketBlock<Packet2d, 2>& kernel)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b)
{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ // Round to nearest.
+ Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
+ // If greater, substract 1.
+ uint64x2_t mask = vcgtq_f64(tmp, a);
+ mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
+ tmp = vsubq_f64(tmp, vreinterpretq_f64_u64(mask));
+ // Handle saturation cases.
+ const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ // Round to nearest.
+ Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
+ // If smaller, add one.
+ uint64x2_t mask = vcltq_f64(tmp, a);
+ mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
+ tmp = vaddq_f64(tmp, vreinterpretq_f64_u64(mask));
+ // Handle saturation cases.
+ const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
{ return pldexp_generic(a, exponent); }
@@ -3920,6 +3972,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasReduxp = 1,
HasDiv = 1,
HasFloor = 1,
+ HasCeil = 1,
HasSin = 0,
HasCos = 0,
HasLog = 0,
@@ -4132,25 +4185,57 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, cons
template <>
EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
- /* perform a floorf */
+ // Round to nearest.
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
-
- /* if greater, substract 1 */
+ // If greater, substract one.
uint16x8_t mask = vcgtq_f16(tmp, a);
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
- return vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
+ tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
+ // Handle saturation cases.
+ EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
}
template <>
EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
- /* perform a floorf */
+ // Round to nearest.
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
-
- /* if greater, substract 1 */
+ // If greater, substract one.
uint16x4_t mask = vcgt_f16(tmp, a);
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
- return vsub_f16(tmp, vreinterpret_f16_u16(mask));
+ tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
+ // Handle saturation cases.
+ EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
+ const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
+ // If smaller, add one.
+ uint16x8_t mask = vcltq_f16(tmp, a);
+ mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
+ tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
+ // Handle saturation cases.
+ EIGEN_CONSTEXPR Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
+ const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
+ // Round to nearest.
+ Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
+ // If smaller, add one.
+ uint16x4_t mask = vclt_f16(tmp, a);
+ mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
+ tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
+ // Handle saturation cases.
+ EIGEN_CONSTEXPR Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
+ return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
}
template <>