aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512
diff options
context:
space:
mode:
authorGravatar Sheng Yang <yang.sheng@intel.com>2020-07-14 01:34:03 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-14 01:34:03 +0000
commit56b3e3f3f8ca9972ca390c8296fde363bdab271c (patch)
tree5d06bf0995ed07dd232e346369e71f70561b5d9c /Eigen/src/Core/arch/AVX512
parent4ab32e2de2511746e2108563a43cbbeb1922fbf2 (diff)
AVX path for BF16
Diffstat (limited to 'Eigen/src/Core/arch/AVX512')
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h47
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h73
2 files changed, 43 insertions, 77 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index b86afced6..83af5f5de 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -135,10 +135,7 @@ plog<Packet16f>(const Packet16f& _x) {
p16f_minus_inf);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -264,10 +261,7 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
-template <>
-EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
// prsqrt for double.
#if EIGEN_FAST_MATH
@@ -435,20 +423,14 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
#endif
@@ -461,31 +443,20 @@ psin<Packet16f>(const Packet16f& _x) {
}
template <>
-EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
-EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index f55a50596..2b6693eed 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1633,13 +1633,13 @@ template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
// There is no half-size packet for current Packet16bf.
- // TODO: support as SSE/AVX path.
- typedef Packet16bf half;
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
@@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf>
{
typedef bfloat16 type;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
- typedef Packet16bf half;
+ typedef Packet8bf half;
};
template <>
@@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
}
-// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+#if defined(EIGEN_VECTORIZE_AVX512DQ)
__m512 flush = _mm512_and_ps(a, a);
+#else
+ __m512 flush = _mm512_max_ps(a, a);
+#endif // EIGEN_VECTORIZE_AVX512DQ
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
#if defined(EIGEN_VECTORIZE_AVX512BF16)
@@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
// output.value = static_cast<uint16_t>(input);
r.i = _mm512_cvtepi32_epi16(t);
-#endif
+#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
}
@@ -1912,6 +1916,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
}
template <>
+EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
+ Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
+ return padd<Packet8bf>(lane0, lane1);
+}
+
+template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
@@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
// Swap hi and lo first because shuffle is in 128-bit lanes.
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
- res.i = _mm256_shuffle_epi8(a.i, m);
+ res.i = _mm256_shuffle_epi8(res.i, m);
return res;
}
@@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
- kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {