diff options
Diffstat (limited to 'Eigen/src/Core/arch/AVX/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX/PacketMath.h | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index b17c34015..cf7146cbc 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -32,11 +32,13 @@ typedef __m256 Packet8f; typedef __m256i Packet8i; typedef __m256d Packet4d; typedef eigen_packet_wrapper<__m128i, 2> Packet8h; +typedef eigen_packet_wrapper<__m128i, 3> Packet8bf; template<> struct is_arithmetic<__m256> { enum { value = true }; }; template<> struct is_arithmetic<__m256i> { enum { value = true }; }; template<> struct is_arithmetic<__m256d> { enum { value = true }; }; template<> struct is_arithmetic<Packet8h> { enum { value = true }; }; +template<> struct is_arithmetic<Packet8bf> { enum { value = true }; }; #define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \ const Packet8f p8f_##NAME = pset1<Packet8f>(X) @@ -134,6 +136,40 @@ struct packet_traits<Eigen::half> : default_packet_traits { HasBlend = 0 }; }; + +template <> +struct packet_traits<bfloat16> : default_packet_traits { + typedef Packet8bf type; + // There is no half-size packet for current Packet8bf. + // TODO: support as SSE path. + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasCmp = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasNdtri = 1, + HasBessel = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; +}; #endif template<> struct scalar_div_cost<float,true> { enum { value = 14 }; }; @@ -165,6 +201,14 @@ template<> struct unpacket_traits<Packet4d> { enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; + +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 16x16 to 8x16. +EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) { + return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), + _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); +} template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); } template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); } @@ -1032,6 +1076,307 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) { kernel.packet[3] = pload<Packet8h>(out[3]); } +EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + __m256i extend = _mm256_cvtepu16_epi32(a); + return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16)); +#else + __m128i lo = _mm_cvtepu16_epi32(a); + __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8)); + __m128i lo_shift = _mm_slli_epi32(lo, 16); + __m128i hi_shift = _mm_slli_epi32(hi, 16); + return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1)); +#endif +} + +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { + Packet8bf r; + + // Flush input denormals value to zero with hardware capability. + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); + __m256 flush = _mm256_and_ps(a, a); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); + + __m256i input = _mm256_castps_si256(flush); + +#ifdef EIGEN_VECTORIZE_AVX2 + // uint32_t lsb = (input >> 16); + __m256i t = _mm256_srli_epi32(input, 16); + // uint32_t lsb = lsb & 1; + t = _mm256_and_si256(t, _mm256_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm256_add_epi32(t, input); + // input = input >> 16; + t = _mm256_srli_epi32(t, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m256i nan = _mm256_set1_epi32(0x7fc0); + t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); + // output.value = static_cast<uint16_t>(input); + return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), + _mm256_extractf128_si256(t, 1)); +#else + // uint32_t lsb = (input >> 16); + __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16); + __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16); + // uint32_t lsb = lsb & 1; + lo = _mm_and_si128(lo, _mm_set1_epi32(1)); + hi = _mm_and_si128(hi, _mm_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff)); + hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff)); + // input += rounding_bias; + lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0)); + hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1)); + // input = input >> 16; + lo = _mm_srli_epi32(lo, 16); + hi = _mm_srli_epi32(hi, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m128i nan = _mm_set1_epi32(0x7fc0); + lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); + hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); + // output.value = static_cast<uint16_t>(input); + return _mm_packus_epi32(lo, hi); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) { + return _mm_set1_epi16(from.value); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) { + return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) { + return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) { + return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) { + _mm_store_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploaddup<Packet8bf>(const bfloat16* from) { + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + return _mm_set_epi16(d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploadquad<Packet8bf>(const bfloat16* from) { + unsigned short a = from[0].value; + unsigned short b = from[1].value; + return _mm_set_epi16(b, b, b, b, a, a, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) { + return _mm_cmpeq_epi32(a, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) { + return _mm_or_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) { + return _mm_xor_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) { + return _mm_and_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) { + return _mm_andnot_si128(b,a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) { + return _mm_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a) +{ + return F32ToBf16(pround<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) { + return F32ToBf16(print<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) { + return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) { + return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { + Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000)); + return _mm_xor_si128(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); +} + + +template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride) +{ + return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); +} + +template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride) +{ + EIGEN_ALIGN32 bfloat16 aux[8]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) { + return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) { + return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) { + return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) { + return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm_shuffle_epi8(a,m); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock<Packet8bf,8>& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + __m128i e = kernel.packet[4]; + __m128i f = kernel.packet[5]; + __m128i g = kernel.packet[6]; + __m128i h = kernel.packet[7]; + + __m128i a03b03 = _mm_unpacklo_epi16(a, b); + __m128i c03d03 = _mm_unpacklo_epi16(c, d); + __m128i e03f03 = _mm_unpacklo_epi16(e, f); + __m128i g03h03 = _mm_unpacklo_epi16(g, h); + __m128i a47b47 = _mm_unpackhi_epi16(a, b); + __m128i c47d47 = _mm_unpackhi_epi16(c, d); + __m128i e47f47 = _mm_unpackhi_epi16(e, f); + __m128i g47h47 = _mm_unpackhi_epi16(g, h); + + __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); + __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); + __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); + __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); + __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); + __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); + __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); + __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); + + kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); + kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock<Packet8bf,4>& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + + __m128i ab_03 = _mm_unpacklo_epi16(a, b); + __m128i cd_03 = _mm_unpacklo_epi16(c, d); + __m128i ab_47 = _mm_unpackhi_epi16(a, b); + __m128i cd_47 = _mm_unpackhi_epi16(c, d); + + kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03); + kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03); + kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47); + kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47); +} + } // end namespace internal } // end namespace Eigen |