aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512
diff options
context:
space:
mode:
authorGravatar Teng Lu <teng.lu@intel.com>2020-06-20 19:16:24 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-06-20 19:16:24 +0000
commit386d809bde475c65b7940f290efe80e6a05878c4 (patch)
treec38e161a53393d15be0ddb02a7a4e22dec738484 /Eigen/src/Core/arch/AVX512
parent6b9c92fe7eff0dedb031cec38004c9c3667f3057 (diff)
Support BFloat16 in Eigen
Diffstat (limited to 'Eigen/src/Core/arch/AVX512')
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h56
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h516
-rw-r--r--Eigen/src/Core/arch/AVX512/TypeCasting.h26
3 files changed, 585 insertions, 13 deletions
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index 67043d01b..b86afced6 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -29,6 +29,12 @@ namespace internal {
#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \
const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X))
+#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \
+ const Packet16bf p16bf_##NAME = pset1<Packet16bf>(X)
+
+#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \
+ const Packet16bf p16bf_##NAME = preinterpret<Packet16bf,Packet16i>(pset1<Packet16i>(X))
+
// Natural logarithm
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
@@ -128,6 +134,11 @@ plog<Packet16f>(const Packet16f& _x) {
p16f_nan),
p16f_minus_inf);
}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
+}
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -253,6 +264,10 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
+template <>
+EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
+}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -303,12 +318,18 @@ template <>
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
return _mm512_sqrt_ps(x);
}
+
template <>
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
return _mm512_sqrt_pd(x);
}
#endif
+template <>
+EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
+ return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
+}
+
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@@ -316,7 +337,6 @@ template <>
EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return _mm512_rsqrt28_ps(x);
}
-
#elif EIGEN_FAST_MATH
template <>
@@ -347,8 +367,7 @@ prsqrt<Packet16f>(const Packet16f& _x) {
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
- }
-
+}
#else
template <>
@@ -356,9 +375,13 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
_EIGEN_DECLARE_CONST_Packet16f(one, 1.0f);
return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x));
}
-
#endif
+template <>
+EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
+ return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
+}
+
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
@@ -412,10 +435,20 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
+template<>
+EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
+}
+
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
+
+template<>
+EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
+}
#endif
#endif
@@ -428,17 +461,32 @@ psin<Packet16f>(const Packet16f& _x) {
}
template <>
+EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
+}
+
+template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
+EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
+}
+
+template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
+template <>
+EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
+ return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index ad37ad620..ed15e0889 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -362,6 +362,25 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
}
#endif
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 32x16 to 16x16.
+EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
+ // Split data into small pieces and handle with AVX instructions
+ // to guarantee internal order of vector.
+ // Operation:
+ // dst[15:0] := Saturate16(rf[31:0])
+ // dst[31:16] := Saturate16(rf[63:32])
+ // ...
+ // dst[255:240] := Saturate16(rf[255:224])
+ __m256i lo = _mm256_castps_si256(extract256<0>(rf));
+ __m256i hi = _mm256_castps_si256(extract256<1>(rf));
+ __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
+ _mm256_extractf128_si256(lo, 1));
+ __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
+ _mm256_extractf128_si256(hi, 1));
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
+}
+
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
@@ -1342,15 +1361,7 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa
template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
- Packet16f rf = pcmp_eq(af, bf);
- // Pack the 32-bit flags into 16-bits flags.
- __m256i lo = _mm256_castps_si256(extract256<0>(rf));
- __m256i hi = _mm256_castps_si256(extract256<1>(rf));
- __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
- _mm256_extractf128_si256(lo, 1));
- __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
- _mm256_extractf128_si256(hi, 1));
- return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
+ return Pack32To16(pcmp_eq(af, bf));
}
template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
@@ -1607,6 +1618,493 @@ ptranspose(PacketBlock<Packet16h,4>& kernel) {
kernel.packet[3] = pload<Packet16h>(out[3]);
}
+typedef union {
+#ifdef EIGEN_VECTORIZE_AVX512BF16
+ __m256bh bh;
+#endif
+ Packet8i i; // __m256i;
+} Packet16bf;
+
+template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet16bf type;
+ // There is no half-size packet for current Packet16bf.
+ // TODO: support as SSE/AVX path.
+ typedef Packet16bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 16,
+ HasHalfPacket = 0,
+ HasBlend = 0,
+ HasInsert = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
+#ifdef EIGEN_VECTORIZE_AVX512DQ
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+#endif
+ HasExp = 1,
+ HasSqrt = EIGEN_FAST_MATH,
+ HasRsqrt = EIGEN_FAST_MATH,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+#endif
+ HasDiv = 1
+ };
+};
+
+template <>
+struct unpacket_traits<Packet16bf>
+{
+ typedef bfloat16 type;
+ enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
+ typedef Packet16bf half;
+};
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
+ Packet16bf r;
+ r.i = _mm256_set1_epi16(from.value);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
+ bfloat16 t;
+ t.value = static_cast<unsigned short>(_mm256_extract_epi16(from.i, 0));
+ return t;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
+ Packet16bf r;
+ r.i = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
+ Packet16bf r;
+ r.i = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
+ const Packet16bf& from) {
+ _mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i);
+}
+
+template <>
+EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
+ const Packet16bf& from) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i);
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf
+ploaddup<Packet16bf>(const bfloat16* from) {
+ Packet16bf r;
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ unsigned short e = from[4].value;
+ unsigned short f = from[5].value;
+ unsigned short g = from[6].value;
+ unsigned short h = from[7].value;
+ r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
+ return r;
+}
+
+template<> EIGEN_STRONG_INLINE Packet16bf
+ploadquad(const bfloat16* from) {
+ Packet16bf r;
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
+ return r;
+}
+
+EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
+}
+
+// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
+EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
+ Packet16bf r;
+
+ // Flush input denormals value to zero with hardware capability.
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+ __m512 flush = _mm512_and_ps(a, a);
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
+
+#if defined(EIGEN_VECTORIZE_AVX512BF16)
+ r.bh = _mm512_cvtneps_pbh(flush);
+#else
+ __m512i t;
+ __m512i input = _mm512_castps_si512(flush);
+ __m512i nan = _mm512_set1_epi32(0x7fc0);
+
+ // uint32_t lsb = (input >> 16) & 1;
+ t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm512_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm512_srli_epi32(t, 16);
+
+ // Check NaN before converting back to bf16
+ __mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
+ t = _mm512_mask_blend_epi32(mask, nan, t);
+
+ // output.value = static_cast<uint16_t>(input);
+ r.i = _mm512_cvtepi32_epi16(t);
+#endif
+
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
+ Packet16bf r;
+ r.i = ptrue<Packet8i>(a.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
+ Packet16bf r;
+ r.i = por<Packet8i>(a.i, b.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
+ Packet16bf r;
+ r.i = pxor<Packet8i>(a.i, b.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
+ Packet16bf r;
+ r.i = pand<Packet8i>(a.i, b.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
+ const Packet16bf& b) {
+ Packet16bf r;
+ r.i = pandnot<Packet8i>(a.i, b.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
+ const Packet16bf& a,
+ const Packet16bf& b) {
+ // Input mask is expected to be all 0/1, handle it with 8-bit
+ // intrinsic for performance.
+ Packet16bf r;
+ r.i = _mm256_blendv_epi8(b.i, a.i, mask.i);
+ return r;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
+ const Packet16bf& b) {
+ Packet16bf result;
+ result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
+ const Packet16bf& b) {
+ Packet16bf result;
+ result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
+ const Packet16bf& b) {
+ Packet16bf result;
+ result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
+ const Packet16bf& b) {
+ Packet16bf result;
+ result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
+ Packet16bf sign_mask;
+ sign_mask.i = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
+ Packet16bf result;
+ result.i = _mm256_xor_si256(a.i, sign_mask.i);
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
+ return a;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
+ return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
+ const Packet16bf& b) {
+ return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
+ return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
+ return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
+ __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,
+ 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+
+ Packet16bf res;
+ // Swap hi and lo first because shuffle is in 128-bit lanes.
+ res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
+ // Shuffle 8-bit values in src within 2*128-bit lanes.
+ res.i = _mm256_shuffle_epi8(a.i, m);
+ return res;
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
+ Index stride) {
+ Packet16bf result;
+ result.i = _mm256_set_epi16(
+ from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
+ from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
+ from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
+ from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
+ return result;
+}
+
+template <>
+EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
+ const Packet16bf& from,
+ Index stride) {
+ EIGEN_ALIGN64 bfloat16 aux[16];
+ pstore(aux, from);
+ to[stride*0].value = aux[0].value;
+ to[stride*1].value = aux[1].value;
+ to[stride*2].value = aux[2].value;
+ to[stride*3].value = aux[3].value;
+ to[stride*4].value = aux[4].value;
+ to[stride*5].value = aux[5].value;
+ to[stride*6].value = aux[6].value;
+ to[stride*7].value = aux[7].value;
+ to[stride*8].value = aux[8].value;
+ to[stride*9].value = aux[9].value;
+ to[stride*10].value = aux[10].value;
+ to[stride*11].value = aux[11].value;
+ to[stride*12].value = aux[12].value;
+ to[stride*13].value = aux[13].value;
+ to[stride*14].value = aux[14].value;
+ to[stride*15].value = aux[15].value;
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
+ __m256i a = kernel.packet[0].i;
+ __m256i b = kernel.packet[1].i;
+ __m256i c = kernel.packet[2].i;
+ __m256i d = kernel.packet[3].i;
+ __m256i e = kernel.packet[4].i;
+ __m256i f = kernel.packet[5].i;
+ __m256i g = kernel.packet[6].i;
+ __m256i h = kernel.packet[7].i;
+ __m256i i = kernel.packet[8].i;
+ __m256i j = kernel.packet[9].i;
+ __m256i k = kernel.packet[10].i;
+ __m256i l = kernel.packet[11].i;
+ __m256i m = kernel.packet[12].i;
+ __m256i n = kernel.packet[13].i;
+ __m256i o = kernel.packet[14].i;
+ __m256i p = kernel.packet[15].i;
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
+
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
+
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
+
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
+ 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
+ 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
+ 0x20);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
+ 0x20);
+ kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
+ 0x20);
+ kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
+ 0x20);
+ kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
+ 0x20);
+ kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
+ 0x20);
+ kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
+ 0x20);
+ kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
+ 0x20);
+ kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
+ 0x20);
+ kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
+ 0x20);
+ kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
+ 0x20);
+ kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
+ 0x20);
+ kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
+ 0x20);
+ kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
+ 0x20);
+}
+
+EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
+ __m256i a = kernel.packet[0].i;
+ __m256i b = kernel.packet[1].i;
+ __m256i c = kernel.packet[2].i;
+ __m256i d = kernel.packet[3].i;
+
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
+
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
+
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
+}
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h
index a82176941..e643b18a7 100644
--- a/Eigen/src/Core/arch/AVX512/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h
@@ -40,6 +40,32 @@ template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packe
return float2half(a);
}
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
+ return Bf16ToF32(a);
+}
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
+ return F32ToBf16(a);
+}
+
} // end namespace internal
} // end namespace Eigen