aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-02-25 14:29:49 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2021-02-27 22:42:07 +0000
commit1e0c7d4f4933b12a325dbaa2c79ce946bb13f7d6 (patch)
tree30d28ba3618296434df793ad25d06f68e6c98d65 /Eigen/src/Core/arch/NEON/PacketMath.h
parent976ae0ca6f381a855daddcba73de72737be2e8a7 (diff)
Add print for SSE/NEON, use NEON rounding intrinsics if available.
In SSE, by adding/subtracting 2^MantissaBits, we force rounding according to the current rounding mode. For NEON, we use the provided intrinsics for rint/floor/ceil if available (armv8). Related to #1969.
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h217
1 files changed, 100 insertions, 117 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index f77a18a4f..ec6ea90c5 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -196,6 +196,7 @@ struct packet_traits<float> : default_packet_traits
HasDiv = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -3182,63 +3183,88 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b)
{ return vbslq_u64(mask, a, b); }
+// Use armv8 rounding intinsics if available.
+#if EIGEN_ARCH_ARMV8
+template<> EIGEN_STRONG_INLINE Packet2f print<Packet2f>(const Packet2f& a)
+{ return vrndn_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a)
+{ return vrndnq_f32(a); }
template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
-{
- const Packet2f cst_1 = pset1<Packet2f>(1.0f);
- // Round to nearest.
- Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a));
- // If greater, subtract one.
- Packet2ui mask = vcgt_f32(tmp, a);
- mask = vand_u32(mask, vreinterpret_u32_f32(cst_1));
- tmp = vsub_f32(tmp, vreinterpret_f32_u32(mask));
- // Handle saturation cases.
- const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+{ return vrndm_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
+{ return vrndmq_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{ return vrndp_f32(a); }
+
+template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
+{ return vrndpq_f32(a); }
+
+#else
+
+template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet4f offset =
+ pselect(pcmp_lt(a, pzero(a)),
+ pset1<Packet4f>(-static_cast<float>(1<<23)),
+ pset1<Packet4f>(+static_cast<float>(1<<23)));
+ return psub(padd(a, offset), offset);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) {
+ // Adds and subtracts signum(a) * 2^23 to force rounding.
+ const Packet2f offset =
+ pselect(pcmp_lt(a, pzero(a)),
+ pset1<Packet2f>(-static_cast<float>(1<<23)),
+ pset1<Packet2f>(+static_cast<float>(1<<23)));
+ return psub(padd(a, offset), offset);
}
template<> EIGEN_STRONG_INLINE Packet4f pfloor<Packet4f>(const Packet4f& a)
{
const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- // Round to nearest.
- Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a));
+ Packet4f tmp = print<Packet4f>(a);
// If greater, subtract one.
- Packet4ui mask = vcgtq_f32(tmp, a);
- mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1));
- tmp = vsubq_f32(tmp, vreinterpretq_f32_u32(mask));
- // Handle saturation cases.
- const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+ Packet4f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
}
-template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+template<> EIGEN_STRONG_INLINE Packet2f pfloor<Packet2f>(const Packet2f& a)
{
const Packet2f cst_1 = pset1<Packet2f>(1.0f);
- // Round to nearest.
- Packet2f tmp = vcvt_f32_s32(vcvt_s32_f32(a));
- // If smaller, add one.
- Packet2ui mask = vclt_f32(tmp, a);
- mask = vand_u32(mask, vreinterpret_u32_f32(cst_1));
- tmp = vadd_f32(tmp, vreinterpret_f32_u32(mask));
- // Handle saturation cases.
- const Packet2f cst_max = pset1<Packet2f>(static_cast<float>(NumTraits<int32_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+ Packet2f tmp = print<Packet2f>(a);
+ // If greater, subtract one.
+ Packet2f mask = pcmp_lt(a, tmp);
+ mask = pand(mask, cst_1);
+ return psub(tmp, mask);
}
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a)
{
const Packet4f cst_1 = pset1<Packet4f>(1.0f);
- // Round to nearest.
- Packet4f tmp = vcvtq_f32_s32(vcvtq_s32_f32(a));
+ Packet4f tmp = print<Packet4f>(a);
+ // If smaller, add one.
+ Packet4f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet2f pceil<Packet2f>(const Packet2f& a)
+{
+ const Packet2f cst_1 = pset1<Packet2f>(1.0);
+ Packet2f tmp = print<Packet2f>(a);
// If smaller, add one.
- Packet4ui mask = vcltq_f32(tmp, a);
- mask = vandq_u32(mask, vreinterpretq_u32_f32(cst_1));
- tmp = vaddq_f32(tmp, vreinterpretq_f32_u32(mask));
- // Handle saturation cases.
- const Packet4f cst_max = pset1<Packet4f>(static_cast<float>(NumTraits<int32_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
+ Packet2f mask = pcmp_lt(tmp, a);
+ mask = pand(mask, cst_1);
+ return padd(tmp, mask);
}
+#endif
+
/**
* Computes the integer square root
* @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
@@ -3404,6 +3430,7 @@ template<> struct packet_traits<bfloat16> : default_packet_traits
HasDiv = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
@@ -3565,6 +3592,11 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4bf pselect(const Packet4
return pselect<Packet4us>(mask, a, b);
}
+template<> EIGEN_STRONG_INLINE Packet4bf print<Packet4bf>(const Packet4bf& a)
+{
+ return F32ToBf16(print<Packet4f>(Bf16ToF32(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet4bf pfloor<Packet4bf>(const Packet4bf& a)
{
return F32ToBf16(pfloor<Packet4f>(Bf16ToF32(a)));
@@ -3751,6 +3783,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasDiv = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasSin = 0,
HasCos = 0,
@@ -3932,33 +3965,14 @@ ptranspose(PacketBlock<Packet2d, 2>& kernel)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b)
{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a)
+{ return vrndnq_f64(a); }
+
template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
-{
- const Packet2d cst_1 = pset1<Packet2d>(1.0);
- // Round to nearest.
- Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
- // If greater, substract 1.
- uint64x2_t mask = vcgtq_f64(tmp, a);
- mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
- tmp = vsubq_f64(tmp, vreinterpretq_f64_u64(mask));
- // Handle saturation cases.
- const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
+{ return vrndmq_f64(a); }
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a)
-{
- const Packet2d cst_1 = pset1<Packet2d>(1.0);
- // Round to nearest.
- Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
- // If smaller, add one.
- uint64x2_t mask = vcltq_f64(tmp, a);
- mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
- tmp = vaddq_f64(tmp, vreinterpretq_f64_u64(mask));
- // Handle saturation cases.
- const Packet2d cst_max = pset1<Packet2d>(static_cast<double>(NumTraits<int64_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
+{ return vrndpq_f64(a); }
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
{ return pldexp_generic(a, exponent); }
@@ -4020,6 +4034,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasDiv = 1,
HasFloor = 1,
HasCeil = 1,
+ HasRint = 1,
HasSin = 0,
HasCos = 0,
HasLog = 0,
@@ -4231,6 +4246,30 @@ EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, cons
}
template <>
+EIGEN_STRONG_INLINE Packet8hf print<Packet8hf>(const Packet8hf& a)
+{ return vrndnq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf print<Packet4hf>(const Packet4hf& a)
+{ return vrndn_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a)
+{ return vrndmq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a)
+{ return vrndm_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a)
+{ return vrndpq_f16(a); }
+
+template <>
+EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a)
+{ return vrndp_f16(a); }
+
+template <>
EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
return vsqrtq_f16(a);
}
@@ -4465,62 +4504,6 @@ EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
}
template <>
-EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
- const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
- // If greater, substract one.
- uint16x8_t mask = vcgtq_f16(tmp, a);
- mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
- tmp = vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
- // Handle saturation cases.
- const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
- const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
- // If greater, substract one.
- uint16x4_t mask = vcgt_f16(tmp, a);
- mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
- tmp = vsub_f16(tmp, vreinterpret_f16_u16(mask));
- // Handle saturation cases.
- const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet8hf pceil<Packet8hf>(const Packet8hf& a) {
- const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
- // If smaller, add one.
- uint16x8_t mask = vcltq_f16(tmp, a);
- mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
- tmp = vaddq_f16(tmp, vreinterpretq_f16_u16(mask));
- // Handle saturation cases.
- const Packet8hf cst_max = pset1<Packet8hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
-EIGEN_STRONG_INLINE Packet4hf pceil<Packet4hf>(const Packet4hf& a) {
- const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
- // Round to nearest.
- Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
- // If smaller, add one.
- uint16x4_t mask = vclt_f16(tmp, a);
- mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
- tmp = vadd_f16(tmp, vreinterpret_f16_u16(mask));
- // Handle saturation cases.
- const Packet4hf cst_max = pset1<Packet4hf>(static_cast<Eigen::half>(NumTraits<int16_t>::highest()));
- return pselect(pcmp_lt(pabs(a), cst_max), tmp, a);
-}
-
-template <>
EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
float16x4_t a_lo, a_hi, sum;