diff options
author | Sheng Yang <yang.sheng@intel.com> | 2020-07-14 01:34:03 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-07-14 01:34:03 +0000 |
commit | 56b3e3f3f8ca9972ca390c8296fde363bdab271c (patch) | |
tree | 5d06bf0995ed07dd232e346369e71f70561b5d9c /Eigen/src/Core/arch/AVX512/PacketMath.h | |
parent | 4ab32e2de2511746e2108563a43cbbeb1922fbf2 (diff) |
AVX path for BF16
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/PacketMath.h | 73 |
1 files changed, 34 insertions, 39 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index f55a50596..2b6693eed 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -1633,13 +1633,13 @@ template <> struct packet_traits<bfloat16> : default_packet_traits { typedef Packet16bf type; // There is no half-size packet for current Packet16bf. - // TODO: support as SSE/AVX path. - typedef Packet16bf half; + // TODO: support as SSE path. + typedef Packet8bf half; enum { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, HasBlend = 0, HasInsert = 1, HasSin = EIGEN_FAST_MATH, @@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf> { typedef bfloat16 type; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; - typedef Packet16bf half; + typedef Packet8bf half; }; template <> @@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16)); } -// Convert float to bfloat16 according to round-to-even/denormals alogrithm. +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { Packet16bf r; // Flush input denormals value to zero with hardware capability. _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#if defined(EIGEN_VECTORIZE_AVX512DQ) __m512 flush = _mm512_and_ps(a, a); +#else + __m512 flush = _mm512_max_ps(a, a); +#endif // EIGEN_VECTORIZE_AVX512DQ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); #if defined(EIGEN_VECTORIZE_AVX512BF16) @@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { // output.value = static_cast<uint16_t>(input); r.i = _mm512_cvtepi32_epi16(t); -#endif +#endif // EIGEN_VECTORIZE_AVX512BF16 return r; } @@ -1912,6 +1916,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a, } template <> +EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) { + Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1); + return padd<Packet8bf>(lane0, lane1); +} + +template <> EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) { return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p))); } @@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { // Swap hi and lo first because shuffle is in 128-bit lanes. res.i = _mm256_permute2x128_si256(a.i, a.i, 1); // Shuffle 8-bit values in src within 2*128-bit lanes. - res.i = _mm256_shuffle_epi8(a.i, m); + res.i = _mm256_shuffle_epi8(res.i, m); return res; } @@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) { __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); // NOTE: no unpacklo/hi instr in this case, so using permute instr. - kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); - kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); + kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); } EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) { |