From 89f90b585d24b3c07946b4ffd8064e66ad5af94a Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 24 Nov 2020 16:28:07 -0800 Subject: AVX512 missing ops. This allows the `packetmath` tests to pass for AVX512 on skylake. Made `half` and `bfloat16` consistent in terms of ops they support. Note the `log` tests are currently disabled for `bfloat16` since they fail due to poor precision (they were previously disabled for `Packet8bf` via test function specialization -- I just removed that specialization and disabled it in the generic test). --- Eigen/src/Core/arch/AVX512/PacketMath.h | 280 ++++++++++++++++++++++++-------- 1 file changed, 211 insertions(+), 69 deletions(-) (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h') diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index bf7f0db4f..9acec3439 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -58,23 +58,35 @@ struct packet_traits : default_packet_traits { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, + + HasCmp = 1, HasAdd = 1, HasSub = 1, HasMul = 1, HasDiv = 1, HasNegate = 1, - HasAbs = 0, + HasAbs = 1, HasAbs2 = 0, - HasMin = 0, - HasMax = 0, - HasConj = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasSetLinear = 0, - HasSqrt = 0, - HasRsqrt = 0, - HasExp = 0, - HasLog = 0, - HasBlend = 0 + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -87,6 +99,11 @@ template<> struct packet_traits : default_packet_traits AlignedOnScalar = 1, size = 16, HasHalfPacket = 1, + + HasAbs = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasBlend = 0, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, @@ -105,7 +122,11 @@ template<> struct packet_traits : default_packet_traits HasErf = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; template<> struct packet_traits : default_packet_traits @@ -125,7 +146,11 @@ template<> struct packet_traits : default_packet_traits HasRsqrt = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -165,7 +190,7 @@ struct unpacket_traits { template<> struct unpacket_traits { typedef Eigen::half type; - typedef Packet16h half; + typedef Packet8h half; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; @@ -188,10 +213,14 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits(unsigned int from) { } template <> -EIGEN_STRONG_INLINE Packet8d pset1frombits(uint64_t from) { +EIGEN_STRONG_INLINE Packet8d pset1frombits(const numext::uint64_t from) { return _mm512_castsi512_pd(_mm512_set1_epi64(from)); } +template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } + template <> EIGEN_STRONG_INLINE Packet16f pload1(const float* from) { return _mm512_broadcastss_ps(_mm_load_ps1(from)); @@ -281,7 +310,7 @@ EIGEN_STRONG_INLINE Packet8d pmul(const Packet8d& a, template <> EIGEN_STRONG_INLINE Packet16i pmul(const Packet16i& a, const Packet16i& b) { - return _mm512_mul_epi32(a, b); + return _mm512_mullo_epi32(a, b); } template <> @@ -482,6 +511,15 @@ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); } +template<> EIGEN_STRONG_INLINE Packet16f print(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet8d print(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet16f pceil(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pceil(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); } + +template<> EIGEN_STRONG_INLINE Packet16f pfloor(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pfloor(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); } + template <> EIGEN_STRONG_INLINE Packet16i ptrue(const Packet16i& /*a*/) { return _mm512_set1_epi32(0xffffffffu); @@ -598,6 +636,21 @@ EIGEN_STRONG_INLINE Packet8d pandnot(const Packet8d& a,const Packet8d& #endif } +template<> EIGEN_STRONG_INLINE Packet16f pround(const Packet16f& a) +{ + // Work-around for default std::round rounding mode. + const Packet16f mask = pset1frombits(static_cast(0x80000000u)); + const Packet16f prev0dot5 = pset1frombits(static_cast(0x3EFFFFFFu)); + return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet8d pround(const Packet8d& a) +{ + // Work-around for default std::round rounding mode. + const Packet8d mask = pset1frombits(static_cast(0x8000000000000000ull)); + const Packet8d prev0dot5 = pset1frombits(static_cast(0x3FDFFFFFFFFFFFFFull)); + return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + template EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { return _mm512_srai_epi32(a, N); } @@ -840,7 +893,24 @@ EIGEN_STRONG_INLINE Packet8d pfrexp(const Packet8d& a, Packet8d& expon const Packet8d cst_half = pset1(0.5); const Packet8d cst_inv_mant_mask = pset1frombits(static_cast(~0x7ff0000000000000ull)); exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); + return por(pand(a, cst_inv_mant_mask), cst_half); +} + +template<> EIGEN_STRONG_INLINE Packet16f pldexp(const Packet16f& a, const Packet16f& exponent) { + return pldexp_float(a,exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, const Packet8d& exponent) { + // Build e=2^n by constructing the exponents in a 256-bit vector and + // shifting them to where they belong in double-precision values. + Packet8i cst_1023 = pset1(1023); + __m256i emm0 = _mm512_cvtpd_epi32(exponent); + emm0 = _mm256_add_epi32(emm0, cst_1023); + emm0 = _mm256_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i lo = _mm256_slli_epi64(emm0, 52); + __m256i hi = _mm256_slli_epi64(_mm256_srli_epi64(emm0, 32), 52); + __m512d b = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + return pmul(a, b); } #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -1270,22 +1340,6 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, return _mm512_mask_blend_pd(m, elsePacket, thenPacket); } -template<> EIGEN_STRONG_INLINE Packet16i pcast(const Packet16f& a) { - return _mm512_cvttps_epi32(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16i& a) { - return _mm512_cvtepi32_ps(a); -} - -template<> EIGEN_STRONG_INLINE Packet16i preinterpret(const Packet16f& a) { - return _mm512_castps_si512(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16i& a) { - return _mm512_castsi512_ps(a); -} - // Packet math for Eigen::half template<> EIGEN_STRONG_INLINE Packet16h pset1(const Eigen::half& from) { return _mm256_set1_epi16(from.x); @@ -1398,6 +1452,29 @@ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { return ptrue(Packet8i(a)); } +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmin(const Packet16h& a, + const Packet16h& b) { + return float2half(pmin(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmax(const Packet16h& a, + const Packet16h& b) { + return float2half(pmax(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h plset(const half& a) { + return float2half(plset(static_cast(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { // in some cases Packet8i is a wrapper around __m256i, so we need to // cast to Packet8i to call the correct overload. @@ -1417,12 +1494,42 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16h pround(const Packet16h& a) { + return float2half(pround(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h print(const Packet16h& a) { + return float2half(print(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pceil(const Packet16h& a) { + return float2half(pceil(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pfloor(const Packet16h& a) { + return float2half(pfloor(half2float(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { Packet16f af = half2float(a); Packet16f bf = half2float(b); return Pack32To16(pcmp_eq(af, bf)); } +template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; } + template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { Packet16h sign_mask = _mm256_set1_epi16(static_cast(0x8000)); return _mm256_xor_si256(a, sign_mask); @@ -1461,6 +1568,25 @@ template<> EIGEN_STRONG_INLINE half predux(const Packet16h& from) { return half(predux(from_float)); } +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4(const Packet16h& a) { + Packet8h lane0 = _mm256_extractf128_si256(a, 0); + Packet8h lane1 = _mm256_extractf128_si256(a, 1); + return padd(lane0, lane1); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_max(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_min(af); + return Eigen::half(reduced); +} + template<> EIGEN_STRONG_INLINE half predux_mul(const Packet16h& from) { Packet16f from_float = half2float(from); return half(predux_mul(from_float)); @@ -1487,22 +1613,22 @@ template<> EIGEN_STRONG_INLINE void pscatter(half* to, const Pa { EIGEN_ALIGN64 half aux[16]; pstore(aux, from); - to[stride*0].x = aux[0].x; - to[stride*1].x = aux[1].x; - to[stride*2].x = aux[2].x; - to[stride*3].x = aux[3].x; - to[stride*4].x = aux[4].x; - to[stride*5].x = aux[5].x; - to[stride*6].x = aux[6].x; - to[stride*7].x = aux[7].x; - to[stride*8].x = aux[8].x; - to[stride*9].x = aux[9].x; - to[stride*10].x = aux[10].x; - to[stride*11].x = aux[11].x; - to[stride*12].x = aux[12].x; - to[stride*13].x = aux[13].x; - to[stride*14].x = aux[14].x; - to[stride*15].x = aux[15].x; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void @@ -1694,7 +1820,7 @@ struct packet_traits : default_packet_traits { HasCos = EIGEN_FAST_MATH, #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) #ifdef EIGEN_VECTORIZE_AVX512DQ - HasLog = 1, + HasLog = 1, // Currently fails test with bad accuracy. HasLog1p = 1, HasExpm1 = 1, HasNdtri = 1, @@ -1859,6 +1985,23 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16bf pround(const Packet16bf& a) +{ + return F32ToBf16(pround(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf print(const Packet16bf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pceil(const Packet16bf& a) { + return F32ToBf16(pceil(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pfloor(const Packet16bf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + template <> EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) { @@ -1885,9 +2028,7 @@ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, template <> EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { - Packet16bf sign_mask; - sign_mask = _mm256_set1_epi16(static_cast(0x8000)); - Packet16bf result; + Packet16bf sign_mask = _mm256_set1_epi16(static_cast(0x8000)); return _mm256_xor_si256(a, sign_mask); } @@ -1898,7 +2039,8 @@ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { template <> EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { - return F32ToBf16(pabs(Bf16ToF32(a))); + const __m256i sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_andnot_si256(sign_mask, a); } template <> @@ -1997,22 +2139,22 @@ EIGEN_STRONG_INLINE void pscatter(bfloat16* to, Index stride) { EIGEN_ALIGN64 bfloat16 aux[16]; pstore(aux, from); - to[stride*0].value = aux[0].value; - to[stride*1].value = aux[1].value; - to[stride*2].value = aux[2].value; - to[stride*3].value = aux[3].value; - to[stride*4].value = aux[4].value; - to[stride*5].value = aux[5].value; - to[stride*6].value = aux[6].value; - to[stride*7].value = aux[7].value; - to[stride*8].value = aux[8].value; - to[stride*9].value = aux[9].value; - to[stride*10].value = aux[10].value; - to[stride*11].value = aux[11].value; - to[stride*12].value = aux[12].value; - to[stride*13].value = aux[13].value; - to[stride*14].value = aux[14].value; - to[stride*15].value = aux[15].value; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { -- cgit v1.2.3