aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512
diff options
context:
space:
mode:
authorGravatar Teng Lu <teng.lu@intel.com>2020-07-25 12:28:59 +0800
committerGravatar Teng Lu <teng.lu@intel.com>2020-07-29 02:20:21 +0000
commit3ec4f0b64185ef1ae220dfdad93fe5ca1257bf3f (patch)
treeb14805f3ba6969b528cda8e20232655dfbd5efac /Eigen/src/Core/arch/AVX512
parentb92206676c06c5a139c7c5eaa13455b0d9f40581 (diff)
Fix undefine BF16 union behavior in AVX512.
Diffstat (limited to 'Eigen/src/Core/arch/AVX512')
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h181
1 files changed, 73 insertions, 108 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 2b6693eed..76f3366d7 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -32,6 +32,7 @@ typedef __m512 Packet16f;
typedef __m512i Packet16i;
typedef __m512d Packet8d;
typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
+typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
template <>
struct is_arithmetic<__m512> {
@@ -1620,13 +1621,6 @@ ptranspose(PacketBlock<Packet16h,4>& kernel) {
kernel.packet[3] = pload<Packet16h>(out[3]);
}
-typedef union {
-#ifdef EIGEN_VECTORIZE_AVX512BF16
- __m256bh bh;
-#endif
- Packet8i i; // __m256i;
-} Packet16bf;
-
template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
template <>
@@ -1673,42 +1667,36 @@ struct unpacket_traits<Packet16bf>
template <>
EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
- Packet16bf r;
- r.i = _mm256_set1_epi16(from.value);
- return r;
+ return _mm256_set1_epi16(from.value);
}
template <>
EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
bfloat16 t;
- t.value = static_cast<unsigned short>(_mm256_extract_epi16(from.i, 0));
+ t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
return t;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
- Packet16bf r;
- r.i = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
- return r;
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
- Packet16bf r;
- r.i = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
- return r;
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
const Packet16bf& from) {
- _mm256_store_si256(reinterpret_cast<__m256i*>(to), from.i);
+ _mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
const Packet16bf& from) {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from.i);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
}
template<> EIGEN_STRONG_INLINE Packet16bf
@@ -1722,8 +1710,7 @@ ploaddup<Packet16bf>(const bfloat16* from) {
unsigned short f = from[5].value;
unsigned short g = from[6].value;
unsigned short h = from[7].value;
- r.i = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
- return r;
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
}
template<> EIGEN_STRONG_INLINE Packet16bf
@@ -1733,12 +1720,11 @@ ploadquad(const bfloat16* from) {
unsigned short b = from[1].value;
unsigned short c = from[2].value;
unsigned short d = from[3].value;
- r.i = _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
- return r;
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
}
EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
- return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
}
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
@@ -1754,8 +1740,11 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
#endif // EIGEN_VECTORIZE_AVX512DQ
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
-#if defined(EIGEN_VECTORIZE_AVX512BF16)
- r.bh = _mm512_cvtneps_pbh(flush);
+#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
+ // Since GCC 10.1 supports avx512bf16 and C style explicit cast
+ // (C++ static_cast is not supported yet), do converion via intrinsic
+ // and register path for performance.
+ r = (__m256i)(_mm512_cvtneps_pbh(flush));
#else
__m512i t;
__m512i input = _mm512_castps_si512(flush);
@@ -1775,7 +1764,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
t = _mm512_mask_blend_epi32(mask, nan, t);
// output.value = static_cast<uint16_t>(input);
- r.i = _mm512_cvtepi32_epi16(t);
+ r = _mm512_cvtepi32_epi16(t);
#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
@@ -1783,38 +1772,28 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
template <>
EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
- Packet16bf r;
- r.i = ptrue<Packet8i>(a.i);
- return r;
+ return ptrue<Packet8i>(a);
}
template <>
EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
- Packet16bf r;
- r.i = por<Packet8i>(a.i, b.i);
- return r;
+ return por<Packet8i>(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
- Packet16bf r;
- r.i = pxor<Packet8i>(a.i, b.i);
- return r;
+ return pxor<Packet8i>(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
- Packet16bf r;
- r.i = pand<Packet8i>(a.i, b.i);
- return r;
+ return pand<Packet8i>(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
const Packet16bf& b) {
- Packet16bf r;
- r.i = pandnot<Packet8i>(a.i, b.i);
- return r;
+ return pandnot<Packet8i>(a, b);
}
template <>
@@ -1823,50 +1802,39 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
const Packet16bf& b) {
// Input mask is expected to be all 0/1, handle it with 8-bit
// intrinsic for performance.
- Packet16bf r;
- r.i = _mm256_blendv_epi8(b.i, a.i, mask.i);
- return r;
+ return _mm256_blendv_epi8(b, a, mask);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
const Packet16bf& b) {
- Packet16bf result;
- result.i = Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
- return result;
+ return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
const Packet16bf& b) {
- Packet16bf result;
- result.i = Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
- return result;
+ return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
const Packet16bf& b) {
- Packet16bf result;
- result.i = Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
- return result;
+ return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
const Packet16bf& b) {
- Packet16bf result;
- result.i = Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
- return result;
+ return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
Packet16bf sign_mask;
- sign_mask.i = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
+ sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
Packet16bf result;
- result.i = _mm256_xor_si256(a.i, sign_mask.i);
- return result;
+ return _mm256_xor_si256(a, sign_mask);
}
template <>
@@ -1917,8 +1885,8 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
template <>
EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
- Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
- Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
+ Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
return padd<Packet8bf>(lane0, lane1);
}
@@ -1949,22 +1917,19 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
Packet16bf res;
// Swap hi and lo first because shuffle is in 128-bit lanes.
- res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
+ res = _mm256_permute2x128_si256(a, a, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
- res.i = _mm256_shuffle_epi8(res.i, m);
- return res;
+ return _mm256_shuffle_epi8(res, m);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
Index stride) {
- Packet16bf result;
- result.i = _mm256_set_epi16(
+ return _mm256_set_epi16(
from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
- return result;
}
template <>
@@ -1992,22 +1957,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
- __m256i a = kernel.packet[0].i;
- __m256i b = kernel.packet[1].i;
- __m256i c = kernel.packet[2].i;
- __m256i d = kernel.packet[3].i;
- __m256i e = kernel.packet[4].i;
- __m256i f = kernel.packet[5].i;
- __m256i g = kernel.packet[6].i;
- __m256i h = kernel.packet[7].i;
- __m256i i = kernel.packet[8].i;
- __m256i j = kernel.packet[9].i;
- __m256i k = kernel.packet[10].i;
- __m256i l = kernel.packet[11].i;
- __m256i m = kernel.packet[12].i;
- __m256i n = kernel.packet[13].i;
- __m256i o = kernel.packet[14].i;
- __m256i p = kernel.packet[15].i;
+ __m256i a = kernel.packet[0];
+ __m256i b = kernel.packet[1];
+ __m256i c = kernel.packet[2];
+ __m256i d = kernel.packet[3];
+ __m256i e = kernel.packet[4];
+ __m256i f = kernel.packet[5];
+ __m256i g = kernel.packet[6];
+ __m256i h = kernel.packet[7];
+ __m256i i = kernel.packet[8];
+ __m256i j = kernel.packet[9];
+ __m256i k = kernel.packet[10];
+ __m256i l = kernel.packet[11];
+ __m256i m = kernel.packet[12];
+ __m256i n = kernel.packet[13];
+ __m256i o = kernel.packet[14];
+ __m256i p = kernel.packet[15];
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
@@ -2063,29 +2028,29 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
- kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
- kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
- kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
- kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
- kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
- kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
- kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
- kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
- kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
- kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
- kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
- kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
+ kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
- __m256i a = kernel.packet[0].i;
- __m256i b = kernel.packet[1].i;
- __m256i c = kernel.packet[2].i;
- __m256i d = kernel.packet[3].i;
+ __m256i a = kernel.packet[0];
+ __m256i b = kernel.packet[1];
+ __m256i c = kernel.packet[2];
+ __m256i d = kernel.packet[3];
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
@@ -2098,10 +2063,10 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
+ kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
+ kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
+ kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
+ kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
} // end namespace internal