From 56b3e3f3f8ca9972ca390c8296fde363bdab271c Mon Sep 17 00:00:00 2001 From: Sheng Yang Date: Tue, 14 Jul 2020 01:34:03 +0000 Subject: AVX path for BF16 --- CMakeLists.txt | 6 + Eigen/src/Core/arch/AVX/MathFunctions.h | 33 ++- Eigen/src/Core/arch/AVX/PacketMath.h | 345 +++++++++++++++++++++++++++++ Eigen/src/Core/arch/AVX/TypeCasting.h | 26 +++ Eigen/src/Core/arch/AVX512/MathFunctions.h | 47 +--- Eigen/src/Core/arch/AVX512/PacketMath.h | 73 +++--- Eigen/src/Core/arch/Default/BFloat16.h | 7 + cmake/EigenTesting.cmake | 12 + 8 files changed, 460 insertions(+), 89 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 28103856e..867d068d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,6 +239,12 @@ if(NOT MSVC) message(STATUS "Enabling FMA in tests/examples") endif() + option(EIGEN_TEST_AVX2 "Enable/Disable AVX2 in tests/examples" OFF) + if(EIGEN_TEST_AVX2) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma") + message(STATUS "Enabling AVX2 in tests/examples") + endif() + option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF) if(EIGEN_TEST_AVX512) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma") diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index c5394430f..461696170 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -58,15 +58,15 @@ pexp(const Packet8f& _x) { // Hyperbolic Tangent function. template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f -ptanh(const Packet8f& x) { - return internal::generic_fast_tanh_float(x); +ptanh(const Packet8f& _x) { + return internal::generic_fast_tanh_float(_x); } // Exponential function for doubles. template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d -pexp(const Packet4d& x) { - return pexp_double(x); +pexp(const Packet4d& _x) { + return pexp_double(_x); } // Functions for sqrt. @@ -96,13 +96,13 @@ psqrt(const Packet8f& _x) { } #else template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet8f psqrt(const Packet8f& x) { - return _mm256_sqrt_ps(x); +Packet8f psqrt(const Packet8f& _x) { + return _mm256_sqrt_ps(_x); } #endif template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet4d psqrt(const Packet4d& x) { - return _mm256_sqrt_pd(x); +Packet4d psqrt(const Packet4d& _x) { + return _mm256_sqrt_pd(_x); } #if EIGEN_FAST_MATH @@ -140,18 +140,27 @@ Packet8f prsqrt(const Packet8f& _x) { #else template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet8f prsqrt(const Packet8f& x) { +Packet8f prsqrt(const Packet8f& _x) { _EIGEN_DECLARE_CONST_Packet8f(one, 1.0f); - return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x)); + return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x)); } #endif template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED -Packet4d prsqrt(const Packet4d& x) { +Packet4d prsqrt(const Packet4d& _x) { _EIGEN_DECLARE_CONST_Packet4d(one, 1.0); - return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x)); + return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x)); } +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index b17c34015..cf7146cbc 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -32,11 +32,13 @@ typedef __m256 Packet8f; typedef __m256i Packet8i; typedef __m256d Packet4d; typedef eigen_packet_wrapper<__m128i, 2> Packet8h; +typedef eigen_packet_wrapper<__m128i, 3> Packet8bf; template<> struct is_arithmetic<__m256> { enum { value = true }; }; template<> struct is_arithmetic<__m256i> { enum { value = true }; }; template<> struct is_arithmetic<__m256d> { enum { value = true }; }; template<> struct is_arithmetic { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; #define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \ const Packet8f p8f_##NAME = pset1(X) @@ -134,6 +136,40 @@ struct packet_traits : default_packet_traits { HasBlend = 0 }; }; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8bf type; + // There is no half-size packet for current Packet8bf. + // TODO: support as SSE path. + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasCmp = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasNdtri = 1, + HasBessel = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; +}; #endif template<> struct scalar_div_cost { enum { value = 14 }; }; @@ -165,6 +201,14 @@ template<> struct unpacket_traits { enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; template<> struct unpacket_traits { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; + +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 16x16 to 8x16. +EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) { + return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), + _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); +} template<> EIGEN_STRONG_INLINE Packet8f pset1(const float& from) { return _mm256_set1_ps(from); } template<> EIGEN_STRONG_INLINE Packet4d pset1(const double& from) { return _mm256_set1_pd(from); } @@ -1032,6 +1076,307 @@ ptranspose(PacketBlock& kernel) { kernel.packet[3] = pload(out[3]); } +EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + __m256i extend = _mm256_cvtepu16_epi32(a); + return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16)); +#else + __m128i lo = _mm_cvtepu16_epi32(a); + __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8)); + __m128i lo_shift = _mm_slli_epi32(lo, 16); + __m128i hi_shift = _mm_slli_epi32(hi, 16); + return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1)); +#endif +} + +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { + Packet8bf r; + + // Flush input denormals value to zero with hardware capability. + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); + __m256 flush = _mm256_and_ps(a, a); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); + + __m256i input = _mm256_castps_si256(flush); + +#ifdef EIGEN_VECTORIZE_AVX2 + // uint32_t lsb = (input >> 16); + __m256i t = _mm256_srli_epi32(input, 16); + // uint32_t lsb = lsb & 1; + t = _mm256_and_si256(t, _mm256_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm256_add_epi32(t, input); + // input = input >> 16; + t = _mm256_srli_epi32(t, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m256i nan = _mm256_set1_epi32(0x7fc0); + t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); + // output.value = static_cast(input); + return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), + _mm256_extractf128_si256(t, 1)); +#else + // uint32_t lsb = (input >> 16); + __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16); + __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16); + // uint32_t lsb = lsb & 1; + lo = _mm_and_si128(lo, _mm_set1_epi32(1)); + hi = _mm_and_si128(hi, _mm_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff)); + hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff)); + // input += rounding_bias; + lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0)); + hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1)); + // input = input >> 16; + lo = _mm_srli_epi32(lo, 16); + hi = _mm_srli_epi32(hi, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m128i nan = _mm_set1_epi32(0x7fc0); + lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); + hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); + // output.value = static_cast(input); + return _mm_packus_epi32(lo, hi); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8bf pset1(const bfloat16& from) { + return _mm_set1_epi16(from.value); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& from) { + return bfloat16_impl::raw_uint16_to_bfloat16(static_cast(_mm_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pload(const bfloat16* from) { + return _mm_load_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploadu(const bfloat16* from) { + return _mm_loadu_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet8bf& from) { + _mm_store_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet8bf& from) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploaddup(const bfloat16* from) { + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + return _mm_set_epi16(d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploadquad(const bfloat16* from) { + unsigned short a = from[0].value; + unsigned short b = from[1].value; + return _mm_set_epi16(b, b, b, b, a, a, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) { + return _mm_cmpeq_epi32(a, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + return F32ToBf16(pabs(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmin(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmax(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) { + return _mm_or_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) { + return _mm_xor_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) { + return _mm_and_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) { + return _mm_andnot_si128(b,a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) { + return _mm_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pround(const Packet8bf& a) +{ + return F32ToBf16(pround(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf print(const Packet8bf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pceil(const Packet8bf& a) { + return F32ToBf16(pceil(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pfloor(const Packet8bf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { + Packet8bf sign_mask = _mm_set1_epi16(static_cast(0x8000)); + return _mm_xor_si128(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf padd(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + + +template<> EIGEN_STRONG_INLINE Packet8bf pgather(const bfloat16* from, Index stride) +{ + return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); +} + +template<> EIGEN_STRONG_INLINE void pscatter(bfloat16* to, const Packet8bf& from, Index stride) +{ + EIGEN_ALIGN32 bfloat16 aux[8]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm_shuffle_epi8(a,m); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + __m128i e = kernel.packet[4]; + __m128i f = kernel.packet[5]; + __m128i g = kernel.packet[6]; + __m128i h = kernel.packet[7]; + + __m128i a03b03 = _mm_unpacklo_epi16(a, b); + __m128i c03d03 = _mm_unpacklo_epi16(c, d); + __m128i e03f03 = _mm_unpacklo_epi16(e, f); + __m128i g03h03 = _mm_unpacklo_epi16(g, h); + __m128i a47b47 = _mm_unpackhi_epi16(a, b); + __m128i c47d47 = _mm_unpackhi_epi16(c, d); + __m128i e47f47 = _mm_unpackhi_epi16(e, f); + __m128i g47h47 = _mm_unpackhi_epi16(g, h); + + __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); + __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); + __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); + __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); + __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); + __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); + __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); + __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); + + kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); + kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + + __m128i ab_03 = _mm_unpacklo_epi16(a, b); + __m128i cd_03 = _mm_unpacklo_epi16(c, d); + __m128i ab_47 = _mm_unpackhi_epi16(a, b); + __m128i cd_47 = _mm_unpackhi_epi16(c, d); + + kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03); + kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03); + kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47); + kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h index 181043588..c669a7f60 100644 --- a/Eigen/src/Core/arch/AVX/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -76,12 +76,38 @@ struct type_casting_traits { }; }; +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8bf& a) { + return Bf16ToF32(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + #endif // EIGEN_VECTORIZE_AVX512 template<> EIGEN_STRONG_INLINE Packet8h pcast(const Packet8f& a) { return float2half(a); } +template<> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8f& a) { + return F32ToBf16(a); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index b86afced6..83af5f5de 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -135,10 +135,7 @@ plog(const Packet16f& _x) { p16f_minus_inf); } -template <> -EIGEN_STRONG_INLINE Packet16bf plog(const Packet16bf& _x) { - return F32ToBf16(plog(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) #endif // Exponential function. Works by writing "x = m*log(2) + r" where @@ -264,10 +261,7 @@ pexp(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ -template <> -EIGEN_STRONG_INLINE Packet16bf pexp(const Packet16bf& _x) { - return F32ToBf16(pexp(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) // Functions for sqrt. // The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step @@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt(const Packet8d& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf psqrt(const Packet16bf& x) { - return F32ToBf16(psqrt(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) // prsqrt for float. #if defined(EIGEN_VECTORIZE_AVX512ER) @@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { } #endif -template <> -EIGEN_STRONG_INLINE Packet16bf prsqrt(const Packet16bf& x) { - return F32ToBf16(prsqrt(Bf16ToF32(x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) // prsqrt for double. #if EIGEN_FAST_MATH @@ -435,20 +423,14 @@ Packet16f plog1p(const Packet16f& _x) { return generic_plog1p(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf plog1p(const Packet16bf& _x) { - return F32ToBf16(plog1p(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pexpm1(const Packet16f& _x) { return generic_expm1(_x); } -template<> -EIGEN_STRONG_INLINE Packet16bf pexpm1(const Packet16bf& _x) { - return F32ToBf16(pexpm1(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) #endif #endif @@ -460,32 +442,21 @@ psin(const Packet16f& _x) { return psin_float(_x); } -template <> -EIGEN_STRONG_INLINE Packet16bf psin(const Packet16bf& _x) { - return F32ToBf16(psin(Bf16ToF32(_x))); -} - template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f pcos(const Packet16f& _x) { return pcos_float(_x); } -template <> -EIGEN_STRONG_INLINE Packet16bf pcos(const Packet16bf& _x) { - return F32ToBf16(pcos(Bf16ToF32(_x))); -} - template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f ptanh(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } -template <> -EIGEN_STRONG_INLINE Packet16bf ptanh(const Packet16bf& _x) { - return F32ToBf16(ptanh(Bf16ToF32(_x))); -} +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index f55a50596..2b6693eed 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -1633,13 +1633,13 @@ template <> struct packet_traits : default_packet_traits { typedef Packet16bf type; // There is no half-size packet for current Packet16bf. - // TODO: support as SSE/AVX path. - typedef Packet16bf half; + // TODO: support as SSE path. + typedef Packet8bf half; enum { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, HasBlend = 0, HasInsert = 1, HasSin = EIGEN_FAST_MATH, @@ -1668,7 +1668,7 @@ struct unpacket_traits { typedef bfloat16 type; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; - typedef Packet16bf half; + typedef Packet8bf half; }; template <> @@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16)); } -// Convert float to bfloat16 according to round-to-even/denormals alogrithm. +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { Packet16bf r; // Flush input denormals value to zero with hardware capability. _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#if defined(EIGEN_VECTORIZE_AVX512DQ) __m512 flush = _mm512_and_ps(a, a); +#else + __m512 flush = _mm512_max_ps(a, a); +#endif // EIGEN_VECTORIZE_AVX512DQ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); #if defined(EIGEN_VECTORIZE_AVX512BF16) @@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { // output.value = static_cast(input); r.i = _mm512_cvtepi32_epi16(t); -#endif +#endif // EIGEN_VECTORIZE_AVX512BF16 return r; } @@ -1911,6 +1915,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax(const Packet16bf& a, return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); } +template <> +EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4(const Packet16bf& a) { + Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1); + return padd(lane0, lane1); +} + template <> EIGEN_STRONG_INLINE bfloat16 predux(const Packet16bf& p) { return static_cast(predux(Bf16ToF32(p))); @@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { // Swap hi and lo first because shuffle is in 128-bit lanes. res.i = _mm256_permute2x128_si256(a.i, a.i, 1); // Shuffle 8-bit values in src within 2*128-bit lanes. - res.i = _mm256_shuffle_epi8(a.i, m); + res.i = _mm256_shuffle_epi8(res.i, m); return res; } @@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); // NOTE: no unpacklo/hi instr in this case, so using permute instr. - kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); - kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, - 0x20); - kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, - 0x20); - kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, - 0x20); - kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, - 0x20); - kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, - 0x20); - kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, - 0x20); - kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, - 0x20); - kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, - 0x20); + kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); } EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index abf2ac933..34a4f0ced 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -23,6 +23,13 @@ limitations under the License. #define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type() #endif +#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \ + template <> \ + EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \ + PACKET_BF16 METHOD(const PACKET_BF16& _x) { \ + return F32ToBf16(METHOD(Bf16ToF32(_x))); \ + } + namespace Eigen { struct bfloat16; diff --git a/cmake/EigenTesting.cmake b/cmake/EigenTesting.cmake index 524223717..c98393e5e 100644 --- a/cmake/EigenTesting.cmake +++ b/cmake/EigenTesting.cmake @@ -310,6 +310,12 @@ macro(ei_testing_print_summary) message(STATUS "AVX: Using architecture defaults") endif() + if(EIGEN_TEST_AVX2) + message(STATUS "AVX2: ON") + else() + message(STATUS "AVX2: Using architecture defaults") + endif() + if(EIGEN_TEST_FMA) message(STATUS "FMA: ON") else() @@ -322,6 +328,12 @@ macro(ei_testing_print_summary) message(STATUS "AVX512: Using architecture defaults") endif() + if(EIGEN_TEST_AVX512DQ) + message(STATUS "AVX512DQ: ON") + else() + message(STATUS "AVX512DQ: Using architecture defaults") + endif() + if(EIGEN_TEST_ALTIVEC) message(STATUS "Altivec: ON") else() -- cgit v1.2.3