diff options
author | Antonio Sanchez <cantonios@google.com> | 2020-11-24 16:28:07 -0800 |
---|---|---|
committer | Antonio Sánchez <cantonios@google.com> | 2020-11-30 16:28:57 +0000 |
commit | 89f90b585d24b3c07946b4ffd8064e66ad5af94a (patch) | |
tree | c29344e3c03752faaaf2f8eee847811091688262 /Eigen/src/Core/arch/AVX512/PacketMath.h | |
parent | c5985c46f5de0a7a381262c5a8a973806db92f40 (diff) |
AVX512 missing ops.
This allows the `packetmath` tests to pass for AVX512 on skylake.
Made `half` and `bfloat16` consistent in terms of ops they support.
Note the `log` tests are currently disabled for `bfloat16` since
they fail due to poor precision (they were previously disabled for
`Packet8bf` via test function specialization -- I just removed that
specialization and disabled it in the generic test).
Diffstat (limited to 'Eigen/src/Core/arch/AVX512/PacketMath.h')
-rw-r--r-- | Eigen/src/Core/arch/AVX512/PacketMath.h | 280 |
1 files changed, 211 insertions, 69 deletions
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index bf7f0db4f..9acec3439 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -58,23 +58,35 @@ struct packet_traits<half> : default_packet_traits { Vectorizable = 1, AlignedOnScalar = 1, size = 16, - HasHalfPacket = 0, + HasHalfPacket = 1, + + HasCmp = 1, HasAdd = 1, HasSub = 1, HasMul = 1, HasDiv = 1, HasNegate = 1, - HasAbs = 0, + HasAbs = 1, HasAbs2 = 0, - HasMin = 0, - HasMax = 0, - HasConj = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasSetLinear = 0, - HasSqrt = 0, - HasRsqrt = 0, - HasExp = 0, - HasLog = 0, - HasBlend = 0 + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -87,6 +99,11 @@ template<> struct packet_traits<float> : default_packet_traits AlignedOnScalar = 1, size = 16, HasHalfPacket = 1, + + HasAbs = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, HasBlend = 0, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, @@ -105,7 +122,11 @@ template<> struct packet_traits<float> : default_packet_traits HasErf = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; template<> struct packet_traits<double> : default_packet_traits @@ -125,7 +146,11 @@ template<> struct packet_traits<double> : default_packet_traits HasRsqrt = EIGEN_FAST_MATH, #endif HasCmp = 1, - HasDiv = 1 + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 }; }; @@ -165,7 +190,7 @@ struct unpacket_traits<Packet16i> { template<> struct unpacket_traits<Packet16h> { typedef Eigen::half type; - typedef Packet16h half; + typedef Packet8h half; enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; @@ -188,10 +213,14 @@ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) { } template <> -EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(uint64_t from) { +EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) { return _mm512_castsi512_pd(_mm512_set1_epi64(from)); } +template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } + template <> EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) { return _mm512_broadcastss_ps(_mm_load_ps1(from)); @@ -281,7 +310,7 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, template <> EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, const Packet16i& b) { - return _mm512_mul_epi32(a, b); + return _mm512_mullo_epi32(a, b); } template <> @@ -482,6 +511,15 @@ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); } +template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); } + +template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); } + template <> EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) { return _mm512_set1_epi32(0xffffffffu); @@ -598,6 +636,21 @@ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& #endif } +template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a) +{ + // Work-around for default std::round rounding mode. + const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u)); + const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu)); + return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a) +{ + // Work-around for default std::round rounding mode. + const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull)); + const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull)); + return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { return _mm512_srai_epi32(a, N); } @@ -840,7 +893,24 @@ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& expon const Packet8d cst_half = pset1<Packet8d>(0.5); const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull)); exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d); - return por(pand(a, cst_inv_mant_mask), cst_half); + return por(pand(a, cst_inv_mant_mask), cst_half); +} + +template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) { + return pldexp_float(a,exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) { + // Build e=2^n by constructing the exponents in a 256-bit vector and + // shifting them to where they belong in double-precision values. + Packet8i cst_1023 = pset1<Packet8i>(1023); + __m256i emm0 = _mm512_cvtpd_epi32(exponent); + emm0 = _mm256_add_epi32(emm0, cst_1023); + emm0 = _mm256_shuffle_epi32(emm0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i lo = _mm256_slli_epi64(emm0, 52); + __m256i hi = _mm256_slli_epi64(_mm256_srli_epi64(emm0, 32), 52); + __m512d b = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + return pmul(a, b); } #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -1270,22 +1340,6 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, return _mm512_mask_blend_pd(m, elsePacket, thenPacket); } -template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) { - return _mm512_cvttps_epi32(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) { - return _mm512_cvtepi32_ps(a); -} - -template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i,Packet16f>(const Packet16f& a) { - return _mm512_castps_si512(a); -} - -template<> EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f,Packet16i>(const Packet16i& a) { - return _mm512_castsi512_ps(a); -} - // Packet math for Eigen::half template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) { return _mm256_set1_epi16(from.x); @@ -1398,6 +1452,29 @@ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { return ptrue(Packet8i(a)); } +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmin<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, + const Packet16h& b) { + return float2half(pmax<Packet16f>(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) { + return float2half(plset<Packet16f>(static_cast<float>(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { // in some cases Packet8i is a wrapper around __m256i, so we need to // cast to Packet8i to call the correct overload. @@ -1417,12 +1494,42 @@ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Pa return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) { + return float2half(pround<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) { + return float2half(print<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) { + return float2half(pceil<Packet16f>(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) { + return float2half(pfloor<Packet16f>(half2float(a))); +} + template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { Packet16f af = half2float(a); Packet16f bf = half2float(b); return Pack32To16(pcmp_eq(af, bf)); } +template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; } + template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); return _mm256_xor_si256(a, sign_mask); @@ -1461,6 +1568,25 @@ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) { return half(predux(from_float)); } +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) { + Packet8h lane0 = _mm256_extractf128_si256(a, 0); + Packet8h lane1 = _mm256_extractf128_si256(a, 1); + return padd<Packet8h>(lane0, lane1); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_max<Packet16f>(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_min<Packet16f>(af); + return Eigen::half(reduced); +} + template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) { Packet16f from_float = half2float(from); return half(predux_mul(from_float)); @@ -1487,22 +1613,22 @@ template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Pa { EIGEN_ALIGN64 half aux[16]; pstore(aux, from); - to[stride*0].x = aux[0].x; - to[stride*1].x = aux[1].x; - to[stride*2].x = aux[2].x; - to[stride*3].x = aux[3].x; - to[stride*4].x = aux[4].x; - to[stride*5].x = aux[5].x; - to[stride*6].x = aux[6].x; - to[stride*7].x = aux[7].x; - to[stride*8].x = aux[8].x; - to[stride*9].x = aux[9].x; - to[stride*10].x = aux[10].x; - to[stride*11].x = aux[11].x; - to[stride*12].x = aux[12].x; - to[stride*13].x = aux[13].x; - to[stride*14].x = aux[14].x; - to[stride*15].x = aux[15].x; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void @@ -1694,7 +1820,7 @@ struct packet_traits<bfloat16> : default_packet_traits { HasCos = EIGEN_FAST_MATH, #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) #ifdef EIGEN_VECTORIZE_AVX512DQ - HasLog = 1, + HasLog = 1, // Currently fails test with bad accuracy. HasLog1p = 1, HasExpm1 = 1, HasNdtri = 1, @@ -1859,6 +1985,23 @@ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, return _mm256_blendv_epi8(b, a, mask); } +template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a) +{ + return F32ToBf16(pround<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(print<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) { + return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a))); +} + template <> EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) { @@ -1885,9 +2028,7 @@ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, template <> EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { - Packet16bf sign_mask; - sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); - Packet16bf result; + Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000)); return _mm256_xor_si256(a, sign_mask); } @@ -1898,7 +2039,8 @@ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { template <> EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { - return F32ToBf16(pabs<Packet16f>(Bf16ToF32(a))); + const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000)); + return _mm256_andnot_si256(sign_mask, a); } template <> @@ -1997,22 +2139,22 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to, Index stride) { EIGEN_ALIGN64 bfloat16 aux[16]; pstore(aux, from); - to[stride*0].value = aux[0].value; - to[stride*1].value = aux[1].value; - to[stride*2].value = aux[2].value; - to[stride*3].value = aux[3].value; - to[stride*4].value = aux[4].value; - to[stride*5].value = aux[5].value; - to[stride*6].value = aux[6].value; - to[stride*7].value = aux[7].value; - to[stride*8].value = aux[8].value; - to[stride*9].value = aux[9].value; - to[stride*10].value = aux[10].value; - to[stride*11].value = aux[11].value; - to[stride*12].value = aux[12].value; - to[stride*13].value = aux[13].value; - to[stride*14].value = aux[14].value; - to[stride*15].value = aux[15].value; + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; } EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) { |