aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-09 20:05:49 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-09 20:05:49 +0000
commit4e4d3f32d168ed9ce09d950f099a60ddcd11240f (patch)
tree3e52ae5b43c238679f69f3caf4d908d4afb16f13 /Eigen/src/Core
parent7a8d3d5b81cb528f7f084b63686ffb20494053f6 (diff)
Clean up packetmath tests and fix various bugs to make bfloat16 pass (almost) all packetmath tests with SSE, AVX, and AVX512.
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h5
-rw-r--r--Eigen/src/Core/arch/AVX/TypeCasting.h48
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h7
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h27
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h5
5 files changed, 61 insertions, 31 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index cf7146cbc..d5dc6a174 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -1205,6 +1205,11 @@ EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
}
+template <>
+EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) {
+ return F32ToBf16(plset<Packet8f>(static_cast<float>(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
return _mm_or_si128(a,b);
}
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index c669a7f60..d507fb67b 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -35,23 +35,6 @@ struct type_casting_traits<int, float> {
};
-
-template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
- return _mm256_cvttps_epi32(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i& a) {
- return _mm256_cvtepi32_ps(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
- return _mm256_castps_si256(a);
-}
-
-template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Packet8i& a) {
- return _mm256_castsi256_ps(a);
-}
-
#ifndef EIGEN_VECTORIZE_AVX512
template <>
@@ -63,9 +46,6 @@ struct type_casting_traits<Eigen::half, float> {
};
};
-template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
- return half2float(a);
-}
template <>
struct type_casting_traits<float, Eigen::half> {
@@ -85,10 +65,6 @@ struct type_casting_traits<bfloat16, float> {
};
};
-template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
- return Bf16ToF32(a);
-}
-
template <>
struct type_casting_traits<float, bfloat16> {
enum {
@@ -100,6 +76,30 @@ struct type_casting_traits<float, bfloat16> {
#endif // EIGEN_VECTORIZE_AVX512
+template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
+ return _mm256_cvttps_epi32(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i& a) {
+ return _mm256_cvtepi32_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
+ return _mm256_castps_si256(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f,Packet8i>(const Packet8i& a) {
+ return _mm256_castsi256_ps(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
+ return half2float(a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
+
template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 76f3366d7..8b946b3e1 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1626,8 +1626,6 @@ template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
- // There is no half-size packet for current Packet16bf.
- // TODO: support as SSE path.
typedef Packet8bf half;
enum {
Vectorizable = 1,
@@ -1884,6 +1882,11 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
}
template <>
+EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) {
+ return F32ToBf16(plset<Packet16f>(static_cast<float>(a)));
+}
+
+template <>
EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index 7c147ae34..4d5fa1bf8 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -103,8 +103,8 @@ struct numeric_limits<Eigen::bfloat16> {
static const bool has_infinity = true;
static const bool has_quiet_NaN = true;
static const bool has_signaling_NaN = true;
- static const float_denorm_style has_denorm = numeric_limits<float>::has_denorm;
- static const bool has_denorm_loss = numeric_limits<float>::has_denorm_loss;
+ static const float_denorm_style has_denorm = std::denorm_absent;
+ static const bool has_denorm_loss = false;
static const std::float_round_style round_style = numeric_limits<float>::round_style;
static const bool is_iec559 = false;
static const bool is_bounded = true;
@@ -551,18 +551,24 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
}
#if EIGEN_HAS_CXX11_MATH
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
- return bfloat16(::asinh(float(a)));
+ EIGEN_USING_STD(asinhf);
+ return bfloat16(asinhf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
- return bfloat16(::acosh(float(a)));
+ EIGEN_USING_STD(acoshf);
+ return bfloat16(acoshf(float(a)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
- return bfloat16(::atanh(float(a)));
+ EIGEN_USING_STD(atanhf);
+ return bfloat16(atanhf(float(a)));
}
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
return bfloat16(::floorf(float(a)));
}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
+ return bfloat16(::rintf(float(a)));
+}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
return bfloat16(::ceilf(float(a)));
}
@@ -581,6 +587,17 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bf
return f1 < f2 ? b : a;
}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fminf(f1, f2));
+}
+EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
+ const float f1 = static_cast<float>(a);
+ const float f2 = static_cast<float>(b);
+ return bfloat16(::fmaxf(f1, f2));
+}
+
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
os << static_cast<float>(v);
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 1246abeaa..6dbae8cee 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -3373,6 +3373,11 @@ template <> EIGEN_STRONG_INLINE Packet4bf pmax<Packet4bf>(const Packet4bf &a,
return F32ToBf16(pmax<Packet4f>(Bf16ToF32(a), Bf16ToF32(b)));
}
+template<> EIGEN_STRONG_INLINE Packet4bf plset<Packet4bf>(const bfloat16& a)
+{
+ return F32ToBf16(plset<Packet4f>(static_cast<float>(a)));
+}
+
template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) {
return por<Packet4us>(a, b);
}