diff options
author | Sheng Yang <yang.sheng@intel.com> | 2020-07-14 01:34:03 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-07-14 01:34:03 +0000 |
commit | 56b3e3f3f8ca9972ca390c8296fde363bdab271c (patch) | |
tree | 5d06bf0995ed07dd232e346369e71f70561b5d9c /Eigen/src/Core/arch/AVX512 | |
parent | 4ab32e2de2511746e2108563a43cbbeb1922fbf2 (diff) |
AVX path for BF16
Diffstat (limited to 'Eigen/src/Core/arch/AVX512')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 47 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX512/PacketMath.h | 73 |
2 files changed, 43 insertions, 77 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index b86afced6..83af5f5de 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -135,10 +135,7 @@ plog<Packet16f>(const Packet16f& _x) { p16f_minus_inf); } -template <> -EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) #endif // Exponential function. Works by writing "x = m*log(2) + r" where @@ -264,10 +261,7 @@ pexp<Packet8d>(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ -template <> -EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step @@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) { - return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) // prsqrt for float. #if defined(EIGEN_VECTORIZE_AVX512ER) @@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) { - return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) // prsqrt for double. #if EIGEN_FAST_MATH @@ -435,20 +423,14 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) { return generic_plog1p(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pexpm1<Packet16f>(const Packet16f& _x) { return generic_expm1(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) #endif #endif @@ -461,31 +443,20 @@ psin<Packet16f>(const Packet16f& _x) { } template <> -EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x))); -} - -template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pcos<Packet16f>(const Packet16f& _x) { return pcos_float(_x); } template <> -EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x))); -} - -template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f ptanh<Packet16f>(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } -template <> -EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) { - return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index f55a50596..2b6693eed 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -1633,13 +1633,13 @@ template <> struct packet_traits<bfloat16> : default_packet_traits { typedef Packet16bf type; // There is no half-size packet for current Packet16bf. - // TODO: support as SSE/AVX path. - typedef Packet16bf half; + // TODO: support as SSE path. + typedef Packet8bf half; enum { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, HasBlend = 0, HasInsert = 1, HasSin = EIGEN_FAST_MATH, @@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf> { typedef bfloat16 type; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; - typedef Packet16bf half; + typedef Packet8bf half; }; template <> @@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16)); } -// Convert float to bfloat16 according to round-to-even/denormals alogrithm. +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { Packet16bf r; // Flush input denormals value to zero with hardware capability. _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#if defined(EIGEN_VECTORIZE_AVX512DQ) __m512 flush = _mm512_and_ps(a, a); +#else + __m512 flush = _mm512_max_ps(a, a); +#endif // EIGEN_VECTORIZE_AVX512DQ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); #if defined(EIGEN_VECTORIZE_AVX512BF16) @@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { // output.value = static_cast<uint16_t>(input); r.i = _mm512_cvtepi32_epi16(t); -#endif +#endif // EIGEN_VECTORIZE_AVX512BF16 return r; } @@ -1912,6 +1916,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a, } template <> +EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) { + Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1); + return padd<Packet8bf>(lane0, lane1); +} + +template <> EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) { return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p))); } @@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { // Swap hi and lo first because shuffle is in 128-bit lanes. res.i = _mm256_permute2x128_si256(a.i, a.i, 1); // Shuffle 8-bit values in src within 2*128-bit lanes. - res.i = _mm256_shuffle_epi8(a.i, m); + res.i = _mm256_shuffle_epi8(res.i, m); return res; } @@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) { __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); // NOTE: no unpacklo/hi instr in this case, so using permute instr. - kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); - kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); + kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); } EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) { |