diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-11-24 16:28:07 -0800 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-11-30 16:28:57 +0000 |
commit | 89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch) | |
tree | c29344e3c03752faaaf2f8eee847811091688262 | |
parent | c5985c46f5de0a7a381262c5a8a973806db92f40 (diff) |
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake.
Made `half` and `bfloat16` consistent in terms of ops they support.
Note the `log` tests are currently disabled for `bfloat16` since
they fail due to poor precision (they were previously disabled for
`Packet8bf` via test function specialization -- I just removed that
specialization and disabled it in the generic test).
-rw-r--r-- | Eigen/src/Core/arch/AVX/PacketMath.h | 21 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX512/MathFunctions.h | 12 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX512/PacketMath.h | 280 | ||||
-rw-r--r-- | Eigen/src/Core/arch/AVX512/TypeCasting.h | 16 | ||||
-rw-r--r-- | test/packetmath.cpp | 129 |
5 files changed, 280 insertions, 178 deletions
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index e9eaaa9e0..a9fc33791 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -105,7 +105,8 @@ template<> struct packet_traits<double> : default_packet_traits HasBlend = 1, HasRound = 1, HasFloor = 1, - HasCeil = 1 + HasCeil = 1, + HasRint = 1 }; }; @@ -278,7 +279,15 @@ template<> EIGEN_STRONG_INLINE Packet8i pconj(const Packet8i& a) { return a; } template<> EIGEN_STRONG_INLINE Packet8f pmul<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4d pmul<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); } - +template<> EIGEN_STRONG_INLINE Packet8i pmul<Packet8i>(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_mullo_epi32(a,b); +#else + const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); } template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); } @@ -499,14 +508,14 @@ template<> EIGEN_STRONG_INLINE Packet8i pandnot<Packet8i>(const Packet8i& a, con template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a) { - const Packet8f mask = pset1frombits<Packet8f>(0x80000000u); - const Packet8f prev0dot5 = pset1frombits<Packet8f>(0x3EFFFFFFu); + const Packet8f mask = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x80000000u)); + const Packet8f prev0dot5 = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x3EFFFFFFu)); return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); } template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a) { - const Packet4d mask = _mm256_castsi256_pd(_mm256_set_epi64x(0x8000000000000000ull, 0x8000000000000000ull, 0x8000000000000000ull, 0x8000000000000000ull)); - const Packet4d prev0dot5 = _mm256_castsi256_pd(_mm256_set_epi64x(0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull)); + const Packet4d mask = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x8000000000000000ull)); + const Packet4d prev0dot5 = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull)); return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); } diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index bfd30c01a..2c34868a7 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -48,6 +48,7 @@ plog<Packet8d>(const Packet8d& _x) { return plog_double(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) #endif @@ -174,6 +175,7 @@ pexp<Packet8d>(const Packet8d& _x) { return pmax(pmul(x, e), _x); }*/ +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) // Functions for sqrt. @@ -232,6 +234,7 @@ EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) { } #endif +F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) // prsqrt for float. @@ -256,7 +259,7 @@ prsqrt<Packet16f>(const Packet16f& _x) { __mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ); __mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ); __mmask16 not_finite_pos_mask = not_pos_mask | inf_mask; - + // Compute an approximate result using the rsqrt intrinsic, forcing +inf // for denormals for consistency with AVX and SSE implementations. Packet16f y_approx = _mm512_rsqrt14_ps(_x); @@ -281,6 +284,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) { } #endif +F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) // prsqrt for double. @@ -336,6 +340,7 @@ Packet16f plog1p<Packet16f>(const Packet16f& _x) { return generic_plog1p(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED @@ -343,6 +348,7 @@ Packet16f pexpm1<Packet16f>(const Packet16f& _x) { return generic_expm1(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) #endif @@ -367,6 +373,10 @@ ptanh<Packet16f>(const Packet16f& _x) { return internal::generic_fast_tanh_float(_x); } +F16_PACKET_FUNCTION(Packet16f, Packet16h, psin) +F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos) +F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh) + BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index bf7f0db4f..9acec3439 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -58,23 +58,35 @@ struct packet_traits<half> : default_packet_traits { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, + + HasCmp = 1, HasAdd = 1, HasSub = 1, HasMul = 1, HasDiv = 1, HasNegate = 1, - HasAbs = 0, + HasAbs = 1, HasAbs2 = 0, - HasMin = 0, - HasMax = 0, - HasConj = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasSetLinear = 0, - HasSqrt = 0, - HasRsqrt = 0, - HasExp = 0, - HasLog = 0, - HasBlend = 0 + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -87,6 +99,11 @@ template<> struct packet_traits<float> : default_packet_traits AlignedOnScalar = 1, size = 16, HasHalfPacket = 1, + + HasAbs = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasBlend = 0, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, @@ -105,7 +122,11 @@ template<> struct packet_traits<float> : default_packet_traits HasErf = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; template<> struct packet_traits<double> : default_packet_traits @@ -125,7 +146,11 @@ template<> struct packet_traits<double> : default_packet_traits HasRsqrt = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -165,7 +190,7 @@ struct unpacket_traits<Packet16i> { template<> struct unpacket_traits<Packet16h> { typedef Eigen::half type; - typedef Packet16h half; + typedef Packet8h half; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; @@ -188,10 +213,14 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) { } template <> -EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) { +EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) { return _mm512_castsi512_pd(_mm512_set1_epi64(from)); } +template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } + template <> EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) { return _mm512_broadcastss_ps(_mm_load_ps1(from)); @@ -281,7 +310,7 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, template <> EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, const Packet16i& b) { - return _mm512_mul_epi32(a, b); + return _mm512_mullo_epi32(a, b); } template <> @@ -482,6 +511,15 @@ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); } +template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); } + +template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); } + template <> EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) { return _mm512_set1_epi32(0xffffffffu); @@ -598,6 +636,21 @@ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& #endif } +template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a) +{ + // Work-around for default std::round rounding mode. + const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u)); + const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu)); + return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a) +{ + // Work-around for default std::round rounding mode. + const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull)); + const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull)); + return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { return _mm512_srai_epi32(a, N); } @@ -840,7 +893,24 @@ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& expon const Packet8d cst_half = pset1<Packet8d>(0.5); const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull)); exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); + return por(pand(a, cst_inv_mant_mask), cst_half); +} + +template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) { + return pldexp_float(a,exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) { + // Build e=2^n by constructing the exponents in a 256-bit vector and + // shifting them to where they belong in double-precision values. + Packet8i cst_1023 = pset1<Packet8i>(1023); + __m256i emm0 = _mm512_cvtpd_epi32(exponent); + emm0 = _mm256_add_epi32(emm0, cst_1023); + emm0 = _mm256_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i lo = _mm256_slli_epi64(emm0, 52); + __m256i hi = _mm256_slli_epi64(_mm256_srli_epi64(emm0, 32), 52); + __m512d b = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + return pmul(a, b); } #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -1270,22 +1340,6 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, return _mm512_mask_blend_pd(m, elsePacket, thenPacket); } -template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { - return _mm512_cvttps_epi32(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) { - return _mm512_cvtepi32_ps(a); -} - -template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i,Packet16f>(const Packet16f& a) { - return _mm512_castps_si512(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const Packet16i& a) { - return _mm512_castsi512_ps(a); -} - // Packet math for Eigen::half template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) { return _mm256_set1_epi16(from.x); @@ -1398,6 +1452,29 @@ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { return ptrue(Packet8i(a)); } +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmin<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmax<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) { + return float2half(plset<Packet16f>(static_cast<float>(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { // in some cases Packet8i is a wrapper around __m256i, so we need to // cast to Packet8i to call the correct overload. @@ -1417,12 +1494,42 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) { + return float2half(pround<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) { + return float2half(print<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) { + return float2half(pceil<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) { + return float2half(pfloor<Packet16f>(half2float(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { Packet16f af = half2float(a); Packet16f bf = half2float(b); return Pack32To16(pcmp_eq(af, bf)); } +template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; } + template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); return _mm256_xor_si256(a, sign_mask); @@ -1461,6 +1568,25 @@ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) { return half(predux(from_float)); } +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) { + Packet8h lane0 = _mm256_extractf128_si256(a, 0); + Packet8h lane1 = _mm256_extractf128_si256(a, 1); + return padd<Packet8h>(lane0, lane1); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_max<Packet16f>(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_min<Packet16f>(af); + return Eigen::half(reduced); +} + template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) { Packet16f from_float = half2float(from); return half(predux_mul(from_float)); @@ -1487,22 +1613,22 @@ template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Pa { EIGEN_ALIGN64 half aux[16]; pstore(aux, from); - to[stride*0].x = aux[0].x; - to[stride*1].x = aux[1].x; - to[stride*2].x = aux[2].x; - to[stride*3].x = aux[3].x; - to[stride*4].x = aux[4].x; - to[stride*5].x = aux[5].x; - to[stride*6].x = aux[6].x; - to[stride*7].x = aux[7].x; - to[stride*8].x = aux[8].x; - to[stride*9].x = aux[9].x; - to[stride*10].x = aux[10].x; - to[stride*11].x = aux[11].x; - to[stride*12].x = aux[12].x; - to[stride*13].x = aux[13].x; - to[stride*14].x = aux[14].x; - to[stride*15].x = aux[15].x; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void @@ -1694,7 +1820,7 @@ struct packet_traits<bfloat16> : default_packet_traits { HasCos = EIGEN_FAST_MATH, #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) #ifdef EIGEN_VECTORIZE_AVX512DQ - HasLog = 1, + HasLog = 1, // Currently fails test with bad accuracy. HasLog1p = 1, HasExpm1 = 1, HasNdtri = 1, @@ -1859,6 +1985,23 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a) +{ + return F32ToBf16(pround<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(print<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a))); +} + template <> EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) { @@ -1885,9 +2028,7 @@ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, template <> EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { - Packet16bf sign_mask; - sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); - Packet16bf result; + Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); return _mm256_xor_si256(a, sign_mask); } @@ -1898,7 +2039,8 @@ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { template <> EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { - return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a))); + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); } template <> @@ -1997,22 +2139,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to, Index stride) { EIGEN_ALIGN64 bfloat16 aux[16]; pstore(aux, from); - to[stride*0].value = aux[0].value; - to[stride*1].value = aux[1].value; - to[stride*2].value = aux[2].value; - to[stride*3].value = aux[3].value; - to[stride*4].value = aux[4].value; - to[stride*5].value = aux[5].value; - to[stride*6].value = aux[6].value; - to[stride*7].value = aux[7].value; - to[stride*8].value = aux[8].value; - to[stride*9].value = aux[9].value; - to[stride*10].value = aux[10].value; - to[stride*11].value = aux[11].value; - to[stride*12].value = aux[12].value; - to[stride*13].value = aux[13].value; - to[stride*14].value = aux[14].value; - to[stride*15].value = aux[15].value; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) { diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index e643b18a7..330412729 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -14,6 +14,22 @@ namespace Eigen { namespace internal { +template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { + return _mm512_cvttps_epi32(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) { + return _mm512_cvtepi32_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) { + return _mm512_castps_si512(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) { + return _mm512_castsi512_ps(a); +} + template <> struct type_casting_traits<half, float> { enum { diff --git a/test/packetmath.cpp b/test/packetmath.cpp index d52f997dc..ae0ead820 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -618,7 +618,10 @@ void packetmath_real() { test::packet_helper<PacketTraits::HasLog, Packet> h; h.store(data2, internal::plog(h.load(data1))); VERIFY((numext::isnan)(data2[0])); - VERIFY_IS_APPROX(std::log(std::numeric_limits<Scalar>::epsilon()), data2[1]); + // TODO(cantonios): Re-enable for bfloat16. + if (!internal::is_same<Scalar, bfloat16>::value) { + VERIFY_IS_APPROX(std::log(data1[1]), data2[1]); + } data1[0] = -std::numeric_limits<Scalar>::epsilon(); data1[1] = Scalar(0); @@ -629,7 +632,10 @@ void packetmath_real() { data1[0] = (std::numeric_limits<Scalar>::min)(); data1[1] = -(std::numeric_limits<Scalar>::min)(); h.store(data2, internal::plog(h.load(data1))); - VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]); + // TODO(cantonios): Re-enable for bfloat16. + if (!internal::is_same<Scalar, bfloat16>::value) { + VERIFY_IS_APPROX(std::log((std::numeric_limits<Scalar>::min)()), data2[0]); + } VERIFY((numext::isnan)(data2[1])); // Note: 32-bit arm always flushes denorms to zero. @@ -731,54 +737,6 @@ void packetmath_real() { VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \ } -template <> -void packetmath_real<bfloat16, typename internal::packet_traits<bfloat16>::type>(){ - typedef internal::packet_traits<bfloat16> PacketTraits; - typedef internal::packet_traits<bfloat16>::type Packet; - - const int PacketSize = internal::unpacket_traits<Packet>::size; - const int size = PacketSize * 4; - EIGEN_ALIGN_MAX bfloat16 data1[PacketSize * 4]; - EIGEN_ALIGN_MAX bfloat16 data2[PacketSize * 4]; - EIGEN_ALIGN_MAX bfloat16 ref[PacketSize * 4]; - - for (int i = 0; i < size; ++i) { - data1[i] = bfloat16(internal::random<float>(0, 1) * std::pow(float(10), internal::random<float>(-6, 6))); - data2[i] = bfloat16(internal::random<float>(0, 1) * std::pow(float(10), internal::random<float>(-6, 6))); - data1[i] = bfloat16(0); - } - - if (internal::random<float>(0, 1) < 0.1f) data1[internal::random<int>(0, PacketSize)] = bfloat16(0); - - CAST_CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog, bfloat16, float); - CAST_CHECK_CWISE1_IF(PacketTraits::HasRsqrt, float(1) / std::sqrt, internal::prsqrt, bfloat16, float); - - for (int i = 0; i < size; ++i) { - data1[i] = bfloat16(internal::random<float>(-1, 1) * std::pow(float(10), internal::random<float>(-3, 3))); - data2[i] = bfloat16(internal::random<float>(-1, 1) * std::pow(float(10), internal::random<float>(-3, 3))); - } - CAST_CHECK_CWISE1_IF(PacketTraits::HasSin, std::sin, internal::psin, bfloat16, float); - CAST_CHECK_CWISE1_IF(PacketTraits::HasCos, std::cos, internal::pcos, bfloat16, float); - CAST_CHECK_CWISE1_IF(PacketTraits::HasTan, std::tan, internal::ptan, bfloat16, float); - - CAST_CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround, bfloat16, float); - CAST_CHECK_CWISE1_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil, bfloat16, float); - CAST_CHECK_CWISE1_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor, bfloat16, float); - - for (int i = 0; i < size; ++i) { - data1[i] = bfloat16(-1.5 + i); - data2[i] = bfloat16(-1.5 + i); - } - CAST_CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround, bfloat16, float); - - for (int i = 0; i < size; ++i) { - data1[i] = bfloat16(internal::random<float>(-87, 88)); - data2[i] = bfloat16(internal::random<float>(-87, 88)); - } - CAST_CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp, bfloat16, float); - -} - template <typename Scalar> Scalar propagate_nan_max(const Scalar& a, const Scalar& b) { if ((numext::isnan)(a)) return a; @@ -793,6 +751,20 @@ Scalar propagate_nan_min(const Scalar& a, const Scalar& b) { return (numext::mini)(a,b); } +template <typename Scalar> +Scalar propagate_number_max(const Scalar& a, const Scalar& b) { + if ((numext::isnan)(a)) return b; + if ((numext::isnan)(b)) return a; + return (numext::maxi)(a,b); +} + +template <typename Scalar> +Scalar propagate_number_min(const Scalar& a, const Scalar& b) { + if ((numext::isnan)(a)) return b; + if ((numext::isnan)(b)) return a; + return (numext::mini)(a,b); +} + template <typename Scalar, typename Packet> void packetmath_notcomplex() { typedef internal::packet_traits<Scalar> PacketTraits; @@ -809,15 +781,9 @@ void packetmath_notcomplex() { CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin); CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax); -#if EIGEN_HAS_CXX11_MATH - using std::fmin; - using std::fmax; -#else - using ::fmin; - using ::fmax; -#endif - CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>)); - CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>); + + CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_number_min, internal::pmin<PropagateNumbers>); + CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_number_max, internal::pmax<PropagateNumbers>); CHECK_CWISE1(numext::abs, internal::pabs); CHECK_CWISE2_IF(PacketTraits::HasAbsDiff, REF_ABS_DIFF, internal::pabsdiff); @@ -890,54 +856,13 @@ void packetmath_notcomplex() { data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0); } // Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here. - CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, (internal::pmin<PropagateNumbers>)); - CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pmax<PropagateNumbers>); + CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_number_min, (internal::pmin<PropagateNumbers>)); + CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_number_max, internal::pmax<PropagateNumbers>); CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, (internal::pmin<PropagateNaN>)); CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pmax<PropagateNaN>); } } -template <> -void packetmath_notcomplex<bfloat16, typename internal::packet_traits<bfloat16>::type>(){ - typedef bfloat16 Scalar; - typedef internal::packet_traits<bfloat16>::type Packet; - typedef internal::packet_traits<Scalar> PacketTraits; - const int PacketSize = internal::unpacket_traits<Packet>::size; - - EIGEN_ALIGN_MAX Scalar data1[PacketSize * 4]; - EIGEN_ALIGN_MAX Scalar data2[PacketSize * 4]; - EIGEN_ALIGN_MAX Scalar ref[PacketSize * 4]; - Array<Scalar, Dynamic, 1>::Map(data1, PacketSize * 4).setRandom(); - - ref[0] = data1[0]; - for (int i = 0; i < PacketSize; ++i) ref[0] = (std::min)(ref[0], data1[i]); - VERIFY(internal::isApprox(ref[0], internal::predux_min(internal::pload<Packet>(data1))) && "internal::predux_min"); - - VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMin); - VERIFY((!PacketTraits::Vectorizable) || PacketTraits::HasMax); - - CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin); - CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax); - CHECK_CWISE1(numext::abs, internal::pabs); - CHECK_CWISE2_IF(PacketTraits::HasAbsDiff, REF_ABS_DIFF, internal::pabsdiff); - - ref[0] = data1[0]; - for (int i = 0; i < PacketSize; ++i) ref[0] = (std::max)(ref[0], data1[i]); - VERIFY(internal::isApprox(ref[0], internal::predux_max(internal::pload<Packet>(data1))) && "internal::predux_max"); - - { - unsigned char* data1_bits = reinterpret_cast<unsigned char*>(data1); - // predux_any - for (unsigned int i = 0; i < PacketSize * sizeof(Scalar); ++i) data1_bits[i] = 0x0; - VERIFY((!internal::predux_any(internal::pload<Packet>(data1))) && "internal::predux_any(0000)"); - for (int k = 0; k < PacketSize; ++k) { - for (unsigned int i = 0; i < sizeof(Scalar); ++i) data1_bits[k * sizeof(Scalar) + i] = 0xff; - VERIFY(internal::predux_any(internal::pload<Packet>(data1)) && "internal::predux_any(0101)"); - for (unsigned int i = 0; i < sizeof(Scalar); ++i) data1_bits[k * sizeof(Scalar) + i] = 0x00; - } - } -} - template <typename Scalar, typename Packet, bool ConjLhs, bool ConjRhs> void test_conj_helper(Scalar* data1, Scalar* data2, Scalar* ref, Scalar* pval) { const int PacketSize = internal::unpacket_traits<Packet>::size; |