aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/AVX512/PacketMath.h
diff options
context:
space:
mode:
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h73
1 files changed, 34 insertions, 39 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index f55a50596..2b6693eed 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1633,13 +1633,13 @@ template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
// There is no half-size packet for current Packet16bf.
- // TODO: support as SSE/AVX path.
- typedef Packet16bf half;
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
@@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf>
{
typedef bfloat16 type;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
- typedef Packet16bf half;
+ typedef Packet8bf half;
};
template <>
@@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
}
-// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+#if defined(EIGEN_VECTORIZE_AVX512DQ)
__m512 flush = _mm512_and_ps(a, a);
+#else
+ __m512 flush = _mm512_max_ps(a, a);
+#endif // EIGEN_VECTORIZE_AVX512DQ
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
#if defined(EIGEN_VECTORIZE_AVX512BF16)
@@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
// output.value = static_cast<uint16_t>(input);
r.i = _mm512_cvtepi32_epi16(t);
-#endif
+#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
}
@@ -1912,6 +1916,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
}
template <>
+EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
+ Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
+ return padd<Packet8bf>(lane0, lane1);
+}
+
+template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
@@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
// Swap hi and lo first because shuffle is in 128-bit lanes.
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
- res.i = _mm256_shuffle_epi8(a.i, m);
+ res.i = _mm256_shuffle_epi8(res.i, m);
return res;
}
@@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
- kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {