aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sheng Yang <yang.sheng@intel.com>2020-07-14 01:34:03 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-14 01:34:03 +0000
commit56b3e3f3f8ca9972ca390c8296fde363bdab271c (patch)
tree5d06bf0995ed07dd232e346369e71f70561b5d9c
parent4ab32e2de2511746e2108563a43cbbeb1922fbf2 (diff)
AVX path for BF16
-rw-r--r--CMakeLists.txt6
-rw-r--r--Eigen/src/Core/arch/AVX/MathFunctions.h33
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h345
-rw-r--r--Eigen/src/Core/arch/AVX/TypeCasting.h26
-rw-r--r--Eigen/src/Core/arch/AVX512/MathFunctions.h47
-rw-r--r--Eigen/src/Core/arch/AVX512/PacketMath.h73
-rw-r--r--Eigen/src/Core/arch/Default/BFloat16.h7
-rw-r--r--cmake/EigenTesting.cmake12
8 files changed, 460 insertions, 89 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 28103856e..867d068d2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -239,6 +239,12 @@ if(NOT MSVC)
message(STATUS "Enabling FMA in tests/examples")
endif()
+ option(EIGEN_TEST_AVX2 "Enable/Disable AVX2 in tests/examples" OFF)
+ if(EIGEN_TEST_AVX2)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma")
+ message(STATUS "Enabling AVX2 in tests/examples")
+ endif()
+
option(EIGEN_TEST_AVX512 "Enable/Disable AVX512 in tests/examples" OFF)
if(EIGEN_TEST_AVX512)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma")
diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h
index c5394430f..461696170 100644
--- a/Eigen/src/Core/arch/AVX/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h
@@ -58,15 +58,15 @@ pexp<Packet8f>(const Packet8f& _x) {
// Hyperbolic Tangent function.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
-ptanh<Packet8f>(const Packet8f& x) {
- return internal::generic_fast_tanh_float(x);
+ptanh<Packet8f>(const Packet8f& _x) {
+ return internal::generic_fast_tanh_float(_x);
}
// Exponential function for doubles.
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d
-pexp<Packet4d>(const Packet4d& x) {
- return pexp_double(x);
+pexp<Packet4d>(const Packet4d& _x) {
+ return pexp_double(_x);
}
// Functions for sqrt.
@@ -96,13 +96,13 @@ psqrt<Packet8f>(const Packet8f& _x) {
}
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f psqrt<Packet8f>(const Packet8f& x) {
- return _mm256_sqrt_ps(x);
+Packet8f psqrt<Packet8f>(const Packet8f& _x) {
+ return _mm256_sqrt_ps(_x);
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d psqrt<Packet4d>(const Packet4d& x) {
- return _mm256_sqrt_pd(x);
+Packet4d psqrt<Packet4d>(const Packet4d& _x) {
+ return _mm256_sqrt_pd(_x);
}
#if EIGEN_FAST_MATH
@@ -140,18 +140,27 @@ Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
#else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet8f prsqrt<Packet8f>(const Packet8f& x) {
+Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
- return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
+ return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
-Packet4d prsqrt<Packet4d>(const Packet4d& x) {
+Packet4d prsqrt<Packet4d>(const Packet4d& _x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
- return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
+ return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x));
}
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
+BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index b17c34015..cf7146cbc 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -32,11 +32,13 @@ typedef __m256 Packet8f;
typedef __m256i Packet8i;
typedef __m256d Packet4d;
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
+typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
template<> struct is_arithmetic<__m256> { enum { value = true }; };
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
+template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \
const Packet8f p8f_##NAME = pset1<Packet8f>(X)
@@ -134,6 +136,40 @@ struct packet_traits<Eigen::half> : default_packet_traits {
HasBlend = 0
};
};
+
+template <>
+struct packet_traits<bfloat16> : default_packet_traits {
+ typedef Packet8bf type;
+ // There is no half-size packet for current Packet8bf.
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
+ enum {
+ Vectorizable = 1,
+ AlignedOnScalar = 1,
+ size = 8,
+ HasHalfPacket = 0,
+
+ HasCmp = 1,
+ HasDiv = 1,
+ HasSin = EIGEN_FAST_MATH,
+ HasCos = EIGEN_FAST_MATH,
+ HasLog = 1,
+ HasLog1p = 1,
+ HasExpm1 = 1,
+ HasExp = 1,
+ HasNdtri = 1,
+ HasBessel = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasTanh = EIGEN_FAST_MATH,
+ HasErf = EIGEN_FAST_MATH,
+ HasBlend = 0,
+ HasRound = 1,
+ HasFloor = 1,
+ HasCeil = 1,
+ HasRint = 1
+ };
+};
#endif
template<> struct scalar_div_cost<float,true> { enum { value = 14 }; };
@@ -165,6 +201,14 @@ template<> struct unpacket_traits<Packet4d> {
enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
};
template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; };
+template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; };
+
+// Helper function for bit packing snippet of low precision comparison.
+// It packs the flags from 16x16 to 8x16.
+EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) {
+ return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0),
+ _mm256_extractf128_si256(_mm256_castps_si256(rf), 1));
+}
template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); }
template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); }
@@ -1032,6 +1076,307 @@ ptranspose(PacketBlock<Packet8h,4>& kernel) {
kernel.packet[3] = pload<Packet8h>(out[3]);
}
+EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
+#ifdef EIGEN_VECTORIZE_AVX2
+ __m256i extend = _mm256_cvtepu16_epi32(a);
+ return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16));
+#else
+ __m128i lo = _mm_cvtepu16_epi32(a);
+ __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8));
+ __m128i lo_shift = _mm_slli_epi32(lo, 16);
+ __m128i hi_shift = _mm_slli_epi32(hi, 16);
+ return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1));
+#endif
+}
+
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
+EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
+ Packet8bf r;
+
+ // Flush input denormals value to zero with hardware capability.
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+ __m256 flush = _mm256_and_ps(a, a);
+ _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
+
+ __m256i input = _mm256_castps_si256(flush);
+
+#ifdef EIGEN_VECTORIZE_AVX2
+ // uint32_t lsb = (input >> 16);
+ __m256i t = _mm256_srli_epi32(input, 16);
+ // uint32_t lsb = lsb & 1;
+ t = _mm256_and_si256(t, _mm256_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ t = _mm256_add_epi32(t, input);
+ // input = input >> 16;
+ t = _mm256_srli_epi32(t, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
+ __m256i nan = _mm256_set1_epi32(0x7fc0);
+ t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
+ // output.value = static_cast<uint16_t>(input);
+ return _mm_packus_epi32(_mm256_extractf128_si256(t, 0),
+ _mm256_extractf128_si256(t, 1));
+#else
+ // uint32_t lsb = (input >> 16);
+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16);
+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16);
+ // uint32_t lsb = lsb & 1;
+ lo = _mm_and_si128(lo, _mm_set1_epi32(1));
+ hi = _mm_and_si128(hi, _mm_set1_epi32(1));
+ // uint32_t rounding_bias = 0x7fff + lsb;
+ lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff));
+ hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff));
+ // input += rounding_bias;
+ lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0));
+ hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1));
+ // input = input >> 16;
+ lo = _mm_srli_epi32(lo, 16);
+ hi = _mm_srli_epi32(hi, 16);
+ // Check NaN before converting back to bf16
+ __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
+ __m128i nan = _mm_set1_epi32(0x7fc0);
+ lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
+ hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
+ // output.value = static_cast<uint16_t>(input);
+ return _mm_packus_epi32(lo, hi);
+#endif
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) {
+ return _mm_set1_epi16(from.value);
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) {
+ return bfloat16_impl::raw_uint16_to_bfloat16(static_cast<unsigned short>(_mm_extract_epi16(from, 0)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) {
+ return _mm_load_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) {
+ return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from));
+}
+
+template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_store_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploaddup<Packet8bf>(const bfloat16* from) {
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ unsigned short c = from[2].value;
+ unsigned short d = from[3].value;
+ return _mm_set_epi16(d, d, c, c, b, b, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf
+ploadquad<Packet8bf>(const bfloat16* from) {
+ unsigned short a = from[0].value;
+ unsigned short b = from[1].value;
+ return _mm_set_epi16(b, b, b, b, a, a, a, a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) {
+ return _mm_cmpeq_epi32(a, a);
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
+ return F32ToBf16(pabs<Packet8f>(Bf16ToF32(a)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template <>
+EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a,
+ const Packet8bf& b) {
+ return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_or_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_xor_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_and_si128(a,b);
+}
+template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) {
+ return _mm_andnot_si128(b,a);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) {
+ return _mm_blendv_epi8(b, a, mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a)
+{
+ return F32ToBf16(pround<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(print<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) {
+ return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) {
+ return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; }
+
+template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) {
+ Packet8bf sign_mask = _mm_set1_epi16(static_cast<unsigned short>(0x8000));
+ return _mm_xor_si128(a, sign_mask);
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) {
+ return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b)));
+}
+
+
+template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride)
+{
+ return _mm_set_epi16(from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
+}
+
+template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride)
+{
+ EIGEN_ALIGN32 bfloat16 aux[8];
+ pstore(aux, from);
+ to[stride*0] = aux[0];
+ to[stride*1] = aux[1];
+ to[stride*2] = aux[2];
+ to[stride*3] = aux[3];
+ to[stride*4] = aux[4];
+ to[stride*5] = aux[5];
+ to[stride*6] = aux[6];
+ to[stride*7] = aux[7];
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
+ return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
+}
+
+template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a)
+{
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
+ return _mm_shuffle_epi8(a,m);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,8>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+ __m128i e = kernel.packet[4];
+ __m128i f = kernel.packet[5];
+ __m128i g = kernel.packet[6];
+ __m128i h = kernel.packet[7];
+
+ __m128i a03b03 = _mm_unpacklo_epi16(a, b);
+ __m128i c03d03 = _mm_unpacklo_epi16(c, d);
+ __m128i e03f03 = _mm_unpacklo_epi16(e, f);
+ __m128i g03h03 = _mm_unpacklo_epi16(g, h);
+ __m128i a47b47 = _mm_unpackhi_epi16(a, b);
+ __m128i c47d47 = _mm_unpackhi_epi16(c, d);
+ __m128i e47f47 = _mm_unpackhi_epi16(e, f);
+ __m128i g47h47 = _mm_unpackhi_epi16(g, h);
+
+ __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
+ __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
+ __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
+ __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
+ __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
+ __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
+ __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
+ __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
+
+ kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
+ kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
+ kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
+ kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
+ kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
+}
+
+EIGEN_STRONG_INLINE void
+ptranspose(PacketBlock<Packet8bf,4>& kernel) {
+ __m128i a = kernel.packet[0];
+ __m128i b = kernel.packet[1];
+ __m128i c = kernel.packet[2];
+ __m128i d = kernel.packet[3];
+
+ __m128i ab_03 = _mm_unpacklo_epi16(a, b);
+ __m128i cd_03 = _mm_unpacklo_epi16(c, d);
+ __m128i ab_47 = _mm_unpackhi_epi16(a, b);
+ __m128i cd_47 = _mm_unpackhi_epi16(c, d);
+
+ kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03);
+ kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03);
+ kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47);
+ kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h
index 181043588..c669a7f60 100644
--- a/Eigen/src/Core/arch/AVX/TypeCasting.h
+++ b/Eigen/src/Core/arch/AVX/TypeCasting.h
@@ -76,12 +76,38 @@ struct type_casting_traits<float, Eigen::half> {
};
};
+template <>
+struct type_casting_traits<bfloat16, float> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
+template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
+ return Bf16ToF32(a);
+}
+
+template <>
+struct type_casting_traits<float, bfloat16> {
+ enum {
+ VectorizedCast = 1,
+ SrcCoeffRatio = 1,
+ TgtCoeffRatio = 1
+ };
+};
+
#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
+template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
+ return F32ToBf16(a);
+}
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h
index b86afced6..83af5f5de 100644
--- a/Eigen/src/Core/arch/AVX512/MathFunctions.h
+++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h
@@ -135,10 +135,7 @@ plog<Packet16f>(const Packet16f& _x) {
p16f_minus_inf);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf plog<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
#endif
// Exponential function. Works by writing "x = m*log(2) + r" where
@@ -264,10 +261,7 @@ pexp<Packet8d>(const Packet8d& _x) {
return pmax(pmul(x, e), _x);
}*/
-template <>
-EIGEN_STRONG_INLINE Packet16bf pexp<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexp<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
@@ -325,10 +319,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf psqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(psqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
@@ -377,10 +368,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
}
#endif
-template <>
-EIGEN_STRONG_INLINE Packet16bf prsqrt<Packet16bf>(const Packet16bf& x) {
- return F32ToBf16(prsqrt<Packet16f>(Bf16ToF32(x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
// prsqrt for double.
#if EIGEN_FAST_MATH
@@ -435,20 +423,14 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf plog1p<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(plog1p<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
-template<>
-EIGEN_STRONG_INLINE Packet16bf pexpm1<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pexpm1<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
#endif
#endif
@@ -461,31 +443,20 @@ psin<Packet16f>(const Packet16f& _x) {
}
template <>
-EIGEN_STRONG_INLINE Packet16bf psin<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(psin<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
template <>
-EIGEN_STRONG_INLINE Packet16bf pcos<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(pcos<Packet16f>(Bf16ToF32(_x)));
-}
-
-template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
ptanh<Packet16f>(const Packet16f& _x) {
return internal::generic_fast_tanh_float(_x);
}
-template <>
-EIGEN_STRONG_INLINE Packet16bf ptanh<Packet16bf>(const Packet16bf& _x) {
- return F32ToBf16(ptanh<Packet16f>(Bf16ToF32(_x)));
-}
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
+BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
} // end namespace internal
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index f55a50596..2b6693eed 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1633,13 +1633,13 @@ template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
// There is no half-size packet for current Packet16bf.
- // TODO: support as SSE/AVX path.
- typedef Packet16bf half;
+ // TODO: support as SSE path.
+ typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
- HasHalfPacket = 0,
+ HasHalfPacket = 1,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
@@ -1668,7 +1668,7 @@ struct unpacket_traits<Packet16bf>
{
typedef bfloat16 type;
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
- typedef Packet16bf half;
+ typedef Packet8bf half;
};
template <>
@@ -1741,13 +1741,17 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a.i), 16));
}
-// Convert float to bfloat16 according to round-to-even/denormals alogrithm.
+// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
+#if defined(EIGEN_VECTORIZE_AVX512DQ)
__m512 flush = _mm512_and_ps(a, a);
+#else
+ __m512 flush = _mm512_max_ps(a, a);
+#endif // EIGEN_VECTORIZE_AVX512DQ
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
#if defined(EIGEN_VECTORIZE_AVX512BF16)
@@ -1772,7 +1776,7 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
// output.value = static_cast<uint16_t>(input);
r.i = _mm512_cvtepi32_epi16(t);
-#endif
+#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
}
@@ -1912,6 +1916,13 @@ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
}
template <>
+EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
+ Packet8bf lane0 = _mm256_extractf128_si256(a.i, 0);
+ Packet8bf lane1 = _mm256_extractf128_si256(a.i, 1);
+ return padd<Packet8bf>(lane0, lane1);
+}
+
+template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
@@ -1940,7 +1951,7 @@ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
// Swap hi and lo first because shuffle is in 128-bit lanes.
res.i = _mm256_permute2x128_si256(a.i, a.i, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
- res.i = _mm256_shuffle_epi8(a.i, m);
+ res.i = _mm256_shuffle_epi8(res.i, m);
return res;
}
@@ -2052,38 +2063,22 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
- kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
- kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01,
- 0x20);
- kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23,
- 0x20);
- kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45,
- 0x20);
- kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67,
- 0x20);
- kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89,
- 0x20);
- kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab,
- 0x20);
- kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd,
- 0x20);
- kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef,
- 0x20);
+ kernel.packet[0].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
+ kernel.packet[1].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
+ kernel.packet[2].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
+ kernel.packet[3].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
+ kernel.packet[4].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
+ kernel.packet[5].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
+ kernel.packet[6].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
+ kernel.packet[7].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
+ kernel.packet[8].i = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
+ kernel.packet[9].i = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
+ kernel.packet[10].i = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
+ kernel.packet[11].i = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
+ kernel.packet[12].i = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
+ kernel.packet[13].i = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
+ kernel.packet[14].i = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
+ kernel.packet[15].i = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h
index abf2ac933..34a4f0ced 100644
--- a/Eigen/src/Core/arch/Default/BFloat16.h
+++ b/Eigen/src/Core/arch/Default/BFloat16.h
@@ -23,6 +23,13 @@ limitations under the License.
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
#endif
+#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
+ template <> \
+ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
+ PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
+ return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
+ }
+
namespace Eigen {
struct bfloat16;
diff --git a/cmake/EigenTesting.cmake b/cmake/EigenTesting.cmake
index 524223717..c98393e5e 100644
--- a/cmake/EigenTesting.cmake
+++ b/cmake/EigenTesting.cmake
@@ -310,6 +310,12 @@ macro(ei_testing_print_summary)
message(STATUS "AVX: Using architecture defaults")
endif()
+ if(EIGEN_TEST_AVX2)
+ message(STATUS "AVX2: ON")
+ else()
+ message(STATUS "AVX2: Using architecture defaults")
+ endif()
+
if(EIGEN_TEST_FMA)
message(STATUS "FMA: ON")
else()
@@ -322,6 +328,12 @@ macro(ei_testing_print_summary)
message(STATUS "AVX512: Using architecture defaults")
endif()
+ if(EIGEN_TEST_AVX512DQ)
+ message(STATUS "AVX512DQ: ON")
+ else()
+ message(STATUS "AVX512DQ: Using architecture defaults")
+ endif()
+
if(EIGEN_TEST_ALTIVEC)
message(STATUS "Altivec: ON")
else()